In [23]:
import os
import sys

import torch.utils.data

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from tqdm import tqdm
import torch
from pandas import read_csv
from torch.utils.data import Dataset
from torch import nn
from pathlib import Path
from transformers import BertConfig, BertTokenizer, BertModel

In [24]:
K = 3

BASE_DIR = Path('..')
PRETRAINED_MODEL_DIR = BASE_DIR / 'models' / 'dna-bert' / f'{K}-new-12w-0'
DATASET_DIR = BASE_DIR / 'data' / 'processed' / 'H.sapiens495' / f'kmer_{K}'

In [25]:
config = BertConfig.from_pretrained(PRETRAINED_MODEL_DIR)
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_DIR)
pre_model = BertModel.from_pretrained(PRETRAINED_MODEL_DIR, config=config)

Some weights of the model checkpoint at ..\models\dna-bert\3-new-12w-0 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class KmerDataset(Dataset):
    def __init__(self, data_dir: Path, dataset: str, set_type: str, k: str, tokenizer: BertTokenizer):
        self.tokenizer = tokenizer

        data = read_csv(data_dir / 'processed' / dataset / f'kmer_text_{k}' / f'{set_type}.csv')

        self.sequences = data['x1']
        self.labels = data['label']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        tokenized_seq = self.tokenizer.encode_plus(self.sequences.iloc[index], None, return_token_type_ids=True)

        return (
            torch.tensor(tokenized_seq['input_ids'], dtype=torch.long),
            torch.tensor(tokenized_seq['attention_mask'], dtype=torch.long),
            torch.tensor(tokenized_seq['token_type_ids'], dtype=torch.long),
            torch.tensor(self.labels.iloc[index], dtype=torch.long),
        )


In [26]:
valid_data = KmerDataset(Path('..') / 'data', 'H.sapiens100', 'valid', K, tokenizer)
train_data = KmerDataset(Path('..') / 'data', 'H.sapiens100', 'train', K, tokenizer)

In [27]:
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
valid_data_loader = torch.utils.data.DataLoader(valid_data, batch_size=16, shuffle=True, num_workers=0)

In [30]:
class PsiSequenceClassifier(nn.Module):
    def __init__(self, pre_model):
        super().__init__()

        self.l1 = pre_model
        # self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = nn.Dropout(.2)
        self.classifier = nn.Linear(768, 2)
        self.out = nn.Sigmoid()

    def forward(self, input_ids, attention_mask, token_type_ids):
        x = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_states = x[0]
        x = hidden_states[:, 0]
        x = self.dropout(x)
        x = self.classifier(x)
        x = self.out(x)

        return x

In [31]:
    device = 'cpu'
    model = PsiSequenceClassifier(pre_model)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=10e-4)

    def calculate_accuracy(y_pred, y_true):
        n_correct = (y_pred == y_true).sum().item()
        return n_correct


    def train(epoch, training_loader):
        tr_loss = 0
        n_correct = 0
        nb_tr_steps = 0
        nb_tr_examples = 0
        model.train()
        for index, data in tqdm(enumerate(training_loader, 0)):
            ids, mask, token_type_ids, labels = data

            outputs = model(ids, mask, token_type_ids)
            loss = loss_function(outputs, labels)
            tr_loss += loss.item()
            big_val, big_idx = torch.max(outputs.data, dim=1)
            n_correct += calculate_accuracy(big_idx, labels)

            nb_tr_steps += 1
            nb_tr_examples += labels.size(0)

            # if _%500==0:
            #     loss_step = tr_loss/nb_tr_steps
            #     accu_step = (n_correct*100)/nb_tr_examples
            #     print(f"Training Loss per 500 steps: {loss_step}")
            #     print(f"Training Accuracy per 500 steps: {accu_step}")

            optimizer.zero_grad()
            loss.backward()
            # # When using GPU
            optimizer.step()

        print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')
        epoch_loss = tr_loss/nb_tr_steps
        epoch_accu = (n_correct*100)/nb_tr_examples
        print(f"Training Loss Epoch: {epoch_loss}")
        print(f"Training Accuracy Epoch: {epoch_accu}")


    EPOCHS = 10
    for epoch in range(EPOCHS):
        train(epoch, train_data_loader)

    def valid(model, testing_loader):
        model.eval()
        n_correct = 0; n_wrong = 0; total = 0; tr_loss=0; nb_tr_steps=0; nb_tr_examples=0
        with torch.no_grad():
            for _, data in tqdm(enumerate(testing_loader, 0)):
                ids, mask, token_type_ids, labels = data

                # ids = data['ids'].to(device, dtype = torch.long)
                # mask = data['mask'].to(device, dtype = torch.long)
                # token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
                # targets = data['targets'].to(device, dtype = torch.long)
                outputs = model(ids, mask, token_type_ids)
                loss = loss_function(outputs, labels)
                tr_loss += loss.item()
                big_val, big_idx = torch.max(outputs.data, dim=1)
                n_correct += calculate_accuracy(big_idx, labels)

                nb_tr_steps += 1
                nb_tr_examples += labels.size(0)

                if _ % 5000==0:
                    loss_step = tr_loss/nb_tr_steps
                    accu_step = (n_correct*100)/nb_tr_examples
                    print(f"Validation Loss per 100 steps: {loss_step}")
                    print(f"Validation Accuracy per 100 steps: {accu_step}")
        epoch_loss = tr_loss/nb_tr_steps
        epoch_accu = (n_correct*100)/nb_tr_examples
        print(f"Validation Loss Epoch: {epoch_loss}")
        print(f"Validation Accuracy Epoch: {epoch_accu}")

        return epoch_accu
    acc = valid(model, valid_data_loader)
    print("Accuracy on validation data = %0.2f%%" % acc)

1it [00:06,  6.55s/it]


KeyboardInterrupt: 