In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from transformers import BertTokenizer, BertModel
import torch.optim as optim
import pandas as pd
import numpy as np
from torchtext.legacy import data
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
max_len = 128
MAX_LEN = 256

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def split_and_cut(sentence):
  tokens = sentence.strip().split(" ")
  tokens = tokens[:MAX_LEN]
  return tokens

def convert_to_int(tok_ids):
  tok_ids = [int(x) for x in tok_ids]
  return tok_ids

cls_token_idx = tokenizer.cls_token_id
sep_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

In [None]:
#For sequence
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)
#For Attention mask
ATTENTION = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = pad_token_idx)
#For token type ids
TTYPE = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = 1)

#For label
LABEL = data.LabelField()

In [None]:
fields = [('tokens', TEXT), ('attention_sent', ATTENTION), ('token_ids', TTYPE), ('gold_label', LABEL)]

In [None]:
train_data, valid_data, test_data = data.TabularDataset.splits(
                                        path = '/content/drive/MyDrive/snli_1.0/snli_1.0/',
                                        train = 'updated_train.csv',
                                        validation = 'updated_val.csv',
                                        test = 'updated_test.csv',
                                        format = 'csv',
                                        fields = fields,
                                        skip_header = True)

In [None]:
LABEL.build_vocab(train_data)

In [None]:
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort = False,
    # sort_key = lambda x: len(x.sequence),
    sort_within_batch = False, 
    device = device)

In [None]:
PATH = '/content/drive/MyDrive/snli_1.0/latest_10000'
PATH2 = "/content/drive/MyDrive/snli_1.0/final_model"

In [None]:
nli_model.load_state_dict(torch.load(PATH2))

<All keys matched successfully>

In [None]:
def get_label(pred_y):

  val = np.argmax(pred_y.detach().numpy(), axis=1)
  return val

In [None]:
def calc_acc(pred_y, true_y, corr_pred):

  val = get_label(pred_y)
  true_y_np = true_y.detach().numpy()
  corr_pred += np.sum(val==true_y_np)

  return corr_pred

In [None]:
def get_predictions(iterator, data):
    corr_pred = 0

    for batch in iterator:

      sequence = batch.tokens
      attn_mask = batch.attention_sent
      token_type = batch.token_ids
      label = batch.gold_label

      prediction = nli_model(sequence, attn_mask, token_type)
      corr_pred = calc_acc(prediction, label, corr_pred)

    acc = corr_pred/len(data)
    return acc

In [None]:
train_acc = get_predictions(train_iterator, train_data)
print(train_acc)

0.9919797728633812


In [None]:
val_acc = get_predictions(valid_iterator, valid_data)
print(val_acc)

0.8918918918918919


In [None]:
test_acc = get_predictions(test_iterator, test_data)
print(test_acc)

0.8934242671009772


In [None]:
reverse_dct = {}

dct = LABEL.vocab.stoi

for k, v in dct.items():
  reverse_dct[v] = k

In [None]:
def get_sentence(lst):
  return " ".join([word for word in lst])

In [None]:
cnt=0

for batch in test_iterator:

  sequence = batch.tokens
  attn_mask = batch.attention_sent
  token_type = batch.token_ids
  label = batch.gold_label

  predictions = nli_model(sequence, attn_mask, token_type)
  pred_labels = get_label(predictions)

  for i, pred in enumerate(pred_labels):
    print(get_sentence(tokenizer.convert_ids_to_tokens(sequence[i])), end=" - ")
    print(reverse_dct[pred_labels[i]])

  cnt += len(batch)
  
  if cnt>=2*len(batch):
    break

[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church has cracks in the ceiling . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] - neutral
[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] the church is filled with song . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] - entailment
[CLS] this church choir sings to the masses as they sing joy ##ous songs from the book at a church . [SEP] a choir singing at a baseball game . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] - contradiction
[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is young . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] - neutral
[CLS] a woman with a green heads ##car ##f , blue shirt and a very big grin . [SEP] the woman is very happy . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 