Working on Natural Language Inference (NLI) using BERT. On [https://nlp.stanford.edu/projects/snli/]

The task is, give two sentences: a premise and a hypothesis, to classify the relation between them. We have three classes to describe this relationship.

1. Entailment: the hypothesis follows from the fact that the premise is true
2. Contradiction: the hypothesis contradicts the fact that the premise is true
3. Neutral: There is not relationship between premise and hypothesis

In [1]:
# Imports for most of the notebook
import torch
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List
from util import get_snli

In [2]:
from datasets import load_dataset
dataset = load_dataset("snli")
print("Split sizes (num_samples, num_labels):\n", dataset.shape)
print("\nExample:\n", dataset['train'][0])

Downloading builder script:   0%|          | 0.00/3.82k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.1k [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/94.6M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/550152 [00:00<?, ? examples/s]

Dataset snli downloaded and prepared to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Split sizes (num_samples, num_labels):
 {'test': (10000, 3), 'validation': (10000, 3), 'train': (550152, 3)}

Example:
 {'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}


In [3]:
from util import get_snli

train_dataset, validation_dataset, test_dataset = get_snli()



  0%|          | 0/3 [00:00<?, ?it/s]

Reading Datapoints: 100%|██████████| 550152/550152 [00:18<00:00, 30413.83it/s]
Reading Datapoints: 100%|██████████| 10000/10000 [00:00<00:00, 30623.56it/s]
Reading Datapoints: 100%|██████████| 10000/10000 [00:00<00:00, 29938.31it/s]


In [4]:
## sub set stats
from collections import Counter

# num sample stats
print(len(train_dataset), len(validation_dataset), len(test_dataset))

# label distribution
print(Counter([t['label'] for t in train_dataset]))
print(Counter([t['label'] for t in validation_dataset]))
print(Counter([t['label'] for t in test_dataset]))

# We have a perfectly balanced dataset

9999 999 999
Counter({0: 3333, 1: 3333, 2: 3333})
Counter({0: 333, 1: 333, 2: 333})
Counter({0: 333, 1: 333, 2: 333})


In [5]:
class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""

    def __init__(self):
        #Initializes the tokenizer
        
        self.hf_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
    
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token
    
    def __call__(self, prem_batch: List[str], hyp_batch: List[str]) -> List[List[str]]:
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # The two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            prem_batch,
            hyp_batch,
            padding=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        return enc
    

# Using tokenizer: example
tokenizer = BatchTokenizer()
x = tokenizer(*[["this is the premise.", "This is also a premise"], ["this is the hypothesis", "This is a second hypothesis"]])
print(x)
tokenizer.hf_tokenizer.batch_decode(x["input_ids"])

Downloading:   0%|          | 0.00/286 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

{'input_ids': tensor([[  101,  2023,  2003,  1996, 18458,  1012,   102,  2023,  2003,  1996,
         10744,   102,     0],
        [  101,  2023,  2003,  2036,  1037, 18458,   102,  2023,  2003,  1037,
          2117, 10744,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


['[CLS] this is the premise. [SEP] this is the hypothesis [SEP] [PAD]',
 '[CLS] this is also a premise [SEP] this is a second hypothesis [SEP]']

In [6]:
def generate_pairwise_input(dataset: List[Dict]) -> (List[str], List[str], List[int]):
    labels = []
    hyp = []
    prem = []
    for i in dataset:
        labels.append(i['label'])
        hyp.append(i['hypothesis'])
        prem.append(i['premise'])
    return (prem, hyp, labels)

In [7]:
train_premises, train_hypotheses, train_labels = generate_pairwise_input(train_dataset)
validation_premises, validation_hypotheses, validation_labels = generate_pairwise_input(validation_dataset)
test_premises, test_hypotheses, test_labels = generate_pairwise_input(test_dataset)

In [8]:
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i + n], lst2[i: i + n]
        
# Notice that since we use huggingface, we tokenize and
# encode in all at once
tokenizer = BatchTokenizer()
batch_size = 32
train_input_batches = [b for b in chunk_multi(train_premises, train_hypotheses, batch_size)]
# Tokenize + encode
train_input_batches = [tokenizer(*batch) for batch in train_input_batches]

In [9]:
def encode_labels(labels: List[int]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[int]): List of all labels in the batch

    Returns:
        torch.FloatTensor: Tensor of all labels in the batch
    """
    return torch.LongTensor([int(l) for l in labels])


train_label_batches = [b for b in chunk(train_labels, batch_size)]
train_label_batches = [encode_labels(batch) for batch in train_label_batches]

In [10]:
class NLIClassifier(torch.nn.Module):
    def __init__(self, output_size: int, hidden_size: int):
        super().__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        
        self.bert = BertModel.from_pretrained("prajjwal1/bert-small")
        
        self.bert_hidden_dimension = self.bert.config.hidden_size #512 
                
        self.hidden_layer = torch.nn.Linear(self.bert_hidden_dimension, self.hidden_size)
        
        self.relu = torch.nn.ReLU()

        self.classifier = torch.nn.Linear(self.hidden_size, self.output_size)
        
        self.log_softmax = torch.nn.LogSoftmax(dim=2)

    def encode_text(
        self,
        symbols: Dict
    ) -> torch.Tensor:
        """Encode the (batch of) sequence(s) of token symbols.
            Then, get the last (non-padded) hidden state for each symbol and return that.

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: The final hiddens state, which represents an encoding of
                the entire sentence
        """
        # First we get the contextualized embedding for each input symbol
        # We don't need an LSTM, since BERT encodes context and 
        # gives us a single vector describing the sequence in the form of the [CLS] token.
        encoded_sequence = self.bert(**symbols)
        
        cls_pooler_output = encoded_sequence[1]
        
        out_dimensions = torch.unsqueeze(cls_pooler_output, 1) # 300, 1, 512
        
        return out_dimensions 


    def forward(
        self,
        symbols: Dict,
    ) -> torch.Tensor:
        """_summary_

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: _description_
        """
        encoded_sents = self.encode_text(symbols)
        output = self.hidden_layer(encoded_sents)
        output = self.relu(output)
        output = self.classifier(output)
        return self.log_softmax(output)

In [11]:
# For making predictions at test time
def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    logits = model(sents)
    return list(torch.argmax(logits, axis=2).squeeze().numpy())

In [12]:
import numpy as np


def precision(predicted_labels, true_labels, which_label=1):
    """
    Precision is True Positives / All Positives Predictions
    """
    pred_which = np.array(predicted_labels) == which_label
    true_which = np.array(true_labels) == which_label
    denominator = np.sum(pred_which)
    if denominator:
        return np.sum(np.logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """
    pred_which = np.array(predicted_labels) == which_label
    true_which = np.array(true_labels) == which_label
    denominator = np.sum(true_which)
    if denominator:
        return np.sum(np.logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def f1_score(
    predicted_labels: List[int],
    true_labels: List[int],
    which_label: int
    ):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels, which_label=which_label)
    R = recall(predicted_labels, true_labels, which_label=which_label)
    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


def macro_f1(
    predicted_labels: List[int],
    true_labels: List[int],
    possible_labels: List[int]
    ):
    scores = [f1_score(predicted_labels, true_labels, l) for l in possible_labels]
    # Macro, so we take the uniform avg.
    return sum(scores) / len(scores)

In [18]:
import random 
from tqdm import tqdm_notebook as tqdm

def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    model,
    possible_label_values
):
    print("Training...")
    loss_func = torch.nn.NLLLoss()
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            preds = model(features).squeeze(1)
            loss = loss_func(preds, labels)
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        
        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels.numpy()))

        dev_f1 = macro_f1(all_preds, all_labels, possible_label_values)
        print(f"Dev F1 {dev_f1}")
        
    # Return the trained model
    return model

In [19]:
epochs = 3

LR = 0.000025

possible_labels = len(set(train_labels))
possible_label_values = set(train_labels)

model = NLIClassifier(output_size=possible_labels, hidden_size=340)
optimizer = torch.optim.AdamW(model.parameters(), LR)

validation_input_batches = [b for b in chunk_multi(validation_premises, validation_hypotheses, batch_size)]

validation_encoded_sents = [tokenizer(*batch) for batch in validation_input_batches]
validation_batch_labels = [b for b in chunk(validation_labels, batch_size)]
validation_encoded_labels = [encode_labels(batch) for batch in validation_batch_labels]

training_loop(
    epochs,
    train_input_batches,
    train_label_batches,
    validation_encoded_sents,
    validation_encoded_labels,
    optimizer,
    model,
    possible_label_values
)

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for features, labels in tqdm(batches):


  0%|          | 0/313 [00:00<?, ?it/s]

epoch 0, loss: 1.064753565258873
Evaluating dev...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):


  0%|          | 0/32 [00:00<?, ?it/s]

Dev F1 0.6412745263093064


  0%|          | 0/313 [00:00<?, ?it/s]

epoch 1, loss: 0.73752792170063
Evaluating dev...


  0%|          | 0/32 [00:00<?, ?it/s]

Dev F1 0.7080991763420129


  0%|          | 0/313 [00:00<?, ?it/s]

epoch 2, loss: 0.5034639345714078
Evaluating dev...


  0%|          | 0/32 [00:00<?, ?it/s]

Dev F1 0.7183785291407564


NLIClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True

In [20]:
test_input_batches = [b for b in chunk_multi(test_premises, test_hypotheses, batch_size)]

test_encoded_sents = [tokenizer(*batch) for batch in test_input_batches]
test_batch_labels = [b for b in chunk(test_labels, batch_size)]
test_encoded_labels = [encode_labels(batch) for batch in test_batch_labels]

all_preds = []
all_labels = []

for sents, labels in tqdm(zip(test_encoded_sents, test_encoded_labels), total=len(test_encoded_sents)):
  pred = predict(model, sents)
  all_preds.extend(pred)
  all_labels.extend(list(labels.numpy()))

test_f1 = macro_f1(all_preds, all_labels, possible_label_values)
print(f"Test F1 {test_f1}")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for sents, labels in tqdm(zip(test_encoded_sents, test_encoded_labels), total=len(test_encoded_sents)):


  0%|          | 0/32 [00:00<?, ?it/s]

Test F1 0.7125177296109021
