In [15]:
import pandas as pd
import numpy as np
import torch
import json
import jsonlines
from pathlib import Path
import random
from tqdm import tqdm

In [2]:
def parse_mnli(path):
    sentences_a = []
    sentences_b = []
    labels = []
    with open(path, "r+", encoding="utf8") as f:
        for item in jsonlines.Reader(f):
            sentences_a.append(item['sentence1'])
            sentences_b.append(item['sentence2'])
            labels.append(item['gold_label'])
    
    return sentences_a,sentences_b,labels

In [3]:
val_a, val_b, val_labels = parse_mnli('./preproc_datasets/multinli_1.0_dev_matched.json')

In [4]:
label_encode = {'contradiction': 0,
                '-': 1,
                'neutral': 2,
                'entailment': 3}
val_labels_encoding = [label_encode[label] for label in val_labels]

In [5]:
from transformers import BertTokenizer
# Load the BERT tokenizer.
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1', 
                                          do_lower_case=True)

In [6]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1", num_labels = 4)
checkpoint = torch.load('checkpoint_mnli_3epochs_seed.pt',map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

<All keys matched successfully>

In [7]:
val_tokens = tokenizer(val_a,val_b, 
                       add_special_tokens=True,
                       max_length=500,
                       truncation=True, padding=True,return_tensors='pt')
val_tokens['labels'] = val_labels_encoding

In [8]:
val_tokens['input_ids']

tensor([[  101,  1103,  1207,  ...,     0,     0,     0],
        [  101,  1142,  1751,  ...,     0,     0,     0],
        [  101, 14863,   178,  ...,     0,     0,     0],
        ...,
        [  101,  1921,   117,  ...,     0,     0,     0],
        [  101,   172,  3161,  ...,     0,     0,     0],
        [  101,   178,   112,  ...,     0,     0,     0]])

In [None]:
outputs = model(**val_tokens)

In [16]:
val_predictions = []
for i in tqdm(range(len(val_a))):    
    inputs = tokenizer(val_a[i], val_b[i], 
                           add_special_tokens=True,
                           max_length=500,
                           truncation=True, padding=True,return_tensors='pt')
    output = model(**inputs)
    pred = torch.argmax(output.logits)
    val_predictions.append(int(pred))

100%|██████████| 10000/10000 [23:37<00:00,  7.06it/s]


In [20]:
with open('pred_mnli_3epochs.txt', 'w') as f:
    for item in val_predictions:
        f.write("%s\n" % item)

In [22]:
from sklearn.metrics import confusion_matrix

confusion_matrix(val_predictions,val_labels_encoding)

array([[2483,   42,  374,  188],
       [   0,    0,    0,    0],
       [ 485,   79, 2394,  519],
       [ 245,   64,  355, 2772]])