In [19]:
import pandas as pd
import numpy as np
import torch
import json
import jsonlines
from pathlib import Path
from barbar import Bar
import random
from torch import nn

In [20]:
with open('../preproc_datasets/BioASQ-train-yesno-8b-snippet.json', 'rb') as f:
    bio_yn_raw = json.load(f)['data'][0]['paragraphs']
bio_yn = [q['qas'][0] for q in bio_yn_raw]
for i in range(len(bio_yn)):
    bio_yn[i]['context'] = bio_yn_raw[i]['context']
bio_yn_df = pd.DataFrame.from_dict(bio_yn)
bio_yn_df.head()

Unnamed: 0,id,question,is_impossible,answers,context
0,54e25eaaae9738404b000017_001,Is the protein Papilin secreted?,False,yes,"Using expression analysis, we identify three g..."
1,54e25eaaae9738404b000017_002,Is the protein Papilin secreted?,False,yes,We found that mig-6 encodes long (MIG-6L) and ...
2,54e25eaaae9738404b000017_003,Is the protein Papilin secreted?,False,yes,"apilins are homologous, secreted extracellular..."
3,54e25eaaae9738404b000017_004,Is the protein Papilin secreted?,False,yes,The TSR superfamily is a diverse family of ext...
4,54e25eaaae9738404b000017_005,Is the protein Papilin secreted?,False,yes,Papilins are extracellular matrix proteins


In [21]:
no_size = bio_yn_df[bio_yn_df.answers == 'no'].shape[0]
no_size

1637

In [22]:
yes_index = bio_yn_df[bio_yn_df.answers == 'yes'].index
random_index = np.random.choice(yes_index, no_size, replace=False)
yes_sample = bio_yn_df.loc[random_index]
bio_yn_balanced = pd.concat([yes_sample,bio_yn_df[bio_yn_df.answers == 'no']])
bio_yn_balanced

Unnamed: 0,id,question,is_impossible,answers,context
467,530e42e65937551c09000007_019,Is fatigue prevalent in patients receiving tre...,False,yes,The most common atrasentan-related toxicities ...
7831,589a246878275d0c4a000030_059,Is vortioxetine effective for treatment of dep...,False,yes,"In this study of adults with MDD, 5 mg vortiox..."
11317,5c5217fd7e3cb0e231000005_021,Is there any association between suicide and a...,False,yes,Although the suicide risk of autism spectrum d...
5134,570908e3cf1c325851000012_017,Is EZH2 associated with prostate cancer?,False,yes,ChIP data on prostate cancer tissue specimens ...
2996,55016397e9bde69634000006_026,Is Sarcolipin a regulatory/inhibitory protein ...,False,yes,The sarco(endo)plasmic reticulum calcium ATPas...
...,...,...,...,...,...
11473,5cb37f76ecadf2e73f00005c_002,Is Pim-1 a protein phosphatase?,True,no,The Pim1 serine/threonine kinase is associated...
11474,5cb38a56ecadf2e73f00005e_001,Is myc a tumour suppressor gene?,True,no,"oncogenic Myc, a master transcription factor t..."
11475,5cb38a56ecadf2e73f00005e_002,Is myc a tumour suppressor gene?,True,no,he MYC oncogene
11476,5cb38a56ecadf2e73f00005e_003,Is myc a tumour suppressor gene?,True,no,the proto-oncogene protein c-MYC


In [23]:
train_a = list(bio_yn_balanced.question)
train_b = list(bio_yn_balanced.context)
train_labels = [int(answer == 'yes') for answer in bio_yn_balanced.answers]

In [24]:
from transformers import BertTokenizer
# Load the BERT tokenizer.
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1', 
                                          do_lower_case=True)

In [25]:
train_tokens = tokenizer(train_a,train_b, 
                       add_special_tokens=True,
                       max_length=500,
                       truncation=True, padding=True)
train_tokens['labels'] = train_labels

In [26]:
from torch.utils.data import Dataset, DataLoader

class MnliDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        #print(self.encodings['start_positions'][idx])
        #{key: torch.tensor(val[idx], dtype = torch.long) for key, val in self.encodings.items()}
        return {'input_ids': torch.tensor(self.encodings['input_ids'][idx], dtype = torch.long),
                'attention_mask': torch.tensor(self.encodings['attention_mask'][idx], dtype = torch.long),
                'token_type_ids': torch.tensor(self.encodings['token_type_ids'][idx], dtype = torch.long),
                'labels': torch.tensor(self.encodings['labels'][idx], dtype = torch.long)
               }

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = MnliDataset(train_tokens)

In [27]:
# freeze all the parameters
for param in model.parameters():
    param.requires_grad = False

In [28]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1", num_labels = 4)
checkpoint = torch.load('../checkpoints/checkpoint_mnli_3epochs_seed.pt',map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

<All keys matched successfully>

In [29]:
class BERT_Arch(nn.Module):

    def __init__(self, model):
      
        super(BERT_Arch, self).__init__()

        self.model = model
        
        # dropout layer
        self.dropout = nn.Dropout(0.1)
        
        # relu activation function
        self.relu =  nn.ReLU()
        # dense layer 1
        self.fc1 = nn.Linear(4,512)
        
        # dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512,2)
        #softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, input_ids,
            attention_mask,
            token_type_ids,labels):

        #pass the inputs to the model  
        outputs = self.model(input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,labels = labels)
        
        cls_hs = outputs.logits
        
        x = self.fc1(cls_hs)

        x = self.relu(x)

        x = self.dropout(x)

        # output layer
        x = self.fc2(x)
        
        # apply softmax activation
        x = self.softmax(x)

        return x

In [30]:
model_full = BERT_Arch(model)

In [31]:
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_full.to(device)
model_full.train()

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optim = AdamW(model.parameters(), lr=5e-5)

In [32]:
cross_entropy  = nn.NLLLoss() 
for epoch in range(3):
    for i,batch in enumerate(Bar(train_loader)):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device, dtype = torch.long)
        attention_mask = batch['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
        labels = batch['labels'].to(device, dtype = torch.long)
        outputs = model_full(input_ids, 
                        attention_mask=attention_mask, 
                        token_type_ids = token_type_ids,
                        labels = labels)
        #loss = outputs.loss
        loss = cross_entropy(outputs, labels)
        loss.backward()
        optim.step()
model_full.eval()



BERT_Arch(
  (model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
 

In [33]:
torch.save({
            'epoch': 3,
            'model_state_dict': model_full.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            },'../checkpoints/checkpoint_bio_yn_balanced_seed.pt')