In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from transformers import BertTokenizer, BertModel
import torch.optim as optim
import pandas as pd
import numpy as np
from torchtext.legacy import data
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
max_len = 128
MAX_LEN = 256

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
def split_and_cut(sentence):
  tokens = sentence.strip().split(" ")
  tokens = tokens[:MAX_LEN]
  return tokens

def convert_to_int(tok_ids):
  tok_ids = [int(x) for x in tok_ids]
  return tok_ids

cls_token_idx = tokenizer.cls_token_id
sep_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

In [None]:
#For sequence
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)
#For Attention mask
ATTENTION = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = pad_token_idx)
#For token type ids
TTYPE = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = split_and_cut,
                  preprocessing = convert_to_int,
                  pad_token = 1)

#For label
LABEL = data.LabelField()

In [None]:
fields = [('tokens', TEXT), ('attention_sent', ATTENTION), ('token_ids', TTYPE), ('gold_label', LABEL)]

In [None]:
train_data, valid_data, test_data = data.TabularDataset.splits(
                                        path = '/content/drive/MyDrive/snli_1.0/snli_1.0/',
                                        train = 'updated_train.csv',
                                        validation = 'updated_val.csv',
                                        test = 'updated_test.csv',
                                        format = 'csv',
                                        fields = fields,
                                        skip_header = True)

In [None]:
LABEL.build_vocab(train_data)

In [None]:
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort = False,
    # sort_key = lambda x: len(x.sequence),
    sort_within_batch = False, 
    device = device)

In [None]:
class NLIModel(nn.Module):

  def __init__(self, model, hidden_neurons1, hidden_neurons2, output_neurons):
    super().__init__()
    self.model = model
    feature_vec = model.config.to_dict()['hidden_size']
    self.dense = nn.Linear(feature_vec, output_neurons)

  def forward(self, tokens, attention_mask, token_type):

    x = self.model(input_ids = tokens, attention_mask = attention_mask, token_type_ids= token_type)[1]
    y = self.dense(x)

    return y

In [None]:
nli_model = NLIModel(bert_model, 1024, 128, 3)
opt = optim.Adam(nli_model.parameters(), lr=2e-5, eps=1e-6)
loss_fnc = nn.CrossEntropyLoss().to(device)

In [None]:
PATH = '/content/drive/MyDrive/snli_1.0/snli_1.0/latest_10000' ## Path where you want to store the model
PATH2 = "/content/drive/MyDrive/snli_1.0/snli_1.0/final_model"

In [None]:
lossHist = []
tot_example = 0

for epoch in range(6):
  for batch in train_iterator:
    opt.zero_grad()
    torch.cuda.empty_cache()

    sequence = batch.tokens
    attn_mask = batch.attention_sent
    token_type = batch.token_ids
    label = batch.gold_label

    predictions = nli_model(sequence, attn_mask, token_type)
    loss = loss_fnc(predictions, label)

    loss.backward()
    opt.step()

    tot_example += len(batch)
    
    if tot_example%30000==0:
        torch.save(nli_model.state_dict(), PATH)

  lossHist.append(loss.item())

In [None]:
print(lossHist)

[0.31172987818717957, 0.15210573375225067, 0.1463630497455597, 0.02814902365207672, 0.020347703248262405, 0.1457570195198059]


In [None]:
torch.save(nli_model.state_dict(), PATH2)