In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm_notebook

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler

import nltk
import spacy
import string
import evaluate
import transformers

from sklearn.model_selection import train_test_split
from transformers import T5ForConditionalGeneration, T5TokenizerFast

import loralib as lora

import warnings
warnings.filterwarnings("ignore")

In [2]:
class QA_Dataset(Dataset):
    '''
    Follow the question answering input format of UnifiedQA: https://arxiv.org/pdf/2005.00700.pdf
    '''
    def __init__(self, tokenizer, dataframe, q_len, t_len):
        self.tokenizer = tokenizer
        self.q_len = q_len
        self.t_len = t_len
        self.data = dataframe
        self.question = self.data['question']
        self.choices = self.data['choices']
        self.label = self.data['label']
        
    def __len__(self):
        return len(self.question)
    
    def __getitem__(self, idx):
        question = self.question[idx]
        choices = self.choices[idx]
        label = int(self.label[idx])
        answer = choices[label]
        
        # Append choices to question following style of UnifiedQA
        # question \n (A) c1 (B) c2 . . .       
        letters = ['(A)', '(B)', '(C)', '(D)', '(E)']
        question = question + ' \n'
        for i, c in enumerate(choices):
            question += f' {letters[i]} {c}'
        question_for_tok =  question
        answer_for_tok = answer
        question_tokenized = self.tokenizer(question_for_tok, max_length=self.q_len, padding="max_length",
                                            truncation=True, pad_to_max_length=True, add_special_tokens=True)
        answer_tokenized = self.tokenizer(answer_for_tok, max_length=self.t_len, padding="max_length", 
                                          truncation=True, pad_to_max_length=True, add_special_tokens=True)
    
        return {
            "input_ids": torch.tensor(question_tokenized["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(question_tokenized["attention_mask"], dtype=torch.long),
            "decoder_input_ids": torch.tensor(answer_tokenized["input_ids"], dtype=torch.long),
            "decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long),
            "question": question,
            "ref_answer": answer,
        }


class Trainer:
    def __init__(self, model, optimizer, tokenizer, 
                 train_loader, 
                 val_loader1, 
                 val_loader2, 
                 val_loader3, 
                 device='cuda'):
        self.model = model
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.train_loader = train_loader
        self.val_loader1 = val_loader1
        self.val_loader2 = val_loader2
        self.val_loader3 = val_loader3
        self.device = device
        self.bleu = evaluate.load("google_bleu")

    def train_epoch(self):
        self.model = self.model.to(self.device)
        self.model.train()
        lora.mark_only_lora_as_trainable(self.model)

        train_loss = 0
        train_batch_count = 0
        for batch in tqdm_notebook(self.train_loader, desc="Training batches"):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["decoder_input_ids"].to(self.device)
            decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
    
            outputs = self.model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 labels=labels,
                                 decoder_attention_mask=decoder_attention_mask)
    
            self.optimizer.zero_grad()
            outputs.loss.backward()
            self.optimizer.step()
            train_loss += outputs.loss.item()
            train_batch_count += 1

        return train_loss / train_batch_count

    def validate_epoch(self, dataset_num):
        self.model = self.model.to(self.device)
        self.model.eval()
        val_loss = 0
        val_batch_count = 0
        predicted_answers = []
        ref_answers = []
        correct_num = 0
        total_num = 0

        if dataset_num == 1:
            val_loader = self.val_loader1
        elif dataset_num == 2:
            val_loader = self.val_loader2
        elif dataset_num == 3:
            val_loader = self.val_loader3
        
        for batch in tqdm_notebook(val_loader, desc="Validation batches"):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["decoder_input_ids"].to(self.device)
            decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
            
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     labels=labels,
                                     decoder_attention_mask=decoder_attention_mask)
                val_loss += outputs.loss.item()
                val_batch_count += 1
    
            # Store val outputs and metrics
            bs = outputs.logits.shape[0]
            for b_idx in range(bs):
                logits = outputs.logits[b_idx]
                tokens = torch.argmax(logits, dim=1)
                end_tok_idx = (tokens == 1).nonzero()
    
                if end_tok_idx.size(0) > 0:
                    end_tok_idx = end_tok_idx[0].item()
                    if end_tok_idx+1 < tokens.size(0):
                        tokens[end_tok_idx+1:] = 0
                
                predicted_answer = self.tokenizer.decode(tokens, skip_special_tokens=True)
                ref_answer = batch['ref_answer'][b_idx]
    
                predicted_answers.append(predicted_answer)
                ref_answers.append(ref_answer)
                if ref_answer in predicted_answer:
                    correct_num += 1
                total_num += 1
        
        # Finish calculating val metrics
        val_acc = correct_num / total_num
        bleu_score = self.bleu.compute(predictions=predicted_answers, references=ref_answers)['google_bleu']

        return {'val_loss': val_loss/val_batch_count, 'val_acc': val_acc, 'bleu_score': bleu_score}


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_acc = -float('inf')

    def early_stop(self, validation_acc):
        if validation_acc > self.min_validation_acc + self.min_delta:
            self.min_validation_acc = validation_acc
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def apply_lora(model, num_blocks=12, model_d=768, lora_r=16):
    # Apply LoRA to all attention matrices in the transformer block: q,k,v,o
    for i in range(num_blocks):
        model.encoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[1].EncDecAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)

In [3]:
model = T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True)
use_lora = True
lora_r = 16
if use_lora:
    apply_lora(model, lora_r=lora_r)
    model.load_state_dict(torch.load('t5-base.pth'), strict=False)

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
optimizer = optim.AdamW(model.parameters(), lr=3e-3, eps=1e-8, weight_decay=0.0) # For lora, use 3e-3, otherwise 1e-4 learning rate
q_len = 512   # Question Length
t_len = 64    # Target Length
train_batch_size = 16
val_batch_size = 8
device = 'cuda:0'
num_epochs = 100

In [4]:
# Loading the data
commonsense_train_file = 'datasets/CommonsenseQA/commonsenseqa_mcq_train.json'
commonsense_val_file = 'datasets/CommonsenseQA/commonsenseqa_mcq_val.json'
medical_train_file = 'datasets/MedQA/medqa_mcq_train.json'
medical_val_file = 'datasets/MedQA/medqa_mcq_val.json'
science_train_file = 'datasets/ScienceQA/scienceqa_mcq_train.json'
science_val_file = 'datasets/ScienceQA/scienceqa_mcq_val.json'

with open(commonsense_train_file) as file:
    commonsense_train_data = json.load(file)
    
with open(commonsense_val_file) as file:
    commonsense_val_data = json.load(file)

with open(medical_train_file) as file:
    medical_train_data = json.load(file)
    
with open(medical_val_file) as file:
    medical_val_data = json.load(file)

with open(science_train_file) as file:
    science_train_data = json.load(file)
    
with open(science_val_file) as file:
    science_val_data = json.load(file)

# Create Dataframes
commonsense_train_data = pd.DataFrame(commonsense_train_data)[:256]
commonsense_val_data = pd.DataFrame(commonsense_val_data)
medical_train_data = pd.DataFrame(medical_train_data)[:256]
medical_val_data = pd.DataFrame(medical_val_data)
science_train_data = pd.DataFrame(science_train_data)[:256]
science_val_data = pd.DataFrame(science_val_data)

# Dataset
commonsense_train_dataset = QA_Dataset(tokenizer, commonsense_train_data, q_len, t_len)
commonsense_val_dataset = QA_Dataset(tokenizer, commonsense_val_data, q_len, t_len)
medical_train_dataset = QA_Dataset(tokenizer, medical_train_data, q_len, t_len)
medical_val_dataset = QA_Dataset(tokenizer, medical_val_data, q_len, t_len)
science_train_dataset = QA_Dataset(tokenizer, science_train_data, q_len, t_len)
science_val_dataset = QA_Dataset(tokenizer, science_val_data, q_len, t_len)
combined_train_dataset = torch.utils.data.ConcatDataset([commonsense_train_dataset, medical_train_dataset, science_train_dataset])

# Dataloader
commonsense_train_loader = DataLoader(commonsense_train_dataset, batch_size=train_batch_size, shuffle=True)
commonsense_val_loader = DataLoader(commonsense_val_dataset, batch_size=val_batch_size, shuffle=False)

medical_train_loader = DataLoader(medical_train_dataset, batch_size=train_batch_size, shuffle=True)
medical_val_loader = DataLoader(medical_val_dataset, batch_size=val_batch_size, shuffle=False)

science_train_loader = DataLoader(science_train_dataset, batch_size=train_batch_size, shuffle=True)
science_val_loader = DataLoader(science_val_dataset, batch_size=val_batch_size, shuffle=False)

combined_train_loader = DataLoader(combined_train_dataset, batch_size=train_batch_size, shuffle=True)

# Trainer
trainer = Trainer(model, optimizer, tokenizer, 
                  combined_train_loader, 
                  commonsense_val_loader,
                  medical_val_loader,
                  science_val_loader,
                  device=device)

In [5]:
# Training loop
loss_log = []
val_metrics_log = []
best_val_acc = -1.0
best_model_path = ''
early_stopping = EarlyStopper(patience=3, min_delta=1e-3)

# Initial validation
val_metrics1 = trainer.validate_epoch(1)
val_metrics2 = trainer.validate_epoch(2)
val_metrics3 = trainer.validate_epoch(3)

for val_metrics in [val_metrics1, val_metrics2, val_metrics3]:
    val_loss = val_metrics['val_loss']
    val_acc = val_metrics['val_acc']
    val_bleu = val_metrics['bleu_score']
    print(f"{0}/{num_epochs} -> Validation Acc: {val_acc:.3f} \tValidation Bleu: {val_bleu:.3f}")

for epoch in range(num_epochs):
    train_loss = trainer.train_epoch()

    val_metrics1 = trainer.validate_epoch(1)
    val_metrics2 = trainer.validate_epoch(2)
    val_metrics3 = trainer.validate_epoch(3)

    avg_val_loss = 0.0
    avg_val_acc = 0.0
    avg_val_bleu = 0.0
    for val_metrics in [val_metrics1, val_metrics2, val_metrics3]:
        val_loss = val_metrics['val_loss']
        val_acc = val_metrics['val_acc']
        val_bleu = val_metrics['bleu_score']
        avg_val_loss += val_loss
        avg_val_acc += val_acc
        avg_val_bleu += val_bleu
        print(f"{epoch+1}/{num_epochs} -> Validation Acc: {val_acc:.3f} \tValidation Bleu: {val_bleu:.3f}")

    avg_val_loss /= 3
    avg_val_acc /= 3
    avg_val_bleu /= 3
    
    loss_log.append((train_loss, avg_val_loss))
    val_metrics_log.append((avg_val_acc, avg_val_bleu))
    
    print('Saving model...')   
    checkpoint_path = f'results/combined_dataset/t5_base_lora_epoch{epoch+1}.pth'
    torch.save(lora.cgm_state_dict(model), checkpoint_path)

    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        best_model_path = checkpoint_path
    
    print(f"{epoch+1}/{num_epochs} -> Train loss: {train_loss} \tValidation loss: {avg_val_loss} " + \
          f"\tValidation Acc: {avg_val_acc:.3f} \tValidation Bleu: {avg_val_bleu:.3f}")

    if early_stopping.early_stop(avg_val_acc):
        break

torch.save(torch.load(best_model_path), 'results/combined_dataset/t5_base_lora_best.pth')

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

0/100 -> Validation Acc: 0.054 	Validation Bleu: 0.003
0/100 -> Validation Acc: 0.004 	Validation Bleu: 0.082
0/100 -> Validation Acc: 0.093 	Validation Bleu: 0.111


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

1/100 -> Validation Acc: 0.037 	Validation Bleu: 0.246
1/100 -> Validation Acc: 0.024 	Validation Bleu: 0.455
1/100 -> Validation Acc: 0.069 	Validation Bleu: 0.678
Saving model...
1/100 -> Train loss: 2.873419631893436 	Validation loss: 0.22735098754626767 	Validation Acc: 0.043 	Validation Bleu: 0.460


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

2/100 -> Validation Acc: 0.273 	Validation Bleu: 0.437
2/100 -> Validation Acc: 0.190 	Validation Bleu: 0.599
2/100 -> Validation Acc: 0.416 	Validation Bleu: 0.733
Saving model...
2/100 -> Train loss: 0.14651714987121522 	Validation loss: 0.04526600818272355 	Validation Acc: 0.293 	Validation Bleu: 0.589


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

3/100 -> Validation Acc: 0.272 	Validation Bleu: 0.440
3/100 -> Validation Acc: 0.199 	Validation Bleu: 0.662
3/100 -> Validation Acc: 0.472 	Validation Bleu: 0.818
Saving model...
3/100 -> Train loss: 0.06316140720931192 	Validation loss: 0.03918512609883584 	Validation Acc: 0.314 	Validation Bleu: 0.640


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

4/100 -> Validation Acc: 0.265 	Validation Bleu: 0.438
4/100 -> Validation Acc: 0.197 	Validation Bleu: 0.665
4/100 -> Validation Acc: 0.504 	Validation Bleu: 0.858
Saving model...
4/100 -> Train loss: 0.048982751982597016 	Validation loss: 0.037080078579854285 	Validation Acc: 0.322 	Validation Bleu: 0.654


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

5/100 -> Validation Acc: 0.270 	Validation Bleu: 0.448
5/100 -> Validation Acc: 0.194 	Validation Bleu: 0.670
5/100 -> Validation Acc: 0.509 	Validation Bleu: 0.869
Saving model...
5/100 -> Train loss: 0.04578633338678628 	Validation loss: 0.03624834646861452 	Validation Acc: 0.325 	Validation Bleu: 0.662


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

6/100 -> Validation Acc: 0.275 	Validation Bleu: 0.447
6/100 -> Validation Acc: 0.197 	Validation Bleu: 0.678
6/100 -> Validation Acc: 0.510 	Validation Bleu: 0.865
Saving model...
6/100 -> Train loss: 0.041164281273571156 	Validation loss: 0.03632412072179692 	Validation Acc: 0.328 	Validation Bleu: 0.663


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

7/100 -> Validation Acc: 0.268 	Validation Bleu: 0.439
7/100 -> Validation Acc: 0.204 	Validation Bleu: 0.683
7/100 -> Validation Acc: 0.511 	Validation Bleu: 0.882
Saving model...
7/100 -> Train loss: 0.03886799289224049 	Validation loss: 0.035324038505530925 	Validation Acc: 0.327 	Validation Bleu: 0.668


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

8/100 -> Validation Acc: 0.276 	Validation Bleu: 0.443
8/100 -> Validation Acc: 0.208 	Validation Bleu: 0.681
8/100 -> Validation Acc: 0.519 	Validation Bleu: 0.884
Saving model...
8/100 -> Train loss: 0.03552251122891903 	Validation loss: 0.03584271560789006 	Validation Acc: 0.334 	Validation Bleu: 0.669


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

9/100 -> Validation Acc: 0.269 	Validation Bleu: 0.439
9/100 -> Validation Acc: 0.208 	Validation Bleu: 0.685
9/100 -> Validation Acc: 0.529 	Validation Bleu: 0.885
Saving model...
9/100 -> Train loss: 0.03216498193796724 	Validation loss: 0.03679481343336426 	Validation Acc: 0.335 	Validation Bleu: 0.669


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

10/100 -> Validation Acc: 0.276 	Validation Bleu: 0.442
10/100 -> Validation Acc: 0.204 	Validation Bleu: 0.691
10/100 -> Validation Acc: 0.539 	Validation Bleu: 0.891
Saving model...
10/100 -> Train loss: 0.030888583976775408 	Validation loss: 0.0353155946599893 	Validation Acc: 0.339 	Validation Bleu: 0.675


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

11/100 -> Validation Acc: 0.273 	Validation Bleu: 0.442
11/100 -> Validation Acc: 0.198 	Validation Bleu: 0.693
11/100 -> Validation Acc: 0.540 	Validation Bleu: 0.895
Saving model...
11/100 -> Train loss: 0.031221371958963573 	Validation loss: 0.03513013039270322 	Validation Acc: 0.337 	Validation Bleu: 0.677


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

12/100 -> Validation Acc: 0.291 	Validation Bleu: 0.457
12/100 -> Validation Acc: 0.208 	Validation Bleu: 0.690
12/100 -> Validation Acc: 0.528 	Validation Bleu: 0.896
Saving model...
12/100 -> Train loss: 0.027723512495867908 	Validation loss: 0.036249504208571905 	Validation Acc: 0.342 	Validation Bleu: 0.681


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

13/100 -> Validation Acc: 0.293 	Validation Bleu: 0.456
13/100 -> Validation Acc: 0.210 	Validation Bleu: 0.704
13/100 -> Validation Acc: 0.538 	Validation Bleu: 0.897
Saving model...
13/100 -> Train loss: 0.027334815240465105 	Validation loss: 0.0365947201083035 	Validation Acc: 0.347 	Validation Bleu: 0.685


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

14/100 -> Validation Acc: 0.287 	Validation Bleu: 0.457
14/100 -> Validation Acc: 0.215 	Validation Bleu: 0.697
14/100 -> Validation Acc: 0.541 	Validation Bleu: 0.893
Saving model...
14/100 -> Train loss: 0.025598849053494632 	Validation loss: 0.03707989630847227 	Validation Acc: 0.348 	Validation Bleu: 0.682


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

15/100 -> Validation Acc: 0.295 	Validation Bleu: 0.458
15/100 -> Validation Acc: 0.217 	Validation Bleu: 0.694
15/100 -> Validation Acc: 0.542 	Validation Bleu: 0.890
Saving model...
15/100 -> Train loss: 0.024422992932765435 	Validation loss: 0.03822092952176806 	Validation Acc: 0.351 	Validation Bleu: 0.681


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

16/100 -> Validation Acc: 0.302 	Validation Bleu: 0.464
16/100 -> Validation Acc: 0.215 	Validation Bleu: 0.699
16/100 -> Validation Acc: 0.550 	Validation Bleu: 0.901
Saving model...
16/100 -> Train loss: 0.02281761683601265 	Validation loss: 0.04095962001919321 	Validation Acc: 0.356 	Validation Bleu: 0.688


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

17/100 -> Validation Acc: 0.281 	Validation Bleu: 0.450
17/100 -> Validation Acc: 0.223 	Validation Bleu: 0.713
17/100 -> Validation Acc: 0.538 	Validation Bleu: 0.904
Saving model...
17/100 -> Train loss: 0.021152956876903772 	Validation loss: 0.04046632647784868 	Validation Acc: 0.347 	Validation Bleu: 0.689


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

18/100 -> Validation Acc: 0.288 	Validation Bleu: 0.458
18/100 -> Validation Acc: 0.241 	Validation Bleu: 0.725
18/100 -> Validation Acc: 0.528 	Validation Bleu: 0.900
Saving model...
18/100 -> Train loss: 0.019902685986987006 	Validation loss: 0.0407011130415709 	Validation Acc: 0.352 	Validation Bleu: 0.694


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

19/100 -> Validation Acc: 0.301 	Validation Bleu: 0.462
19/100 -> Validation Acc: 0.228 	Validation Bleu: 0.716
19/100 -> Validation Acc: 0.553 	Validation Bleu: 0.908
Saving model...
19/100 -> Train loss: 0.019368583549900602 	Validation loss: 0.044071596497436145 	Validation Acc: 0.360 	Validation Bleu: 0.695


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

20/100 -> Validation Acc: 0.292 	Validation Bleu: 0.459
20/100 -> Validation Acc: 0.228 	Validation Bleu: 0.699
20/100 -> Validation Acc: 0.536 	Validation Bleu: 0.897
Saving model...
20/100 -> Train loss: 0.018121032684575766 	Validation loss: 0.042442650503216245 	Validation Acc: 0.352 	Validation Bleu: 0.685


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

21/100 -> Validation Acc: 0.296 	Validation Bleu: 0.454
21/100 -> Validation Acc: 0.216 	Validation Bleu: 0.695
21/100 -> Validation Acc: 0.552 	Validation Bleu: 0.908
Saving model...
21/100 -> Train loss: 0.01649641814098383 	Validation loss: 0.04859397884015479 	Validation Acc: 0.355 	Validation Bleu: 0.686


Training batches:   0%|          | 0/48 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/159 [00:00<?, ?it/s]

Validation batches:   0%|          | 0/268 [00:00<?, ?it/s]

22/100 -> Validation Acc: 0.305 	Validation Bleu: 0.470
22/100 -> Validation Acc: 0.226 	Validation Bleu: 0.719
22/100 -> Validation Acc: 0.540 	Validation Bleu: 0.909
Saving model...
22/100 -> Train loss: 0.015007802847928057 	Validation loss: 0.04881291497357706 	Validation Acc: 0.357 	Validation Bleu: 0.699


In [16]:
# Qualitative examples
i = 1
data = val_dataset.__getitem__(i)
input_ids = data['input_ids'].to(device).unsqueeze(0)
attention_mask = data['attention_mask'].to(device).unsqueeze(0)

outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
predicted_answer = tokenizer.decode(outputs.flatten(), skip_special_tokens=True)

print(data['question'], '\n')
print('predicted_answer:', predicted_answer)
print('correct_answer:', data['ref_answer'])

What do people aim to do at work? 
 (A) complete job (B) learn from each other (C) kill animals (D) wear hats (E) talk to each other 

predicted_answer: complete job
correct_answer: complete job
