In [75]:
import torch
import numpy as np
from transformers import BertTokenizer
from torch import nn
from transformers import BertModel
import pandas as pd
from torch.optim import Adam
from tqdm import tqdm

In [77]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [78]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):

        self.labels = [label for label in df['labels']]
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 32, truncation=True,
                                return_tensors="pt") for text in df['drugname']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

### Train / Test split

In [87]:
datapath_train = 'drug-train.csv'
df_train = pd.read_csv(datapath_train)

datapath_test = 'drug-test.csv'
df_val = pd.read_csv(datapath_test)

In [89]:
print(len(df_train),len(df_val))

343 136


### Build Classification Model


In [90]:
class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')

        for param in self.bert.parameters():
            param.requires_grad = True

        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 76)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)

        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

### Train Function

In [91]:
def train(model, train_data, val_data, learning_rate, epochs):
    train, val = Dataset(train_data), Dataset(val_data)

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)

    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(epochs):

            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label)
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0

            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label)
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
                | Train Accuracy: {total_acc_train / len(train_data): .3f} \
                | Val Loss: {total_loss_val / len(val_data): .3f} \
                | Val Accuracy: {total_acc_val / len(val_data): .3f}')

In [92]:
EPOCHS = 50
model = BertClassifier()
LR = 1e-6

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [93]:
train(model, df_train, df_val, LR, EPOCHS)

100%|██████████| 172/172 [00:14<00:00, 12.00it/s]


Epochs: 1 | Train Loss:  2.198                 | Train Accuracy:  0.026                 | Val Loss:  2.112                 | Val Accuracy:  0.066


100%|██████████| 172/172 [00:14<00:00, 11.57it/s]


Epochs: 2 | Train Loss:  2.119                 | Train Accuracy:  0.038                 | Val Loss:  2.055                 | Val Accuracy:  0.081


100%|██████████| 172/172 [00:14<00:00, 11.51it/s]


Epochs: 3 | Train Loss:  2.052                 | Train Accuracy:  0.122                 | Val Loss:  1.965                 | Val Accuracy:  0.221


100%|██████████| 172/172 [00:14<00:00, 11.79it/s]


Epochs: 4 | Train Loss:  1.973                 | Train Accuracy:  0.166                 | Val Loss:  1.863                 | Val Accuracy:  0.279


100%|██████████| 172/172 [00:14<00:00, 11.58it/s]


Epochs: 5 | Train Loss:  1.849                 | Train Accuracy:  0.274                 | Val Loss:  1.742                 | Val Accuracy:  0.412


100%|██████████| 172/172 [00:15<00:00, 11.38it/s]


Epochs: 6 | Train Loss:  1.762                 | Train Accuracy:  0.350                 | Val Loss:  1.655                 | Val Accuracy:  0.441


100%|██████████| 172/172 [00:14<00:00, 11.56it/s]


Epochs: 7 | Train Loss:  1.653                 | Train Accuracy:  0.458                 | Val Loss:  1.549                 | Val Accuracy:  0.537


100%|██████████| 172/172 [00:15<00:00, 11.37it/s]


Epochs: 8 | Train Loss:  1.600                 | Train Accuracy:  0.487                 | Val Loss:  1.477                 | Val Accuracy:  0.522


100%|██████████| 172/172 [00:15<00:00, 11.44it/s]


Epochs: 9 | Train Loss:  1.515                 | Train Accuracy:  0.504                 | Val Loss:  1.418                 | Val Accuracy:  0.610


100%|██████████| 172/172 [00:14<00:00, 11.69it/s]


Epochs: 10 | Train Loss:  1.462                 | Train Accuracy:  0.577                 | Val Loss:  1.373                 | Val Accuracy:  0.640


100%|██████████| 172/172 [00:14<00:00, 11.85it/s]


Epochs: 11 | Train Loss:  1.420                 | Train Accuracy:  0.621                 | Val Loss:  1.283                 | Val Accuracy:  0.699


100%|██████████| 172/172 [00:14<00:00, 11.88it/s]


Epochs: 12 | Train Loss:  1.348                 | Train Accuracy:  0.641                 | Val Loss:  1.222                 | Val Accuracy:  0.735


100%|██████████| 172/172 [00:14<00:00, 12.21it/s]


Epochs: 13 | Train Loss:  1.291                 | Train Accuracy:  0.685                 | Val Loss:  1.185                 | Val Accuracy:  0.721


100%|██████████| 172/172 [00:14<00:00, 11.77it/s]


Epochs: 14 | Train Loss:  1.244                 | Train Accuracy:  0.691                 | Val Loss:  1.135                 | Val Accuracy:  0.779


100%|██████████| 172/172 [00:14<00:00, 11.63it/s]


Epochs: 15 | Train Loss:  1.212                 | Train Accuracy:  0.711                 | Val Loss:  1.098                 | Val Accuracy:  0.787


100%|██████████| 172/172 [00:15<00:00, 11.21it/s]


Epochs: 16 | Train Loss:  1.184                 | Train Accuracy:  0.708                 | Val Loss:  1.079                 | Val Accuracy:  0.787


100%|██████████| 172/172 [00:15<00:00, 11.27it/s]


Epochs: 17 | Train Loss:  1.133                 | Train Accuracy:  0.749                 | Val Loss:  1.016                 | Val Accuracy:  0.816


100%|██████████| 172/172 [00:15<00:00, 11.41it/s]


Epochs: 18 | Train Loss:  1.124                 | Train Accuracy:  0.758                 | Val Loss:  0.995                 | Val Accuracy:  0.838


100%|██████████| 172/172 [00:14<00:00, 11.73it/s]


Epochs: 19 | Train Loss:  1.081                 | Train Accuracy:  0.778                 | Val Loss:  0.985                 | Val Accuracy:  0.838


100%|██████████| 172/172 [00:15<00:00, 11.29it/s]


Epochs: 20 | Train Loss:  1.037                 | Train Accuracy:  0.802                 | Val Loss:  0.938                 | Val Accuracy:  0.875


100%|██████████| 172/172 [00:14<00:00, 11.49it/s]


Epochs: 21 | Train Loss:  1.003                 | Train Accuracy:  0.799                 | Val Loss:  0.898                 | Val Accuracy:  0.882


100%|██████████| 172/172 [00:14<00:00, 11.54it/s]


Epochs: 22 | Train Loss:  0.971                 | Train Accuracy:  0.816                 | Val Loss:  0.860                 | Val Accuracy:  0.890


100%|██████████| 172/172 [00:15<00:00, 11.42it/s]


Epochs: 23 | Train Loss:  0.961                 | Train Accuracy:  0.828                 | Val Loss:  0.838                 | Val Accuracy:  0.904


100%|██████████| 172/172 [00:14<00:00, 11.49it/s]


Epochs: 24 | Train Loss:  0.932                 | Train Accuracy:  0.843                 | Val Loss:  0.810                 | Val Accuracy:  0.890


100%|██████████| 172/172 [00:14<00:00, 11.55it/s]


Epochs: 25 | Train Loss:  0.910                 | Train Accuracy:  0.828                 | Val Loss:  0.791                 | Val Accuracy:  0.919


100%|██████████| 172/172 [00:15<00:00, 11.35it/s]


Epochs: 26 | Train Loss:  0.887                 | Train Accuracy:  0.854                 | Val Loss:  0.758                 | Val Accuracy:  0.912


100%|██████████| 172/172 [00:14<00:00, 11.53it/s]


Epochs: 27 | Train Loss:  0.847                 | Train Accuracy:  0.863                 | Val Loss:  0.761                 | Val Accuracy:  0.919


100%|██████████| 172/172 [00:14<00:00, 11.87it/s]


Epochs: 28 | Train Loss:  0.853                 | Train Accuracy:  0.878                 | Val Loss:  0.726                 | Val Accuracy:  0.926


100%|██████████| 172/172 [00:14<00:00, 11.72it/s]


Epochs: 29 | Train Loss:  0.809                 | Train Accuracy:  0.883                 | Val Loss:  0.703                 | Val Accuracy:  0.912


100%|██████████| 172/172 [00:14<00:00, 11.49it/s]


Epochs: 30 | Train Loss:  0.807                 | Train Accuracy:  0.892                 | Val Loss:  0.677                 | Val Accuracy:  0.912


100%|██████████| 172/172 [00:14<00:00, 12.18it/s]


Epochs: 31 | Train Loss:  0.773                 | Train Accuracy:  0.892                 | Val Loss:  0.669                 | Val Accuracy:  0.919


100%|██████████| 172/172 [00:15<00:00, 11.32it/s]


Epochs: 32 | Train Loss:  0.768                 | Train Accuracy:  0.892                 | Val Loss:  0.678                 | Val Accuracy:  0.926


100%|██████████| 172/172 [00:15<00:00, 11.07it/s]


Epochs: 33 | Train Loss:  0.737                 | Train Accuracy:  0.898                 | Val Loss:  0.631                 | Val Accuracy:  0.926


100%|██████████| 172/172 [00:14<00:00, 12.11it/s]


Epochs: 34 | Train Loss:  0.724                 | Train Accuracy:  0.895                 | Val Loss:  0.619                 | Val Accuracy:  0.941


100%|██████████| 172/172 [00:14<00:00, 11.94it/s]


Epochs: 35 | Train Loss:  0.701                 | Train Accuracy:  0.901                 | Val Loss:  0.601                 | Val Accuracy:  0.949


100%|██████████| 172/172 [00:14<00:00, 11.64it/s]


Epochs: 36 | Train Loss:  0.681                 | Train Accuracy:  0.915                 | Val Loss:  0.608                 | Val Accuracy:  0.934


100%|██████████| 172/172 [00:14<00:00, 11.81it/s]


Epochs: 37 | Train Loss:  0.665                 | Train Accuracy:  0.915                 | Val Loss:  0.569                 | Val Accuracy:  0.949


100%|██████████| 172/172 [00:14<00:00, 11.89it/s]


Epochs: 38 | Train Loss:  0.666                 | Train Accuracy:  0.921                 | Val Loss:  0.548                 | Val Accuracy:  0.956


100%|██████████| 172/172 [00:15<00:00, 11.21it/s]


Epochs: 39 | Train Loss:  0.637                 | Train Accuracy:  0.927                 | Val Loss:  0.534                 | Val Accuracy:  0.971


100%|██████████| 172/172 [00:15<00:00, 11.46it/s]


Epochs: 40 | Train Loss:  0.615                 | Train Accuracy:  0.933                 | Val Loss:  0.536                 | Val Accuracy:  0.971


100%|██████████| 172/172 [00:14<00:00, 11.50it/s]


Epochs: 41 | Train Loss:  0.617                 | Train Accuracy:  0.936                 | Val Loss:  0.525                 | Val Accuracy:  0.963


100%|██████████| 172/172 [00:14<00:00, 11.55it/s]


Epochs: 42 | Train Loss:  0.590                 | Train Accuracy:  0.936                 | Val Loss:  0.513                 | Val Accuracy:  0.963


100%|██████████| 172/172 [00:15<00:00, 11.43it/s]


Epochs: 43 | Train Loss:  0.587                 | Train Accuracy:  0.939                 | Val Loss:  0.489                 | Val Accuracy:  0.971


100%|██████████| 172/172 [00:14<00:00, 11.60it/s]


Epochs: 44 | Train Loss:  0.567                 | Train Accuracy:  0.942                 | Val Loss:  0.493                 | Val Accuracy:  0.978


100%|██████████| 172/172 [00:14<00:00, 11.89it/s]


Epochs: 45 | Train Loss:  0.551                 | Train Accuracy:  0.948                 | Val Loss:  0.488                 | Val Accuracy:  0.978


100%|██████████| 172/172 [00:15<00:00, 10.97it/s]


Epochs: 46 | Train Loss:  0.544                 | Train Accuracy:  0.933                 | Val Loss:  0.481                 | Val Accuracy:  0.963


100%|██████████| 172/172 [00:13<00:00, 12.52it/s]


Epochs: 47 | Train Loss:  0.546                 | Train Accuracy:  0.956                 | Val Loss:  0.461                 | Val Accuracy:  0.956


100%|██████████| 172/172 [00:14<00:00, 12.18it/s]


Epochs: 48 | Train Loss:  0.528                 | Train Accuracy:  0.945                 | Val Loss:  0.467                 | Val Accuracy:  0.963


100%|██████████| 172/172 [00:14<00:00, 11.52it/s]


Epochs: 49 | Train Loss:  0.512                 | Train Accuracy:  0.956                 | Val Loss:  0.434                 | Val Accuracy:  0.971


100%|██████████| 172/172 [00:14<00:00, 11.67it/s]


Epochs: 50 | Train Loss:  0.493                 | Train Accuracy:  0.962                 | Val Loss:  0.437                 | Val Accuracy:  0.978


### Evaluate

In [94]:
def evaluate(model, test_data):
    test = Dataset(test_data)

    test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:

        model = model.cuda()

    total_acc_test = 0
    with torch.no_grad():

        for test_input, test_label in test_dataloader:

              test_label = test_label.to(device)
              mask = test_input['attention_mask'].to(device)
              input_id = test_input['input_ids'].squeeze(1).to(device)

              output = model(input_id, mask)

              acc = (output.argmax(dim=1) == test_label).sum().item()
              total_acc_test += acc
    
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')
    

In [95]:
evaluate(model, df_val)

Test Accuracy:  0.963
