In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm_notebook

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler

import nltk
import spacy
import string
import evaluate
import transformers

from sklearn.model_selection import train_test_split
from transformers import T5ForConditionalGeneration, T5TokenizerFast

import loralib as lora

import warnings
warnings.filterwarnings("ignore")

In [2]:
class QA_Dataset(Dataset):
    '''
    Follow the question answering input format of UnifiedQA: https://arxiv.org/pdf/2005.00700.pdf
    '''
    def __init__(self, tokenizer, dataframe, q_len, t_len):
        self.tokenizer = tokenizer
        self.q_len = q_len
        self.t_len = t_len
        self.data = dataframe
        self.question = self.data['question']
        self.choices = self.data['choices']
        self.label = self.data['label']
        
    def __len__(self):
        return len(self.question)
    
    def __getitem__(self, idx):
        question = self.question[idx]
        choices = self.choices[idx]
        label = int(self.label[idx])
        answer = choices[label]
        
        # Append choices to question following style of UnifiedQA
        # question \n (A) c1 (B) c2 . . .       
        letters = ['(A)', '(B)', '(C)', '(D)', '(E)']
        question = question + ' \n'
        for i, c in enumerate(choices):
            question += f' {letters[i]} {c}'
        question_for_tok =  question
        answer_for_tok = answer
        question_tokenized = self.tokenizer(question_for_tok, max_length=self.q_len, padding="max_length",
                                            truncation=True, pad_to_max_length=True, add_special_tokens=True)
        answer_tokenized = self.tokenizer(answer_for_tok, max_length=self.t_len, padding="max_length", 
                                          truncation=True, pad_to_max_length=True, add_special_tokens=True)
    
        return {
            "input_ids": torch.tensor(question_tokenized["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(question_tokenized["attention_mask"], dtype=torch.long),
            "decoder_input_ids": torch.tensor(answer_tokenized["input_ids"], dtype=torch.long),
            "decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long),
            "question": question,
            "ref_answer": answer,
        }


class Trainer:
    def __init__(self, model, optimizer, tokenizer, 
                 train_loader1, val_loader1, 
                 train_loader2, val_loader2, 
                 train_loader3, val_loader3, 
                 device='cuda'):
        self.model = model
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.train_loader1 = train_loader1
        self.val_loader1 = val_loader1
        self.train_loader2 = train_loader2
        self.val_loader2 = val_loader2
        self.train_loader3 = train_loader3
        self.val_loader3 = val_loader3
        self.device = device
        self.bleu = evaluate.load("google_bleu")

        assert len(train_loader1) == len(train_loader2)
        assert len(train_loader2) == len(train_loader3)

    def train_epoch(self):
        self.model = self.model.to(self.device)
        self.model.train()

        train_loss = 0
        train_batch_count = 0
        for multi_data_batch in tqdm_notebook(zip(self.train_loader1, self.train_loader2, self.train_loader3), 
                                              desc="Training batches", total=len(self.train_loader1)):
            for batch in multi_data_batch:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["decoder_input_ids"].to(self.device)
                decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
        
                outputs = self.model(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     labels=labels,
                                     decoder_attention_mask=decoder_attention_mask)
        
                self.optimizer.zero_grad()
                outputs.loss.backward()
                self.optimizer.step()
                train_loss += outputs.loss.item()
                train_batch_count += 1

        return train_loss / train_batch_count

    def validate_epoch(self, dataset_num):
        self.model = self.model.to(self.device)
        self.model.eval()
        val_loss = 0
        val_batch_count = 0
        predicted_answers = []
        ref_answers = []
        correct_num = 0
        total_num = 0

        if dataset_num == 1:
            val_loader = self.val_loader1
        elif dataset_num == 2:
            val_loader = self.val_loader2
        elif dataset_num == 3:
            val_loader = self.val_loader3
        
        for batch in tqdm_notebook(val_loader, desc="Validation batches"):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["decoder_input_ids"].to(self.device)
            decoder_attention_mask = batch["decoder_attention_mask"].to(self.device)
            
            with torch.no_grad():
                outputs = self.model(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     labels=labels,
                                     decoder_attention_mask=decoder_attention_mask)
                val_loss += outputs.loss.item()
                val_batch_count += 1
    
            # Store val outputs and metrics
            bs = outputs.logits.shape[0]
            for b_idx in range(bs):
                logits = outputs.logits[b_idx]
                tokens = torch.argmax(logits, dim=1)
                end_tok_idx = (tokens == 1).nonzero()
    
                if end_tok_idx.size(0) > 0:
                    end_tok_idx = end_tok_idx[0].item()
                    if end_tok_idx+1 < tokens.size(0):
                        tokens[end_tok_idx+1:] = 0
                
                predicted_answer = self.tokenizer.decode(tokens, skip_special_tokens=True)
                ref_answer = batch['ref_answer'][b_idx]
    
                predicted_answers.append(predicted_answer)
                ref_answers.append(ref_answer)
                if ref_answer in predicted_answer:
                    correct_num += 1
                total_num += 1
        
        # Finish calculating val metrics
        val_acc = correct_num / total_num
        bleu_score = self.bleu.compute(predictions=predicted_answers, references=ref_answers)['google_bleu']

        return {'val_loss': val_loss/val_batch_count, 'val_acc': val_acc, 'bleu_score': bleu_score}


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_acc = -float('inf')

    def early_stop(self, validation_acc):
        if validation_acc > self.min_validation_acc + self.min_delta:
            self.min_validation_acc = validation_acc
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def apply_lora(model, num_blocks=12, model_d=768, lora_r=16):
    # Apply LoRA to all attention matrices in the transformer block: q,k,v,o
    for i in range(num_blocks):
        model.encoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.encoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[0].SelfAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[0].SelfAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        
        model.decoder.block[i].layer[1].EncDecAttention.q = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.k = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.v = lora.Linear(model_d, model_d, r=lora_r, bias=False)
        model.decoder.block[i].layer[1].EncDecAttention.o = lora.Linear(model_d, model_d, r=lora_r, bias=False)

def apply_lora_cgm(model,
                   commonsense_lora_weights,
                   medical_lora_weights,
                   science_lora_weights,
                   num_blocks=12):
    # Apply LoRA to all attention matrices in the transformer block: q,k,v,o
    for i in range(num_blocks):
        
        lora_A_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_A'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_A'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_B'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_B'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.q.lora_B']]
        model.encoder.block[i].layer[0].SelfAttention.q = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_A'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_A'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_B'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_B'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.k.lora_B']]
        model.encoder.block[i].layer[0].SelfAttention.k = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_A'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_A'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_B'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_B'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.v.lora_B']]
        model.encoder.block[i].layer[0].SelfAttention.v = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_A'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_A'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_B'],
                          medical_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_B'],
                          science_lora_weights[f'encoder.block.{i}.layer.0.SelfAttention.o.lora_B']]
        model.encoder.block[i].layer[0].SelfAttention.o = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.q.lora_B']]
        model.decoder.block[i].layer[0].SelfAttention.q = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.k.lora_B']]
        model.decoder.block[i].layer[0].SelfAttention.k = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.v.lora_B']]
        model.decoder.block[i].layer[0].SelfAttention.v = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.0.SelfAttention.o.lora_B']]
        model.decoder.block[i].layer[0].SelfAttention.o = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.q.lora_B']]
        model.decoder.block[i].layer[1].EncDecAttention.q = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.k.lora_B']]
        model.decoder.block[i].layer[1].EncDecAttention.k = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.v.lora_B']]
        model.decoder.block[i].layer[1].EncDecAttention.v = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

        lora_A_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_A'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_A'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_A']]
        lora_B_weights = [commonsense_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_B'],
                          medical_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_B'],
                          science_lora_weights[f'decoder.block.{i}.layer.1.EncDecAttention.o.lora_B']]
        model.decoder.block[i].layer[1].EncDecAttention.o = lora.CGMLinear(lora_A_weights, lora_B_weights, bias=False)

In [3]:
# Load base model
model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)

# Load finetuned LoRA weights
commonsense_lora_weights = torch.load('results/commonsense_qa/t5_base_lora_best.pth')
medical_lora_weights = torch.load('results/medical_qa/t5_base_lora_best.pth')
science_lora_weights = torch.load('results/science_qa/t5_base_lora_best.pth')

# Apply lora cgm and reload base model weights
apply_lora_cgm(model, commonsense_lora_weights, medical_lora_weights, science_lora_weights)
model.load_state_dict(torch.load('t5-base.pth'), strict=False)

# Set only mixing module weights to require grad
for n, p in model.named_parameters():
    if 'context_gated_mixing' in n:
        p.requires_grad = True
    else:
        p.requires_grad = False

tokenizer = T5TokenizerFast.from_pretrained('t5-base')
optimizer = optim.AdamW(model.parameters(), lr=3e-3, eps=1e-8, weight_decay=0.0) # For lora, use 3e-3, otherwise 1e-4 learning rate
q_len = 512   # Question Length
t_len = 64    # Target Length
train_batch_size = 16
val_batch_size = 8
device = 'cuda:0'
num_epochs = 100

In [4]:
# Loading the data
commonsense_train_file = 'datasets/CommonsenseQA/commonsenseqa_mcq_train.json'
commonsense_val_file = 'datasets/CommonsenseQA/commonsenseqa_mcq_val.json'
medical_train_file = 'datasets/MedQA/medqa_mcq_train.json'
medical_val_file = 'datasets/MedQA/medqa_mcq_val.json'
science_train_file = 'datasets/ScienceQA/scienceqa_mcq_train.json'
science_val_file = 'datasets/ScienceQA/scienceqa_mcq_val.json'

with open(commonsense_train_file) as file:
    commonsense_train_data = json.load(file)
    
with open(commonsense_val_file) as file:
    commonsense_val_data = json.load(file)

with open(medical_train_file) as file:
    medical_train_data = json.load(file)
    
with open(medical_val_file) as file:
    medical_val_data = json.load(file)

with open(science_train_file) as file:
    science_train_data = json.load(file)
    
with open(science_val_file) as file:
    science_val_data = json.load(file)

# Create Dataframes
commonsense_train_data = pd.DataFrame(commonsense_train_data)[:256]
commonsense_val_data = pd.DataFrame(commonsense_val_data)
medical_train_data = pd.DataFrame(medical_train_data)[:256]
medical_val_data = pd.DataFrame(medical_val_data)
science_train_data = pd.DataFrame(science_train_data)[:256]
science_val_data = pd.DataFrame(science_val_data)

# Dataset
commonsense_train_dataset = QA_Dataset(tokenizer, commonsense_train_data, q_len, t_len)
commonsense_val_dataset = QA_Dataset(tokenizer, commonsense_val_data, q_len, t_len)
medical_train_dataset = QA_Dataset(tokenizer, medical_train_data, q_len, t_len)
medical_val_dataset = QA_Dataset(tokenizer, medical_val_data, q_len, t_len)
science_train_dataset = QA_Dataset(tokenizer, science_train_data, q_len, t_len)
science_val_dataset = QA_Dataset(tokenizer, science_val_data, q_len, t_len)


# Dataloader
commonsense_train_loader = DataLoader(commonsense_train_dataset, batch_size=train_batch_size, shuffle=True)
commonsense_val_loader = DataLoader(commonsense_val_dataset, batch_size=val_batch_size, shuffle=False)

medical_train_loader = DataLoader(medical_train_dataset, batch_size=train_batch_size, shuffle=True)
medical_val_loader = DataLoader(medical_val_dataset, batch_size=val_batch_size, shuffle=False)

science_train_loader = DataLoader(science_train_dataset, batch_size=train_batch_size, shuffle=True)
science_val_loader = DataLoader(science_val_dataset, batch_size=val_batch_size, shuffle=False)

# Trainer
trainer = Trainer(model, optimizer, tokenizer, 
                  commonsense_train_loader, commonsense_val_loader,
                  medical_train_loader, medical_val_loader,
                  science_train_loader, science_val_loader,
                  device=device)

In [5]:
# Training loop
loss_log = []
val_metrics_log = []
best_val_acc = -1.0
best_model_path = ''
early_stopping = EarlyStopper(patience=5, min_delta=1e-3)

# Initial validation
val_metrics1 = trainer.validate_epoch(1)
val_metrics2 = trainer.validate_epoch(2)
val_metrics3 = trainer.validate_epoch(3)

for val_metrics in [val_metrics1, val_metrics2, val_metrics3]:
    val_loss = val_metrics['val_loss']
    val_acc = val_metrics['val_acc']
    val_bleu = val_metrics['bleu_score']
    print(f"{0}/{num_epochs} -> Validation Acc: {val_acc:.3f} \tValidation Bleu: {val_bleu:.3f}")

for epoch in range(100, 200):
    train_loss = trainer.train_epoch()

    val_metrics1 = trainer.validate_epoch(1)
    val_metrics2 = trainer.validate_epoch(2)
    val_metrics3 = trainer.validate_epoch(3)

    avg_val_loss = 0.0
    avg_val_acc = 0.0
    avg_val_bleu = 0.0
    for val_metrics in [val_metrics1, val_metrics2, val_metrics3]:
        val_loss = val_metrics['val_loss']
        val_acc = val_metrics['val_acc']
        val_bleu = val_metrics['bleu_score']
        avg_val_loss += val_loss
        avg_val_acc += val_acc
        avg_val_bleu += val_bleu
        print(f"{epoch+1}/{num_epochs} -> Validation Acc: {val_acc:.3f} \tValidation Bleu: {val_bleu:.3f}")

    avg_val_loss /= 3
    avg_val_acc /= 3
    avg_val_bleu /= 3
    
    loss_log.append((train_loss, avg_val_loss))
    val_metrics_log.append((avg_val_acc, avg_val_bleu))
    
    print('Saving model...')   
    checkpoint_path = f'results/cgm_coarse/t5_base_lora_cgm_epoch{epoch+1}.pth'
    torch.save(lora.cgm_state_dict(model), checkpoint_path)

    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        best_model_path = checkpoint_path
    
    print(f"{epoch+1}/{num_epochs} -> Train loss: {train_loss} \tValidation loss: {avg_val_loss} " + \
          f"\tValidation Acc: {avg_val_acc:.3f} \tValidation Bleu: {avg_val_bleu:.3f}")

    if early_stopping.early_stop(avg_val_acc):
        break

torch.save(torch.load(best_model_path), 'results/cgm_coarse/t5_base_lora_cgm_best.pth')

Validation batches:   0%|          | 0/153 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
best_model_path

'results/cgm_coarse/t5_base_lora_cgm_epoch151.pth'

In [None]:
# Number of additional parameters: 445104
# Can compute metrics using accuracy delta from finetuned and report avereage accuracy delta

In [23]:
i = 10
data = medical_val_dataset.__getitem__(i)
input_ids = data['input_ids'].to(device).unsqueeze(0)
attention_mask = data['attention_mask'].to(device).unsqueeze(0)

outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
predicted_answer = tokenizer.decode(outputs.flatten(), skip_special_tokens=True)

print(data['question'], '\n')
print('predicted_answer:', predicted_answer)
print('correct_answer:', data['ref_answer'])

A 42-year-old male presents to his primary care physician complaining of abdominal pain. He reports a 5-month history of epigastric pain that improves with meals. He has lost 15 pounds since the pain started. His past medical history is significant for a prolactinoma for which he underwent transphenoidal resection. He drinks alcohol socially and has a 10 pack-year smoking history. His family history is notable for a maternal uncle with a parathyroid adenoma. His temperature is 98.8°F (37.1°C), blood pressure is 125/80 mmHg, pulse is 85/min, and respirations are 18/min. After further workup, the patient is started on octreotide, an analogue of an endogenously produced hormone. When this hormone is produced by the hypothalamus, it has which of the following effects? 
 (A) Decrease production of growth hormone (B) Decrease production of cholecystokinin (C) Decrease production of prolactin (D) Decrease production of gastrin (E) Decrease production of thyrotropin-releasing hormone 

predict