In [1]:
import torch 
from torch import nn
from transformers import AutoModel,AutoTokenizer
from utils import *
from config import data_path,save_path
import torch.nn.functional as F



%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fada9b85f90>

In [3]:
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) |available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) |available.
We will use the GPU: NVIDIA GeForce GTX 1080 Ti


In [4]:

def virtual_adversarial_training(model, input_ids, attention_mask, epsilon=1.0, alpha=1.0, n_iterations=1):
    # Initial forward pass
    with torch.no_grad():
        initial_logits, _ = model(input_ids, attention_mask)
        initial_prob = F.softmax(initial_logits, dim=1)

    # Initialize perturbation
    d = torch.randn_like(input_ids, dtype=torch.float).to(input_ids.device)
    d = F.normalize(d, dim=-1, p=2)
    d.requires_grad_()

    # print(epsilon)
    # print(d)

    for _ in range(n_iterations):
        # Forward pass with perturbed input
        perturbed_ids = input_ids.float() + epsilon * d
        logits_perturbed, _ = model(perturbed_ids.long(), attention_mask)
        
        # Compute KL divergence
        loss = F.kl_div(F.log_softmax(logits_perturbed, dim=1),
                        initial_prob,
                        reduction='batchmean')
        
        # Compute gradients
        loss.backward()
        
        # Update perturbation
        if d.grad is not None:
            d = d.grad.detach()
            d = F.normalize(d, dim=-1, p=2)
            d.requires_grad_()
        else:
            # If gradient is None, reinitialize d
            d = torch.randn_like(input_ids, dtype=torch.float).to(input_ids.device)
            d = F.normalize(d, dim=-1, p=2)
            d.requires_grad_()
        
        model.zero_grad()

    # Final forward pass with adversarial perturbation
    with torch.no_grad():
        adv_ids = (input_ids.float() + epsilon * d).long()
        logits_adv, _ = model(adv_ids, attention_mask)
    
    # Compute VAT loss
    vat_loss = F.kl_div(F.log_softmax(logits_adv, dim=1),
                        initial_prob,
                        reduction='batchmean')
    
    return alpha * vat_loss


In [5]:
class MultiTaskPubMedQA(nn.Module):
    def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1', num_labels=3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.long_answer_generator = nn.Linear(self.bert.config.hidden_size, self.bert.config.vocab_size)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        
        # Task 1: Yes/No/Maybe Classification
        classification_logits = self.classifier(pooled_output)
        
        # Task 2: Long Answer Generation
        sequence_output = outputs.last_hidden_state
        long_answer_logits = self.long_answer_generator(sequence_output)
        
        return classification_logits, long_answer_logits

In [6]:
model_name = 'nlpie/tiny-biobert'
model = MultiTaskPubMedQA(model_name = model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model = model.to(device)

In [8]:
expert_train_processed,artificial_train_processed,unlabeled_processed,expert_test_processed = load_pubmedqa_data(data_path)

In [9]:
train_dataset = PubMedQADataset(expert_train_processed + artificial_train_processed, tokenizer,max_length = 256)
# val_dataset = PubMedQADataset(expert_val, tokenizer)
unlabeled_dataset = PubMedQADataset(unlabeled_processed, tokenizer,max_length = 256)
test_dataset = PubMedQADataset(expert_test_processed,tokenizer,max_length = 256)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=16)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32, shuffle=True)

In [11]:
class_count_list = [196420, 15294, 55]
class_weights = [max(class_count_list) / count for count in class_count_list]
class_weights = torch.tensor(class_weights, dtype=torch.float).to('cuda')

In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
classification_loss_fn = nn.CrossEntropyLoss(weight = class_weights)
generation_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [13]:
num_epochs = 2

In [14]:
next(iter(unlabeled_loader))   

{'input_ids': tensor([[  101,  2372, 16979,  ...,  1833,  2366,   102],
         [  101, 12120, 20954,  ...,  1121,  1103,   102],
         [  101,  8274,   118,  ...,  8167, 10721,   102],
         ...,
         [  101,  7187,  1425,  ...,  1127,  2382,   102],
         [  101,  2181, 11019,  ...,   119,   121,   102],
         [  101,   140,  3161,  ...,   110,   117,   102]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'label': tensor([-1, -1, -1, -1, -1, -1, -1, -1]),
 'long_answer_ids': tensor([[101, 102,   0,  ...,   0,   0,   0],
         [101, 102,   0,  ...,   0,   0,   0],
         [101, 102,   0,  ...,   0,   0,   0],
         ...,
         [101, 102,   0,  ...,   0,   0,   0],
         [101, 102,   0,  ...,   0,   0,   0],
         [101, 102,   0,  ...,   0,   0,   0]]),
 'lo

In [14]:
torch.cuda.empty_cache()
# import gc
# del variables
# gc.collect()

In [15]:
print(f"Test Accuracy before phase 1 : {get_acc(model,test_loader,device)}")

100%|██████████| 16/16 [00:02<00:00,  6.93it/s]

Test Accuracy before phase 1 : 0.338





In [18]:
model.train()


total_start_time = time.time()

for epoch in range(num_epochs):
    
    epoch_loss = 0

    for step,batch in enumerate(unlabeled_loader):

        # input_ids, attention_mask, label, long_answer, = batch
        # batch = batch.to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        label = batch['label'].to(device)
        long_answer_ids = batch['long_answer_ids'].to(device)

        
        classification_logits, long_answer_logits = model(input_ids, attention_mask)
        
        # Compute losses
        # classification_loss = classification_loss_fn(classification_logits, label)
        # In your training loop
        vat_loss = virtual_adversarial_training(model, input_ids, attention_mask,n_iterations=5)

        generation_loss = generation_loss_fn(long_answer_logits.view(-1, long_answer_logits.size(-1)), long_answer_ids.view(-1))

        # Combine losses
        total_loss = vat_loss + generation_loss
        
        optimizer.zero_grad()
        total_loss.backward()

        if step % 500 == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            
            # Report progress.
            total_time = time.time() - total_start_time
            print('  Batch {:>5,}  of  {:>5,} time elapsed {}'.format(step, len(unlabeled_loader),format_time(total_time)))
        # generation_loss.backward()
        epoch_loss += total_loss.item()
        
        optimizer.step()
    
    print(f"Epoch {epoch} loss : {epoch_loss}")

    
print(f"Test Accuracy after phase 1 : {get_acc(model,test_loader,device)}")

  Batch   500  of  7,657 time elapsed 0:02:00
  Batch 1,000  of  7,657 time elapsed 0:04:01
  Batch 1,500  of  7,657 time elapsed 0:06:02
  Batch 2,000  of  7,657 time elapsed 0:08:04
  Batch 2,500  of  7,657 time elapsed 0:10:05
  Batch 3,000  of  7,657 time elapsed 0:12:06
  Batch 3,500  of  7,657 time elapsed 0:14:07
  Batch 4,000  of  7,657 time elapsed 0:16:08
  Batch 4,500  of  7,657 time elapsed 0:18:09
  Batch 5,000  of  7,657 time elapsed 0:20:10
  Batch 5,500  of  7,657 time elapsed 0:22:11
  Batch 6,000  of  7,657 time elapsed 0:24:12
  Batch 6,500  of  7,657 time elapsed 0:26:13
  Batch 7,000  of  7,657 time elapsed 0:28:15
  Batch 7,500  of  7,657 time elapsed 0:30:16
Epoch 0 loss : 166.08718601762666
  Batch   500  of  7,657 time elapsed 0:32:55
  Batch 1,000  of  7,657 time elapsed 0:34:56
  Batch 1,500  of  7,657 time elapsed 0:36:57
  Batch 2,000  of  7,657 time elapsed 0:38:58
  Batch 2,500  of  7,657 time elapsed 0:40:59
  Batch 3,000  of  7,657 time elapsed 0:43:00


100%|██████████| 16/16 [00:01<00:00,  8.86it/s]

Test Accuracy after phase 1 : 0.526





In [None]:
1844, 0.552

In [19]:
PATH = f"weights/phase_1_model_{num_epochs}.pt"

In [20]:
state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epochs': num_epochs,
    'lr':3e-4
}
torch.save(state, PATH)

In [17]:
print(f"Test Accuracy after phase 1 : {get_acc(model,test_loader,device)}")

100%|██████████| 16/16 [00:01<00:00,  8.79it/s]

Test Accuracy after phase 1 : 0.526





In [19]:

model.train()

total_start_time = time.time()

for epoch in range(num_epochs):
    epoch_loss = 0

    for step,batch in tqdm(enumerate(train_loader)):

        # batch = batch.to(device)

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        label = batch['label'].to(device)
        long_answer_ids = batch['long_answer_ids'].to(device)
        # long_answer_mask = batch['long_answer_mask']

        # 'input_ids' : inputs['input_ids'].squeeze(),
        #     'attention_mask': inputs['attention_mask'].squeeze(),
        #     'label':label,
        #     'long_answer_ids': long_answer_encoding['input_ids'].squeeze(),
        #     'long_answer_mask': long_answer_encoding['attention_mask'].squeeze()
        # input_ids, attention_mask = prepare_data(question, context)
        
        classification_logits, long_answer_logits = model(input_ids, attention_mask)
        
        # Compute losses
        classification_loss = classification_loss_fn(classification_logits, label)
        vat_loss = virtual_adversarial_training(model, input_ids, attention_mask,n_iterations=5)

        generation_loss = generation_loss_fn(long_answer_logits.view(-1, long_answer_logits.size(-1)), long_answer_ids.view(-1))

        # Combine losses
        total_loss = classification_loss + generation_loss + vat_loss
        
        optimizer.zero_grad()
        total_loss.backward()

        if step % 3000 == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            
            # Report progress.
            total_time = time.time() - total_start_time
            print('  Batch {:>5,}  of  {:>5,} time elapsed {}'.format(step, len(train_loader),format_time(total_time)))
        optimizer.step()
        epoch_loss += total_loss.item()
    
    print(f"Epoch {epoch} loss : {epoch_loss}")
    


3001it [12:27,  4.02it/s]

  Batch 3,000  of  26,472 time elapsed 0:12:27


6001it [24:56,  3.99it/s]

  Batch 6,000  of  26,472 time elapsed 0:24:56


9001it [37:25,  4.02it/s]

  Batch 9,000  of  26,472 time elapsed 0:37:25


12001it [49:53,  4.03it/s]

  Batch 12,000  of  26,472 time elapsed 0:49:53


15001it [1:02:22,  4.02it/s]

  Batch 15,000  of  26,472 time elapsed 1:02:21


18001it [1:14:50,  4.02it/s]

  Batch 18,000  of  26,472 time elapsed 1:14:50


21001it [1:27:19,  4.01it/s]

  Batch 21,000  of  26,472 time elapsed 1:27:19


24001it [1:39:48,  4.01it/s]

  Batch 24,000  of  26,472 time elapsed 1:39:48


26472it [1:50:05,  4.01it/s]

Epoch 2 loss : 197524.51847219467





In [20]:
print(f"Test Accuracy after phase 2 : {get_acc(model,test_loader,device)}")

100%|██████████| 16/16 [00:01<00:00,  8.80it/s]

Test Accuracy after phase 2 : 0.496





In [20]:
state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epochs': 2,
    'lr':3e-4
}
torch.save(state, "weights/phase_2_model_2.pt")

In [None]:
199071.09152078629, 0.552
197524.51847219467, 0.496

# 171688.996986866, 0.58

In [16]:
checkpoint = torch.load(f"weights/phase_1_model_{num_epochs}.pt")
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [16]:
checkpoint = torch.load(f"weights/phase_2_model_{num_epochs}.pt")
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])