In [4]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm_notebook

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler

import nltk
import spacy
import string
import evaluate
import transformers

from sklearn.model_selection import train_test_split
from transformers import T5ForConditionalGeneration, T5TokenizerFast

import loralib as lora

import warnings
warnings.filterwarnings("ignore")

In [5]:
class QA_Dataset(Dataset):
    '''
    Follow the question answering input format of UnifiedQA: https://arxiv.org/pdf/2005.00700.pdf
    '''
    def __init__(self, tokenizer, dataframe, q_len, t_len):
        self.tokenizer = tokenizer
        self.q_len = q_len
        self.t_len = t_len
        self.data = dataframe
        self.question = self.data['question']
        self.choices = self.data['choices']
        self.label = self.data['label']
        
    def __len__(self):
        return len(self.question)
    
    def __getitem__(self, idx):
        question = self.question[idx]
        choices = self.choices[idx]
        label = int(self.label[idx])
        answer = choices[label]
        
        # Append choices to question following style of UnifiedQA
        # question \n (A) c1 (B) c2 . . .       
        letters = ['(A)', '(B)', '(C)', '(D)', '(E)']
        question = question + ' \n'
        for i, c in enumerate(choices):
            question += f' {letters[i]} {c}'
        question_for_tok =  question
        answer_for_tok = answer
        question_tokenized = self.tokenizer(question_for_tok, max_length=self.q_len, padding="max_length",
                                            truncation=True, pad_to_max_length=True, add_special_tokens=True)
        answer_tokenized = self.tokenizer(answer_for_tok, max_length=self.t_len, padding="max_length", 
                                          truncation=True, pad_to_max_length=True, add_special_tokens=True)
    
        return {
            "input_ids": torch.tensor(question_tokenized["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(question_tokenized["attention_mask"], dtype=torch.long),
            "decoder_input_ids": torch.tensor(answer_tokenized["input_ids"], dtype=torch.long),
            "decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long),
            "question": question,
            "ref_answer": answer,
        }


class Trainer:
    def __init__(self, model, optimizer, tokenizer, train_loader, val_loader, use_lora=False, device='cuda'):
        self.model = model
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.use_lora = use_lora
        self.device = device
        self.bleu = evaluate.load("google_bleu")

    def train_epoch(self):
        self.model = self.model.to(self.device)
        self.model.train()
        if self.use_lora:
            lora.mark_only_lora_as_trainable(self.model)

        train_loss = 0
        train_batch_count = 0
        for batch in tqdm_notebook(self.train_loader, desc="Training batches"):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["decoder_input_ids"].to(self.device)
            decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
    
            outputs = self.model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 labels=labels,
                                 decoder_attention_mask=decoder_attention_mask)
    
            self.optimizer.zero_grad()
            outputs.loss.backward()
            self.optimizer.step()
            train_loss += outputs.loss.item()
            train_batch_count += 1

        return train_loss / train_batch_count

    def validate_epoch(self):
        self.model = self.model.to(self.device)
        self.model.eval()
        val_loss = 0
        val_batch_count = 0
        predicted_answers = []
        ref_answers = []
        correct_num = 0
        total_num = 0
        
        for batch in tqdm_notebook(self.val_loader, desc="Validation batches"):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["decoder_input_ids"].to(self.device)
            decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
            
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     labels=labels,
                                     decoder_attention_mask=decoder_attention_mask)
                val_loss += outputs.loss.item()
                val_batch_count += 1
    
            # Store val outputs and metrics
            bs = outputs.logits.shape[0]
            for b_idx in range(bs):
                logits = outputs.logits[b_idx]
                tokens = torch.argmax(logits, dim=1)
                end_tok_idx = (tokens == 1).nonzero()
    
                if end_tok_idx.size(0) > 0:
                    end_tok_idx = end_tok_idx[0].item()
                    if end_tok_idx+1 < tokens.size(0):
                        tokens[end_tok_idx+1:] = 0
                
                predicted_answer = self.tokenizer.decode(tokens, skip_special_tokens=True)
                ref_answer = batch['ref_answer'][b_idx]
    
                predicted_answers.append(predicted_answer)
                ref_answers.append(ref_answer)
                if ref_answer in predicted_answer:
                    correct_num += 1
                total_num += 1
        
        # Finish calculating val metrics
        val_acc = correct_num / total_num
        bleu_score = self.bleu.compute(predictions=predicted_answers, references=ref_answers)['google_bleu']

        return {'val_loss': val_loss/val_batch_count, 'val_acc': val_acc, 'bleu_score': bleu_score}


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_acc = -float('inf')

    def early_stop(self, validation_acc):
        if validation_acc > self.min_validation_acc + self.min_delta:
            self.min_validation_acc = validation_acc
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def apply_lora(model, num_blocks=12, model_d=768, lora_r=16):
    # Apply LoRA to all attention matrices in the transformer block: q,k,v,o
    for i in range(num_blocks):
        model.encoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[1].EncDecAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)

In [32]:
model = T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True)
use_lora = True
lora_r = 16
if use_lora:
    apply_lora(model, lora_r=lora_r)
    model.load_state_dict(torch.load('t5-base.pth'), strict=False)

model.load_state_dict(torch.load('results/released_weights/t5_base_lora_scienceqa.pth'), strict=False)

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
optimizer = optim.AdamW(model.parameters(), lr=3e-3, eps=1e-8, weight_decay=0.0) # For lora, use 3e-3, otherwise 1e-4 learning rate
q_len = 512   # Question Length
t_len = 64    # Target Length
train_batch_size = 16
val_batch_size = 8
device = 'cuda:0'
num_epochs = 20

In [35]:
# Loading the data
#val_file = 'datasets/CommonsenseQA/commonsenseqa_mcq_val.json'
#val_file = 'datasets/MedQA/medqa_mcq_val.json'
val_file = 'datasets/ScienceQA/scienceqa_mcq_val.json'
    
with open(val_file) as file:
    val_data = json.load(file)

# Create Dataframes
val_data = pd.DataFrame(val_data)

# Dataloader
val_dataset = QA_Dataset(tokenizer, val_data, q_len, t_len)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

trainer = Trainer(model, optimizer, tokenizer, None, val_loader, use_lora=use_lora, device=device)

In [36]:
trainer.validate_epoch()

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

{'val_loss': 0.010684275164776273,
 'val_acc': 0.7933768656716418,
 'bleu_score': 0.9690484724535622}

In [37]:
i = 12
data = val_dataset.__getitem__(i)
input_ids = data['input_ids'].to(device).unsqueeze(0)
attention_mask = data['attention_mask'].to(device).unsqueeze(0)

outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
predicted_answer = tokenizer.decode(outputs.flatten(), skip_special_tokens=True)

print(data['question'], '\n')
print('predicted_answer:', predicted_answer)
print('correct_answer:', data['ref_answer'])

Compare the motion of three ships. Which ship was moving at the lowest speed? 
 (A) a ship that moved 555kilometers west in 10hours (B) a ship that moved 95kilometers south in 10hours (C) a ship that moved 460kilometers south in 10hours 

predicted_answer: a ship that moved 95kilometers south in 10hours
correct_answer: a ship that moved 95kilometers south in 10hours
