In [1]:
import torch 
import random
from torch import nn
from transformers import AutoModel,AutoTokenizer
from utils import *
from config import data_path,save_path
import torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm


%load_ext autoreload
%autoreload 2

In [2]:
seed = 0
torch.manual_seed(seed)
# torch.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) |available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) |available.
We will use the GPU: NVIDIA GeForce GTX 1080 Ti


In [4]:
class MixMatch:
    def __init__(self, model, tokenizer, num_classes, T=0.5, K=2, alpha=0.75, lambda_u=75):
        self.model = model
        self.tokenizer = tokenizer
        self.num_classes = num_classes
        self.T = T
        self.K = K
        self.alpha = alpha
        self.lambda_u = lambda_u

    def temperature_sharpening(self, probs):
        probs = probs.pow(1/self.T)
        return probs / probs.sum(dim=1, keepdim=True)

    def mixup(self, x1, x2, y1, y2):
        lambda_ = np.random.beta(self.alpha, self.alpha)
        lambda_ = max(lambda_, 1 - lambda_)
        
        # Perform mixup on embeddings instead of input_ids
        with torch.no_grad():
            embed1 = self.model.bert.embeddings.word_embeddings(x1)
            embed2 = self.model.bert.embeddings.word_embeddings(x2)
        
        mixed_embed = lambda_ * embed1 + (1 - lambda_) * embed2
        mixed_y = lambda_ * y1 + (1 - lambda_) * y2
        
        return mixed_embed, mixed_y

    def augment(self, input_ids, attention_mask, p=0.1):
        # Random deletion augmentation
        batch_size, seq_length = input_ids.size()
        augmented_input_ids = input_ids.clone()
        augmented_attention_mask = attention_mask.clone()
        
        for i in range(batch_size):
            # Find the actual sequence length for this example
            actual_length = attention_mask[i].sum().item()
            
            # Only augment if the sequence is long enough
            if actual_length > 2:  # We need at least 3 tokens to perform deletion
                tokens = input_ids[i][:actual_length]
                n_to_delete = max(1, int(p * (actual_length - 2)))  # Ensure we keep at least 2 tokens
                
                # Randomly choose tokens to delete, excluding the first and last tokens
                indices_to_keep = [0] + random.sample(range(1, actual_length - 1), actual_length - 2 - n_to_delete) + [actual_length - 1]
                indices_to_keep.sort()
                
                # Create the new sequence
                new_tokens = tokens[indices_to_keep]
                
                # Pad the sequence to maintain original length
                padding_length = seq_length - len(new_tokens)
                new_tokens = torch.cat([new_tokens, torch.full((padding_length,), self.tokenizer.pad_token_id, device=new_tokens.device)])
                
                # Update input_ids and attention_mask
                augmented_input_ids[i] = new_tokens
                augmented_attention_mask[i] = torch.cat([torch.ones(len(indices_to_keep), device=attention_mask.device), 
                                                         torch.zeros(padding_length, device=attention_mask.device)])
        
        return augmented_input_ids, augmented_attention_mask

    def process_batch(self, labeled_batch, unlabeled_batch):
        x_labeled, y_labeled = labeled_batch['input_ids'], labeled_batch['label']
        x_unlabeled, attention_mask_unlabeled = unlabeled_batch['input_ids'], unlabeled_batch['attention_mask']
    
        with torch.no_grad():
            # Generate pseudo-labels for unlabeled data
            qb_unlabeled = torch.zeros(x_unlabeled.size(0), self.num_classes).to(x_labeled.device)
            for _ in range(self.K):
                x_augmented, attention_mask_augmented = self.augment(x_unlabeled, attention_mask_unlabeled)
                logits = self.model(x_augmented, attention_mask_augmented)
                qb_unlabeled += F.softmax(logits, dim=1)
            qb_unlabeled /= self.K
    
            # Apply temperature sharpening
            qb_unlabeled = self.temperature_sharpening(qb_unlabeled)
    
        # Concatenate labeled and unlabeled data
        x_all = torch.cat([x_labeled, x_unlabeled], dim=0)
        y_all = torch.cat([F.one_hot(y_labeled, num_classes=self.num_classes).float(), qb_unlabeled], dim=0)
        attention_mask_all = torch.cat([labeled_batch['attention_mask'], attention_mask_unlabeled], dim=0)
    
        # Shuffle for MixUp
        indices = torch.randperm(x_all.size(0))
        x_shuffled = x_all[indices]
        y_shuffled = y_all[indices]
        attention_mask_shuffled = attention_mask_all[indices]
    
        # Apply MixUp on embeddings
        mixed_embed, y_mixed = self.mixup(x_all, x_shuffled, y_all, y_shuffled)
        attention_mask_mixed = attention_mask_all  # Attention mask doesn't change in MixUp
    
        # Split mixed data back into labeled and unlabeled
        batch_size = x_labeled.size(0)
        embed_labeled_mixed = mixed_embed[:batch_size]
        y_labeled_mixed = y_mixed[:batch_size]
        embed_unlabeled_mixed = mixed_embed[batch_size:]
        y_unlabeled_mixed = y_mixed[batch_size:]
        attention_mask_labeled_mixed = attention_mask_mixed[:batch_size]
        attention_mask_unlabeled_mixed = attention_mask_mixed[batch_size:]
    
        return embed_labeled_mixed, y_labeled_mixed, embed_unlabeled_mixed, y_unlabeled_mixed, attention_mask_labeled_mixed, attention_mask_unlabeled_mixed
    
    def compute_loss(self, embed_labeled, y_labeled, embed_unlabeled, y_unlabeled, attention_mask_labeled, attention_mask_unlabeled):
        # Compute loss for labeled data
        logits_labeled = self.model.bert(inputs_embeds=embed_labeled, attention_mask=attention_mask_labeled).last_hidden_state[:, 0, :]
        logits_labeled = self.model.classifier(logits_labeled)
        class_weights = torch.tensor([1.0000e+00, 1.2843e+01, 3.5713e+03], dtype=torch.float).to('cuda')
        loss_labeled = F.cross_entropy(logits_labeled, y_labeled.argmax(dim=1),weight = class_weights )
    
        # Compute loss for unlabeled data
        logits_unlabeled = self.model.bert(inputs_embeds=embed_unlabeled, attention_mask=attention_mask_unlabeled).last_hidden_state[:, 0, :]
        logits_unlabeled = self.model.classifier(logits_unlabeled)
        loss_unlabeled = F.mse_loss(F.softmax(logits_unlabeled, dim=1), y_unlabeled)
    
        # Combine losses
        total_loss = loss_labeled + self.lambda_u * loss_unlabeled
    
        return total_loss

In [5]:
class MixMatchDataLoader:
    def __init__(self, labeled_loader, unlabeled_loader):
        self.labeled_loader = labeled_loader
        self.unlabeled_loader = unlabeled_loader
        self.reset()

    def reset(self):
        self.labeled_iter = iter(self.labeled_loader)
        self.unlabeled_iter = iter(self.unlabeled_loader)
        self.num_batches = max(len(self.labeled_loader), len(self.unlabeled_loader))
        self.current_batch = 0

    def __iter__(self):
        self.reset()
        return self

    def __next__(self):
        if self.current_batch >= self.num_batches:
            raise StopIteration

        try:
            labeled_batch = next(self.labeled_iter)
        except StopIteration:
            self.labeled_iter = iter(self.labeled_loader)
            labeled_batch = next(self.labeled_iter)

        try:
            unlabeled_batch = next(self.unlabeled_iter)
        except StopIteration:
            self.unlabeled_iter = iter(self.unlabeled_loader)
            unlabeled_batch = next(self.unlabeled_iter)

        self.current_batch += 1
        return labeled_batch, unlabeled_batch

    def __len__(self):
        return self.num_batches

In [6]:
def train(model, mixmatch, train_loader,test_loader,optimizer, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for labeled_batch, unlabeled_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            labeled_batch = {k: v.to(device) for k, v in labeled_batch.items()}
            unlabeled_batch = {k: v.to(device) for k, v in unlabeled_batch.items()}
            
            embed_labeled_mixed, y_labeled_mixed, embed_unlabeled_mixed, y_unlabeled_mixed, attention_mask_labeled_mixed, attention_mask_unlabeled_mixed = mixmatch.process_batch(labeled_batch, unlabeled_batch)
            
            loss = mixmatch.compute_loss(
                embed_labeled_mixed, y_labeled_mixed, embed_unlabeled_mixed, y_unlabeled_mixed,
                attention_mask_labeled_mixed, attention_mask_unlabeled_mixed
            )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")
        print(f"Test Accuracy after training : {get_acc(model,test_loader,device)}")
        model.train()



In [7]:
class PubMedQAModel(nn.Module):
    def __init__(self, pretrained_model_name='microsoft/biobert-base-cased-v1.1'):
        super().__init__()
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 3)  # 3 classes: yes, no, maybe

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if input_ids is not None:
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        else:
            outputs = self.bert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        logits = self.classifier(outputs.pooler_output)
        return logits

In [8]:
model_name = 'nlpie/tiny-biobert'

In [9]:
model = PubMedQAModel(pretrained_model_name = model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
model = model.to(device)

In [11]:
expert_train_processed,artificial_train_processed,unlabeled_processed,expert_test_processed = load_pubmedqa_data(data_path)

In [12]:
train_dataset = PubMedQADataset(expert_train_processed + artificial_train_processed, tokenizer,max_length = 400)
# val_dataset = PubMedQADataset(expert_val, tokenizer)
unlabeled_dataset = PubMedQADataset(unlabeled_processed, tokenizer,max_length = 400)
test_dataset = PubMedQADataset(expert_test_processed,tokenizer,max_length = 400)

In [13]:
labeled_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32, shuffle=True)

In [14]:
mixmatch_loader = MixMatchDataLoader(labeled_loader, unlabeled_loader)
mixmatch = MixMatch(model, tokenizer, num_classes=3)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)


In [15]:
len(labeled_loader), len(unlabeled_loader)

(26472, 7657)

In [16]:
len(mixmatch_loader)

26472

In [17]:
num_epochs = 2

In [18]:
print(f"Test Accuracy before phase 1 : {get_acc(model,test_loader,device)}")

100%|██████████| 16/16 [00:03<00:00,  5.12it/s]

Test Accuracy before phase 1 : 0.338





In [19]:
train(model, mixmatch, mixmatch_loader, test_loader,optimizer, device, num_epochs)

Epoch 1/2:   0%|          | 0/26472 [00:00<?, ?it/s]

Epoch 1/2, Loss: 1.0637


100%|██████████| 16/16 [00:02<00:00,  5.99it/s]

Test Accuracy after training : 0.336





Epoch 2/2:   0%|          | 0/26472 [00:00<?, ?it/s]

Epoch 2/2, Loss: 0.9089


100%|██████████| 16/16 [00:02<00:00,  6.01it/s]

Test Accuracy after training : 0.338





In [15]:
torch.cuda.empty_cache()
# import gc
# del variables
# gc.collect()

In [20]:
print(f"Test Accuracy after phase 1 : {get_acc(model,test_loader,device)}")

Test Accuracy after phase 1 : 0.332


In [None]:
1844, 0.552

In [18]:
PATH = f"mixup_{num_epochs}.pt"

In [19]:
state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epochs': num_epochs,
    'lr':2e-5
}
torch.save(state, PATH)

In [17]:
len(mixmatch_loader)

TypeError: object of type 'MixMatchDataLoader' has no len()

In [None]:
print(f"Test Accuracy after phase 2 : {get_acc(model,test_loader,device)}")

In [22]:
state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
torch.save(state, "phase_2_model_2_epochs.pt")

In [None]:
175503, 0.576
171688.996986866, 0.58

In [15]:
checkpoint = torch.load("phase_2_model.pt")
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])