<a href="https://colab.research.google.com/github/antgr/pytorch-nli/blob/master/1b_With_Bert_Simple_NLI_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-vp0dk5aj
  Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-vp0dk5aj
Building wheels for collected packages: transformers
  Building wheel for transformers (setup.py) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-2.1.1-cp36-none-any.whl size=327209 sha256=4416d4d6a54a25db70598d6528e9c90fb5492cd73b343078ec61c35d25ede81c
  Stored in directory: /tmp/pip-ephem-wheel-cache-vz51xvm8/wheels/70/d3/52/b3fa4f8b8ef04167ac62e5bb2accb62ae764db2a378247490e
Successfully built transformers


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchtext import data
from torchtext import datasets

import random
import numpy as np

import time


SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True


In [3]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
len(tokenizer.vocab)


30522

In [5]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token
print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [6]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id
print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [0]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

In [0]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [0]:
from torchtext import data

TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)


LABEL = data.LabelField(dtype = torch.long)

In [0]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

In [0]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 549367
Number of validation examples: 9842
Number of testing examples: 9824


In [0]:
vars(train_data.examples[6])['hypothesis']

[1996, 2879, 17260, 2015, 2091, 1996, 11996, 1012]

In [0]:
vars(train_data.examples[6])['hypothesis']

[1996, 2879, 17260, 2015, 2091, 1996, 11996, 1012]

In [0]:
vars(train_data.examples[6])['premise']

[1037,
 2879,
 2003,
 8660,
 2006,
 17260,
 6277,
 1999,
 1996,
 2690,
 1997,
 1037,
 2417,
 2958,
 1012]

In [0]:
vars(train_data.examples[6])['label']

'contradiction'

In [0]:
LABEL.build_vocab(train_data)

In [0]:
print(LABEL.vocab.itos)

In [0]:
print(vars(train_data.examples[0]))

{'premise': [1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012], 'hypothesis': [1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012], 'label': 'neutral'}


In [0]:
BATCH_SIZE = 512

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

In [0]:
from transformers import BertTokenizer, BertModel

bert = BertModel.from_pretrained('bert-base-uncased')

In [0]:
class NLISum(nn.Module):
    def __init__(self, 
                 bert,
                 hidden_dim,
                 fc_layers,
                 output_dim, 
                 dropout,
                 PAD_IDX):
        
        super().__init__()

        self.bert = bert

        self.hidden_dim = hidden_dim
        self.embedding_dim = bert.config.to_dict()['hidden_size']
        #print ("self.embedding_dim: ", self.embedding_dim)

        self.translation = nn.Linear(self.embedding_dim, self.hidden_dim)
        #print ("self.translation: ", self.translation)
        
        fcs = [nn.Linear(self.hidden_dim*2 , self.hidden_dim*2 ) for _ in range(fc_layers)]
        
        self.fcs = nn.ModuleList(fcs)
        #print ("self.fcs: ", self.fcs)
        
        self.fc_out = nn.Linear(self.hidden_dim*2 , output_dim)
        #print ("self.fc_out: ", self.fc_out)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, prem, hypo):
        with torch.no_grad():
            embedded_prem = bert(prem)[0]
            #print ("embedded_prem: ", embedded_prem.shape)
            embedded_hypo = bert(hypo)[0]
            #print ("embedded_hypo: ", embedded_hypo.shape)

        embedded_prem = embedded_prem.sum(dim = 1)
        #print ("embedded_prem: ", embedded_prem.shape)
        embedded_hypo = embedded_hypo.sum(dim = 1)
        #print ("embedded_hypo: ", embedded_hypo.shape)

        #prem = [prem sent len, batch size]
        #hypo = [hypo sent len, batch size]
        
        translated_prem = F.relu(self.translation(embedded_prem))
        #print ("translated_prem: ", translated_prem.shape)
        translated_hypo = F.relu(self.translation(embedded_hypo))
        #print ("translated_hypo: ", translated_hypo.shape)
        
        #translated_prem = [prem sent len, batch size, hidden dim]
        #translated_hypo = [hypo sent len, batch size, hidden dim]

        hidden = torch.cat((translated_prem, translated_hypo), dim = 1)
        #print ("hidden: ", hidden.shape)

        #hidden = [batch size, hid dim * 2]
            
        for fc in self.fcs:
            hidden = fc(hidden)
            hidden = F.relu(hidden)
            hidden = self.dropout(hidden)
        #print ("hidden: ", hidden.shape)
        prediction = self.fc_out(hidden)
        #print ("prediction: ", prediction.shape)
        
        #prediction = [batch size, output dim]
      
        return prediction

In [0]:
class NLIRNN(nn.Module):
    def __init__(self, 
                 bert,
                 hidden_dim,
                 fc_layers,
                 output_dim, 
                 dropout,
                 PAD_IDX):
        
        super().__init__()

        self.bert = bert

        self.hidden_dim = hidden_dim
        self.embedding_dim = bert.config.to_dict()['hidden_size']
        #print ("self.embedding_dim: ", self.embedding_dim)
        
        self.translation = nn.Linear(self.hidden_dim, self.hidden_dim)
        #print ("self.translation: ", self.translation)

        self.rnn = nn.LSTM(self.embedding_dim, self.hidden_dim)
        #print ("self.rnn: ", self.rnn)
        
        fcs = [nn.Linear(self.hidden_dim*2 , self.hidden_dim*2 ) for _ in range(fc_layers)]
        
        self.fcs = nn.ModuleList(fcs)
        #print ("self.fcs: ", self.fcs)
        
        self.fc_out = nn.Linear(self.hidden_dim*2 , output_dim)
        #print ("self.fc_out: ", self.fc_out)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, prem, hypo):
        with torch.no_grad():
            embedded_prem = bert(prem)[0]
            #print ("embedded_prem: ", embedded_prem.shape)
            embedded_hypo = bert(hypo)[0]
            #print ("embedded_hypo: ", embedded_hypo.shape)

        embedded_prem = embedded_prem.permute(1,0,2)
        #print ("embedded_prem: ", embedded_prem.shape)
        embedded_hypo = embedded_hypo.permute(1,0,2)
        #print ("embedded_hypo: ", embedded_hypo.shape)

        _, (hidden_prem, _) = self.rnn(embedded_prem)
        #print ("hidden_prem: ", hidden_prem.shape)
        _, (hidden_prem, _) = self.rnn(embedded_hypo)
        #print ("hidden_prem: ", hidden_prem.shape)

        embedded_prem = hidden_prem.squeeze(0)
        #print ("embedded_prem: ", embedded_prem.shape)
        embedded_hypo = hidden_prem.squeeze(0)
        #print ("embedded_hypo: ", embedded_hypo.shape)

        #prem = [prem sent len, batch size]
        #hypo = [hypo sent len, batch size]
        

        translated_prem = F.relu(self.translation(embedded_prem))
        #print ("translated_prem: ", translated_prem.shape)
        translated_hypo = F.relu(self.translation(embedded_hypo))
        #print ("translated_hypo: ", translated_hypo.shape)
        
        #translated_prem = [prem sent len, batch size, hidden dim]
        #translated_hypo = [hypo sent len, batch size, hidden dim]

        hidden = torch.cat((translated_prem, translated_hypo), dim = 1)
        #print ("hidden: ", hidden.shape)

        #hidden = [batch size, hid dim * 2]
            
        for fc in self.fcs:
            hidden = fc(hidden)
            hidden = F.relu(hidden)
            hidden = self.dropout(hidden)
        #print ("hidden: ", hidden.shape)
        prediction = self.fc_out(hidden)
        #print ("prediction: ", prediction.shape)
        
        #prediction = [batch size, output dim]
      
        return prediction

In [0]:
class NLISumRNN(nn.Module):
    def __init__(self, 
                 bert,
                 hidden_dim,
                 fc_layers,
                 output_dim, 
                 dropout,
                 PAD_IDX):
        
        super().__init__()

        self.bert = bert

        self.hidden_dim = hidden_dim
        self.embedding_dim = bert.config.to_dict()['hidden_size']
        #print ("self.embedding_dim: ", self.embedding_dim)
        
        self.translation1 = nn.Linear(self.hidden_dim, self.hidden_dim)
        #print ("self.translation1: ", self.translation1)

        self.translation = nn.Linear(self.embedding_dim, self.hidden_dim)
        #print ("self.translation: ", self.translation)

        self.rnn = nn.LSTM(self.embedding_dim, self.hidden_dim)
        #print ("self.rnn: ", self.rnn)
        
        fcs = [nn.Linear(self.hidden_dim*4 , self.hidden_dim*4 ) for _ in range(fc_layers)]
        
        self.fcs = nn.ModuleList(fcs)
        #print ("self.fcs: ", self.fcs)
        
        self.fc_out = nn.Linear(self.hidden_dim*4 , output_dim)
        #print ("self.fc_out: ", self.fc_out)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, prem, hypo):
        with torch.no_grad():
            embedded_prem = bert(prem)[0]
            #print ("embedded_prem: ", embedded_prem.shape)
            embedded_hypo = bert(hypo)[0]
            #print ("embedded_hypo: ", embedded_hypo.shape)

        embedded_prem1 = embedded_prem.permute(1,0,2)
        #print ("embedded_prem1: ", embedded_prem1.shape)
        embedded_hypo1 = embedded_hypo.permute(1,0,2)
        #print ("embedded_hypo1: ", embedded_hypo1.shape)

        _, (hidden_prem, _) = self.rnn(embedded_prem1)
        #print ("hidden_prem1: ", hidden_prem.shape)
        _, (hidden_prem, _) = self.rnn(embedded_hypo1)
        #print ("hidden_prem1: ", hidden_prem.shape)

        embedded_prem = embedded_prem.sum(dim=1)
        #print ("embedded_prem: ", embedded_prem.shape)
        embedded_hypo = embedded_hypo.sum(dim=1)
        #print ("embedded_hypo: ", embedded_hypo.shape)

        embedded_prem1 = hidden_prem.squeeze(0)
        #print ("embedded_prem1: ", embedded_prem1.shape)
        embedded_hypo1 = hidden_prem.squeeze(0)
        #print ("embedded_hypo1: ", embedded_hypo1.shape)

        #prem = [prem sent len, batch size]
        #hypo = [hypo sent len, batch size]
        

        translated_prem1 = F.relu(self.translation1(embedded_prem1))
        #print ("translated_prem1: ", translated_prem1.shape)
        translated_hypo1 = F.relu(self.translation1(embedded_hypo1))
        #print ("translated_hypo1: ", translated_hypo1.shape)
        
        #translated_prem = [prem sent len, batch size, hidden dim]
        #translated_hypo = [hypo sent len, batch size, hidden dim]

        translated_prem = F.relu(self.translation(embedded_prem))
        #print ("translated_prem: ", translated_prem.shape)
        translated_hypo = F.relu(self.translation(embedded_hypo))
        #print ("translated_hypo: ", translated_hypo.shape)


        hidden = torch.cat((translated_prem, translated_hypo), dim = 1)
        #print ("hidden: ", hidden.shape)

        #hidden = [batch size, hid dim * 2]

        hidden1 = torch.cat((translated_prem1, translated_hypo1), dim = 1)

        hidden2 = torch.cat((hidden, hidden1), dim = 1)
            
        for fc in self.fcs:
            hidden2 = fc(hidden2)
            hidden2 = F.relu(hidden2)
            hidden2 = self.dropout(hidden2)
        #print ("hidden: ", hidden.shape)
        prediction = self.fc_out(hidden2)
        #print ("prediction: ", prediction.shape)
        
        #prediction = [batch size, output dim]
      
        return prediction

In [0]:
HIDDEN_DIM = 300
FC_LAYERS = 3
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.25
PAD_IDX = pad_token_idx

model = NLISumRNN(bert,
              HIDDEN_DIM,
               FC_LAYERS,
               OUTPUT_DIM,
               DROPOUT,
               PAD_IDX)

In [16]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 115,414,443 trainable parameters


In [0]:
for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

In [18]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 5,932,203 trainable parameters


In [19]:
for name, param in model.named_parameters():                
    if param.requires_grad:
        print(name)

translation1.weight
translation1.bias
translation.weight
translation.bias
rnn.weight_ih_l0
rnn.weight_hh_l0
rnn.bias_ih_l0
rnn.bias_hh_l0
fcs.0.weight
fcs.0.bias
fcs.1.weight
fcs.1.bias
fcs.2.weight
fcs.2.bias
fc_out.weight
fc_out.bias


In [0]:
optimizer = optim.Adam(model.parameters())

In [0]:
criterion = nn.CrossEntropyLoss()

In [0]:
model = model.to(device)
criterion = criterion.to(device)

In [0]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

In [0]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        #prem = [prem sent len, batch size]
        #hypo = [hypo sent len, batch size]
        
        predictions = model(prem, hypo)
        
        #predictions = [batch size, output dim]
        #labels = [batch size]
        
        loss = criterion(predictions, labels)
                
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [0]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [0]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [0]:
N_EPOCHS = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

In [0]:
test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')