In [1]:
import json
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoModel, AutoTokenizer, AdamW
from sklearn.model_selection import train_test_split
from collections import Counter

from tqdm import tqdm
from utils import *
from contrastive_utils import *
import random
# from class_balanced_loss import *



%load_ext autoreload
%autoreload 2

In [2]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) |available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) |available.
We will use the GPU: NVIDIA GeForce GTX 1080 Ti


In [4]:
expert_train_processed,artificial_train_processed,unlabeled_processed,expert_test_processed = load_pubmedqa_data(data_path)

In [5]:
# model_name = "dmis-lab/biobert-base-cased-v1.1"
# model_name = "nlpie/bio-mobilebert"
# model_name = 'nlpie/bio-tinybert'
model_name = "nlpie/tiny-biobert"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = PubMedQAContrastive(model_name).to(device)

Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
train_dataset = PubMedQADataset(expert_train_processed + artificial_train_processed, tokenizer,max_length = 512)
unlabeled_dataset = PubMedQADataset(unlabeled_processed, tokenizer,max_length = 512)
test_dataset = PubMedQADataset(expert_test_processed,tokenizer,max_length = 512)

In [8]:
test_set_path = os.path.join(data_path,"test_set.json")
df_test = pd.read_json(test_set_path).T

In [9]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

23087604

In [10]:
def compute_class_counts(dataset):
    class_counts = Counter()
    for data in tqdm(dataset):
        label = data['label'].item()
        class_counts[label] += 1
    return class_counts

In [11]:

# Compute class counts and weights
class_counts = compute_class_counts(train_dataset)

100%|██████████| 211769/211769 [07:02<00:00, 501.26it/s]


In [30]:
class_count_list = [196420, 15294, 55]

In [31]:
# class_count_list = [class_counts[i] for i in range(len(class_counts))]
class_weights = [max(class_count_list) / count for count in class_count_list]
class_weights = torch.tensor(class_weights, dtype=torch.float).to('cuda')

In [32]:
class_weights

tensor([1.0000e+00, 1.2843e+01, 3.5713e+03], device='cuda:0')

In [11]:
labeled_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32, shuffle=False)

In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [13]:
num_epochs = 2

In [14]:
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [15]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Create and train the classifier
classifier = PubMedQAClassifier(model).to(device)

In [16]:
print(f"Test Accuracy before finetuning : {get_acc(classifier,test_loader,device)}")


100%|██████████| 16/16 [00:03<00:00,  4.60it/s]

Test Accuracy before finetuning : 0.39





## PHASE 2 finetuning

In [33]:
classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=2e-5)
classification_loss_fn = nn.CrossEntropyLoss(weight = class_weights)
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [17]:
def train_classifier(model, dataloader, testloader, optimizer,classification_loss_fn,generation_loss_fn, device, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            label = batch['label'].to(device)
            long_answer_ids = batch['long_answer_ids'].to(device)


            classification_logits,long_answer_logits = model(input_ids, attention_mask)
            
            # Compute losses
            classification_loss = classification_loss_fn(classification_logits, label)
            # classification_loss = CB_loss(label.to('cpu'), classification_logits.to('cpu'), class_count_list, num_classes,loss_type, beta, gamma)
            generation_loss = generation_loss_fn(long_answer_logits.view(-1, long_answer_logits.size(-1)), long_answer_ids.view(-1))
    
            # Combine losses
            loss = classification_loss + generation_loss
      
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
            
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")
        print(f"Test Accuracy : {get_acc(model,testloader,device)}")
        model.train()



In [21]:
train_classifier(classifier,labeled_loader,test_loader,classifier_optimizer,classification_loss_fn,
                 generation_loss_fn,device,1)

Epoch 1/1: 100%|██████████| 26472/26472 [55:40<00:00,  7.92it/s, loss=6.27]


Epoch 1/1, Loss: 6.7563


100%|██████████| 16/16 [00:03<00:00,  5.20it/s]

Test Accuracy : 0.546





In [22]:
state = {
    'state_dict': classifier.state_dict(),
    'optimizer': classifier_optimizer.state_dict(),
    'epochs': 1,
    'lr':2e-5
}

torch.save(state, f"weights/{model_name.split('/')[1]}_512_contrastive_labelled_only_QA_model.pt")

In [25]:
checkpoint = torch.load(f"weights/{model_name.split('/')[1]}_512_contrastive_labelled_only_QA_model.pt")
classifier.load_state_dict(checkpoint['state_dict'])
classifier_optimizer.load_state_dict(checkpoint['optimizer'])


In [26]:
get_acc(classifier,test_loader,device)

100%|██████████| 16/16 [00:03<00:00,  5.25it/s]


0.546

In [27]:
save_preds(df_test.index.to_list(),
           get_pred(classifier,test_loader,device),
           "tinybiobert_phase_2_labelled_only",
          pred_dir=save_path)

100%|██████████| 16/16 [00:03<00:00,  5.18it/s]


## Without artificial data

In [18]:
expert_train = PubMedQADataset(expert_train_processed, tokenizer,max_length = 512)
expert_train_loader = DataLoader(expert_train,batch_size=8, shuffle=False)

In [19]:
class_counts = compute_class_counts(expert_train)

100%|██████████| 500/500 [00:01<00:00, 497.39it/s]


In [20]:
class_count_list = [class_counts[i] for i in range(len(class_counts))]
class_weights = [max(class_count_list) / count for count in class_count_list]
class_weights = torch.tensor(class_weights, dtype=torch.float).to('cuda')

In [21]:
classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=2e-5)
classification_loss_fn = nn.CrossEntropyLoss(weight = class_weights)
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [22]:
train_classifier(classifier,expert_train_loader,test_loader,classifier_optimizer,classification_loss_fn,
                 generation_loss_fn,device,1)

Epoch 1/1: 100%|██████████| 63/63 [00:08<00:00,  7.79it/s, loss=12]  


Epoch 1/1, Loss: 11.4237


100%|██████████| 16/16 [00:03<00:00,  5.24it/s]

Test Accuracy : 0.552





In [24]:
save_preds(df_test.index.to_list(),
           get_pred(classifier,test_loader,device),
           "tinybiobert_phase_2_expert_only",
          pred_dir=save_path)

100%|██████████| 16/16 [00:03<00:00,  5.15it/s]
