Working on Natural Language Inference (NLI) using BERT. On [https://nlp.stanford.edu/projects/snli/]

The task is, give two sentences: a premise and a hypothesis, to classify the relation between them. We have three classes to describe this relationship.

1. Entailment: the hypothesis follows from the fact that the premise is true
2. Contradiction: the hypothesis contradicts the fact that the premise is true
3. Neutral: There is not relationship between premise and hypothesis

In [1]:
# Imports for most of the notebook
import torch
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List
from util import get_snli

In [2]:
from datasets import load_dataset
dataset = load_dataset("snli")
print("Split sizes (num_samples, num_labels):\n", dataset.shape)
print("\nExample:\n", dataset['train'][0])

Found cached dataset snli (/home/jovyan/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Split sizes (num_samples, num_labels):
 {'test': (10000, 3), 'train': (550152, 3), 'validation': (10000, 3)}

Example:
 {'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}


In [3]:
from util import get_snli

train_dataset, validation_dataset, test_dataset = get_snli()

Found cached dataset snli (/home/jovyan/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Reading Datapoints: 100%|██████████| 550152/550152 [00:15<00:00, 36654.53it/s]
Reading Datapoints: 100%|██████████| 10000/10000 [00:00<00:00, 40054.13it/s]
Reading Datapoints: 100%|██████████| 10000/10000 [00:00<00:00, 40218.47it/s]


In [4]:
## sub set stats
from collections import Counter

# num sample stats
print(len(train_dataset), len(validation_dataset), len(test_dataset))

# label distribution
print(Counter([t['label'] for t in train_dataset]))
print(Counter([t['label'] for t in validation_dataset]))
print(Counter([t['label'] for t in test_dataset]))

# We have a perfectly balanced dataset

9999 999 999
Counter({0: 3333, 1: 3333, 2: 3333})
Counter({0: 333, 1: 333, 2: 333})
Counter({0: 333, 1: 333, 2: 333})


In [5]:
class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""

    def __init__(self):
        #Initializes the tokenizer
        
        self.hf_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
    
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token
    
    def __call__(self, prem_batch: List[str], hyp_batch: List[str]) -> List[List[str]]:
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # The two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            prem_batch,
            hyp_batch,
            padding=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        return enc
    

# Using tokenizer: example
tokenizer = BatchTokenizer()
x = tokenizer(*[["this is the premise.", "This is also a premise"], ["this is the hypothesis", "This is a second hypothesis"]])
print(x)
tokenizer.hf_tokenizer.batch_decode(x["input_ids"])

{'input_ids': tensor([[  101,  2023,  2003,  1996, 18458,  1012,   102,  2023,  2003,  1996,
         10744,   102,     0],
        [  101,  2023,  2003,  2036,  1037, 18458,   102,  2023,  2003,  1037,
          2117, 10744,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


['[CLS] this is the premise. [SEP] this is the hypothesis [SEP] [PAD]',
 '[CLS] this is also a premise [SEP] this is a second hypothesis [SEP]']

In [8]:
def generate_pairwise_input(dataset: List[Dict]) -> (List[str], List[str], List[int]):
    labels = []
    hyp = []
    prem = []
    for i in dataset:
        labels.append(i['label'])
        hyp.append(i['hypothesis'])
        prem.append(i['premise'])
    return (prem, hyp, labels)

In [9]:
train_premises, train_hypotheses, train_labels = generate_pairwise_input(train_dataset)
validation_premises, validation_hypotheses, validation_labels = generate_pairwise_input(validation_dataset)
test_premises, test_hypotheses, test_labels = generate_pairwise_input(test_dataset)

In [11]:
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i + n], lst2[i: i + n]
        
# Notice that since we use huggingface, we tokenize and
# encode in all at once
tokenizer = BatchTokenizer()
batch_size = 32
train_input_batches = [b for b in chunk_multi(train_premises, train_hypotheses, batch_size)]
# Tokenize + encode
train_input_batches = [tokenizer(*batch) for batch in train_input_batches]

In [12]:
def encode_labels(labels: List[int]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[int]): List of all labels in the batch

    Returns:
        torch.FloatTensor: Tensor of all labels in the batch
    """
    return torch.LongTensor([int(l) for l in labels])


train_label_batches = [b for b in chunk(train_labels, batch_size)]
train_label_batches = [encode_labels(batch) for batch in train_label_batches]