In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import numpy as np
import pandas as pd
from torch import nn
from utils import getData
from torch.optim import Adam
from tqdm.notebook import tqdm
from transformers import BertTokenizer, BertModel

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        
        self.labels = data.label.values
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in data.text.values]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [4]:
class BertClassifier(nn.Module):
    def __init__(self, dropout, model_name):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 2)
        self.softmax = nn.Softmax()
        
    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        output =  self.softmax(linear_output)

        return output

In [5]:
from datasets import load_metric
f1_fun = load_metric("f1")

In [6]:
def train(model, train, val, learning_rate, epochs, batch_size):

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=batch_size)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)

    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(epochs):

            total_acc_train = 0
            total_f1_train = 0 
            total_loss_train = 0

            for train_input, train_label in tqdm(train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label)
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc
                
                f1=f1_fun.compute(predictions=output.argmax(dim=1), references=train_label)['f1']
                total_f1_train+=f1 
                
                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0
            total_f1_val = 0 

            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label)
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
                    
                    f1=f1_fun.compute(predictions=output.argmax(dim=1), references=val_label)['f1']
                    total_f1_val+=f1 
            
            print(
                f"""Epochs: {epoch_num + 1} | \
                Train Loss: {total_loss_train / len(train): .3f} | \
                Train Accuracy: {total_acc_train / len(train): .3f} | \
                Train F1: {total_f1_train / len(train): .3f} |\
                Val Loss: {total_loss_val / len(val): .3f} | \
                Val Accuracy: {total_acc_val / len(val): .3f} |
                Val F1: {total_f1_val / len(val): .3f}
                """)
                  

In [7]:
train_df, test_df = getData(sub_task="A", return_type="pandas")
train_dataset = Dataset(data=train_df, tokenizer= BertTokenizer.from_pretrained("UBC-NLP/MARBERT"))
test_dataset = Dataset(data=test_df, tokenizer= BertTokenizer.from_pretrained("UBC-NLP/MARBERT"))

In [8]:
EPOCHS = 10
torch.manual_seed(2903)
model = BertClassifier(dropout= 0.1, model_name="UBC-NLP/MARBERT")
LR = 2e-5

Some weights of the model checkpoint at UBC-NLP/MARBERT were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
train(model, train_dataset, test_dataset, LR, EPOCHS, batch_size=4)

  0%|          | 0/2222 [00:00<?, ?it/s]

Epochs: 1 |                 Train Loss:  0.155 |                 Train Accuracy:  0.683 |                 Train F1:  0.045 |                Val Loss:  0.154 |                 Val Accuracy:  0.651 |
                Val F1:  0.096
                


  0%|          | 0/2222 [00:00<?, ?it/s]