In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoModel, AutoTokenizer, AdamW
from sklearn.model_selection import train_test_split
from collections import Counter

from tqdm import tqdm
from utils import *
from contrastive_utils import *
import random
# from class_balanced_loss import *



%load_ext autoreload
%autoreload 2

In [2]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) |available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) |available.
We will use the GPU: NVIDIA GeForce GTX 1080 Ti


In [4]:
def train_contrastive(model, dataloader, tokenizer, optimizer, device, epochs,augment,generation_loss_fn):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            long_answer_ids = batch['long_answer_ids'].to(device)
            
            # Create two views of the same batch
            aug_input_ids_1, aug_attention_mask_1 = augment(input_ids, attention_mask, tokenizer)
            aug_input_ids_2, aug_attention_mask_2 = augment(input_ids, attention_mask, tokenizer)
            
            proj_1,long_answer_logits_1 = model(aug_input_ids_1, aug_attention_mask_1)
            proj_2,long_answer_logits_2 = model(aug_input_ids_2, aug_attention_mask_2)
            
            loss = contrastive_loss(proj_1, proj_2)
            generation_loss = generation_loss_fn(long_answer_logits_1.view(-1, long_answer_logits_1.size(-1)), 
                                                 long_answer_ids.view(-1)) + generation_loss_fn(long_answer_logits_2.view(-1, long_answer_logits_2.size(-1)), 
                                                 long_answer_ids.view(-1))

            loss += generation_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

# train_contrastive(model, train_loader, tokenizer, optimizer, device, epochs)

In [5]:
expert_train_processed,artificial_train_processed,unlabeled_processed,expert_test_processed = load_pubmedqa_data(data_path)

In [6]:
# model_name = "dmis-lab/biobert-base-cased-v1.1"
# model_name = "nlpie/bio-mobilebert"
# model_name = 'nlpie/bio-tinybert'
model_name = "nlpie/tiny-biobert"

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = PubMedQAContrastive(model_name).to(device)

Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
train_dataset = PubMedQADataset(expert_train_processed + artificial_train_processed, tokenizer,max_length = 400)
unlabeled_dataset = PubMedQADataset(unlabeled_processed, tokenizer,max_length = 400)
test_dataset = PubMedQADataset(expert_test_processed,tokenizer,max_length = 400)

In [9]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

23087604

In [10]:
def compute_class_counts(dataset):
    class_counts = Counter()
    for data in tqdm(dataset):
        label = data['label'].item()
        class_counts[label] += 1
    return class_counts

In [None]:

# Compute class counts and weights
class_counts = compute_class_counts(train_dataset)

In [11]:
class_count_list = [196420, 15294, 55]

In [12]:
# class_count_list = [class_counts[i] for i in range(len(class_counts))]
class_weights = [max(class_count_list) / count for count in class_count_list]
class_weights = torch.tensor(class_weights, dtype=torch.float).to('cuda')

In [13]:
class_weights

tensor([1.0000e+00, 1.2843e+01, 3.5713e+03], device='cuda:0')

In [14]:
labeled_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32, shuffle=False)

In [15]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [17]:
num_epochs = 4

In [18]:
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [32]:
train_contrastive(model, unlabeled_loader, tokenizer,optimizer, device, 
                  num_epochs, augment,generation_loss_fn)

Epoch 1/2: 100%|██████████| 7657/7657 [23:53<00:00,  5.34it/s, loss=3.81e-6]


Epoch 1/2, Loss: 0.7038


Epoch 2/2: 100%|██████████| 7657/7657 [23:54<00:00,  5.34it/s, loss=0]    

Epoch 2/2, Loss: 0.6967





In [31]:
# torch.save(model.state_dict(), 'pubmedqa_contrastive_model.pth')

state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epochs': 4,
    'lr':5e-5
}

torch.save(state, f"weights/{model_name.split('/')[1]}_contrastive_model.pt")

In [30]:

# Load the trained contrastive model
model = PubMedQAContrastive(model_name).to(device)

checkpoint = torch.load(f"weights/{model_name.split('/')[1]}_contrastive_model.pt")
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])


Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [31]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Create and train the classifier
classifier = PubMedQAClassifier(model).to(device)

Before pre-training = 0.398


After pre-training for 2 epochs = 0.522

After pre-training for 4 epochs = 0.524

after 4 accuracy comes down

In [32]:
print(f"Test Accuracy before finetuning : {get_acc(classifier,test_loader,device)}")


100%|██████████| 16/16 [00:02<00:00,  6.18it/s]

Test Accuracy before finetuning : 0.524





PHASE 2 finetuning

In [33]:
classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=2e-5)
classification_loss_fn = nn.CrossEntropyLoss(weight = class_weights)
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [34]:
def train_classifier(model, dataloader, testloader, optimizer,classification_loss_fn,generation_loss_fn, device, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            label = batch['label'].to(device)
            long_answer_ids = batch['long_answer_ids'].to(device)


            classification_logits,long_answer_logits = model(input_ids, attention_mask)
            
            # Compute losses
            classification_loss = classification_loss_fn(classification_logits, label)
            # classification_loss = CB_loss(label.to('cpu'), classification_logits.to('cpu'), class_count_list, num_classes,loss_type, beta, gamma)
            generation_loss = generation_loss_fn(long_answer_logits.view(-1, long_answer_logits.size(-1)), long_answer_ids.view(-1))
    
            # Combine losses
            loss = classification_loss + generation_loss
      
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
        print(f"Test Accuracy : {get_acc(model,testloader,device)}")
        model.train()



In [35]:
train_classifier(classifier,labeled_loader,test_loader,classifier_optimizer,classification_loss_fn,
                 generation_loss_fn,device,1)

Epoch 1/1: 100%|██████████| 26472/26472 [45:56<00:00,  9.60it/s, loss=7.45] 


Epoch 1/1, Loss: 7.2657


100%|██████████| 16/16 [00:02<00:00,  6.16it/s]

Test Accuracy : 0.614





In [36]:
state = {
    'state_dict': classifier.state_dict(),
    'optimizer': classifier_optimizer.state_dict(),
    'epochs': 2,
    'lr':2e-5
}

torch.save(state, f"weights/{model_name.split('/')[1]}_contrastive_QA_model.pt")