In [None]:
import pandas as pd
import numpy as np
import torch
import json
import jsonlines
from pathlib import Path
from barbar import Bar
import random

In [None]:
import zipfile
!unzip /content/mnli_data.zip

Archive:  /content/mnli_data.zip
  inflating: multinli_1.0_train.json  
  inflating: __MACOSX/._multinli_1.0_train.json  
  inflating: multinli_1.0_dev_matched.json  
  inflating: __MACOSX/._multinli_1.0_dev_matched.json  


In [None]:
#!pip install transformers
#!pip install barbar
#!pip install jsonlines

In [None]:
random.seed(1995)

In [None]:
def parse_mnli(path):
    sentences_a = []
    sentences_b = []
    labels = []
    with open(path, "r+", encoding="utf8") as f:
        for item in jsonlines.Reader(f):
            sentences_a.append(item['sentence1'])
            sentences_b.append(item['sentence2'])
            labels.append(item['gold_label'])
    
    return sentences_a,sentences_b,labels

In [None]:
train_a, train_b, train_labels = parse_mnli('/content/multinli_1.0_train.json')
val_a, val_b, val_labels = parse_mnli('/content/multinli_1.0_dev_matched.json')

In [None]:
label_encode = {'contradiction': 0,
                '-': 1,
                'neutral': 2,
                'entailment': 3}
train_labels_encoding = [label_encode[label] for label in train_labels]
val_labels_encoding = [label_encode[label] for label in val_labels]

In [None]:
from transformers import BertTokenizer 
# Load the BERT tokenizer.
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1', 
                                          do_lower_case=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [None]:
val_tokens = tokenizer(val_a,val_b, 
                       add_special_tokens=True,
                       max_length=500,
                       truncation=True, padding=True)

train_tokens = tokenizer(train_a,train_b, 
                       add_special_tokens=True,
                       max_length=500,
                       truncation=True, padding=True)

In [None]:
train_tokens['labels'] = train_labels_encoding
val_tokens['labels'] = val_labels_encoding

In [None]:
from torch.utils.data import Dataset, DataLoader

class MnliDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        #print(self.encodings['start_positions'][idx])
        #{key: torch.tensor(val[idx], dtype = torch.long) for key, val in self.encodings.items()}
        return {'input_ids': torch.tensor(self.encodings['input_ids'][idx], dtype = torch.long),
                'attention_mask': torch.tensor(self.encodings['attention_mask'][idx], dtype = torch.long),
                'token_type_ids': torch.tensor(self.encodings['token_type_ids'][idx], dtype = torch.long),
                'labels': torch.tensor(self.encodings['labels'][idx], dtype = torch.long)
               }

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = MnliDataset(train_tokens)
val_dataset = MnliDataset(val_tokens)

In [None]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1",
                                                      num_labels = 4)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=313.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435780550.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
model.train()

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optim = AdamW(model.parameters(), lr=5e-5)

In [None]:
# Train 
for epoch in range(5):
    for i,batch in enumerate(Bar(train_loader)):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device, dtype = torch.long)
        attention_mask = batch['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = batch['token_type_ids'].to(device, dtype = torch.long)
        labels = batch['labels'].to(device, dtype = torch.long)

        outputs = model(input_ids, 
                        attention_mask=attention_mask, 
                        token_type_ids = token_type_ids,
                        labels = labels)
        loss = outputs.loss
        loss.backward()
        optim.step()
model.eval()

 30112/392702: [==>.............................] - ETA 10306.1sBuffered data was truncated after reaching the output size limit.

In [None]:
torch.save({
            'epoch': 5,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            },'checkpoint_mnli_5epochs_seed.pt')

In [None]:
#from google.colab import files
#files.download('checkpoint_mnli_5epochs_seed.pt') 