In this notebook, we will show how to fine tune BERT to perform **natural language inference (NLI)**. Although we will make a few simplifying assumptions, the model we will train is conceptually very close to what NLP engineers and researchers implement in practice. The primary difference will be that we will not fine tune our model for very long, to keep the demo brief.

One of the takeaways for this notebook---and something that should be exciting to you!---is that we can use pretraining and fine tuning to implement a high-end model for NLI with very little work. Indeed, this whole notebook is on the order of 20 to 30 lines of code, thanks to the useful `transformers` and `datasets` libraries.

As always, let's set up the runtime.

In [1]:
!pip install transformers datasets >/dev/null

In [2]:
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'

# Load Dataset

We'll use the Stanford NLI (SNLI) dataset, a well known benchmark for the NLI task. The dataset contains on the order of 500k training samples. We'll toss most of the training data and train the validation dataset.

In [3]:
import datasets

dataset = datasets.load_dataset('snli')
del dataset['train']

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1660.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=938.0, style=ProgressStyle(description_…


Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1929.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1259440.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=65886400.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1263568.0, style=ProgressStyle(descript…


Dataset snli downloaded and prepared to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c. Subsequent calls will reuse this data.


Each sample in the dataset includes a **hypothesis** (left sentence in the slides), a **premise** (right sentence in the slides), and an entailment label, which can take one of the following values:
- 0, indicating that the hypothesis entails the premise
- 1, indicating that the premise and hypothesis neither entail nor contradict each other
- 2, indicating that the hypothesis contradicts the premise.

Let's look at an example.

In [5]:
dataset['validation'][1]

{'hypothesis': 'Two woman are holding packages.',
 'label': 0,
 'premise': 'Two women are embracing while holding to go packages.'}

# Load Pretrained Language Model

Now we need a large, pretrained language model that we can fine tune for the NLI task. We'll use the base BERT model provided by the Huggingface transformers library.

**Pro tip**: If you want to use a well known transformer in your final project, check out the transformers library. It supports most popular architectures and host pretrained models for all of them!

In [7]:
import transformers

tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
model = transformers.BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=3).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

You'll see a warning message above. This is telling us that the BERT language model was initialized from pretrained weights, but the classifier was not. This is what we wanted, as we are going to train the classifier from scratch.

Before moving on, we can check that the model has all the parts we expect. This output is a little overwhelming, but the important part thing is to notice that the `BertForSequenceClassification` has the two immediate subcomponents we care about: the BERT language model (`model.bert`) and a linear classifier (`model.classifier`) with an output dimension of 3.

In [8]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

Oh, and remember all that tokenization boilerplate from the recitation 2 notebook? Huggingface does that for us as well.

In [10]:
tokenizer('red bird', return_tensors='pt')

{'input_ids': tensor([[ 101, 2417, 4743,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

The tokenizer tokenizes the sentence into words, maps each word to a unique integer ID, and returns the full integer sequence as a tensor under the `input_ids` key. The `attention_mask` is a binary bitmap indicating which tokens are real (1) and which are padding (0). This isn't necessary when we have only one sentence like we do here, but if we were batching many sentences of different lengths we would need to pad the shorter ones so they all have the same length, then tell the BERT to ignore the padding tokens. This is what `attention_mask` will let us do.

**Spot check**: Notice the sentence "red bird" has two words, but `input_ids` has four tokens. Without peeking, what do you think are tokens 101 and 102?

# Prepare Data

Let's get the data ready for training. We have two preprocessing steps:
1. For some reason, some samples have a -1 label, which is not mentioned in the documentation. Let's just remove samples with those labels, as it causes trouble when computing the loss function.
2. We will pretokenize all sentences, concatenating the hypothesis and premise using the special [SEP] token. This allows BERT to look at the hypothesis and premise in one forward pass.

In [None]:
dataset = dataset\
    .filter(lambda sample: sample['label'] != -1)\
    .map(lambda sample: tokenizer(sample['hypothesis'],
                                  sample['premise'],
                                  truncation=True,
                                  max_length=50,
                                  padding='max_length'),
         batched=True)

Now we can convert the `datasets.Dataset` into a `torch.utils.Dataset`, allowing us to use PyTorch's batching tools. This just requires an extra function call.

In [None]:
from torch.utils import data

dataset.reset_format()
dataset.set_format(type='torch',
                   columns=['input_ids', 'attention_mask', 'label'],
                   device=device)
loader = data.DataLoader(dataset['validation'], batch_size=64)

# Train Model

Now we can fine tune BERT for NLI! The training loop here is especially simple, as the BERT implementation will compute the loss for us if we provide the target labels during the forward pass.

We will only train the model for one epoch through the validation dataset, which is **EGREGIOUSLY LITTLE** and will not work. But it'll show that all the pieces fit together and fit in the span of a short demo :)

Feel free to modify the code below to make it train for a reasonable number of epochs through the training dataset.

In [None]:
import torch
from torch import optim
from tqdm.notebook import tqdm

# Use this code if you want to fine tune *the entire model*
model.train()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Otherwise, comment out the above an duse this code to train
# *only the classifier* head.
# model.bert.eval()
# for parameter in model.bert.parameters():
  # parameter.requires_grad_(False)
# optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3)


pbar = tqdm(loader)
for batch in pbar:
  output = model(batch['input_ids'],
                 attention_mask=batch['attention_mask'],
                 labels=batch['label'])
  output.loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  pbar.set_description(f'{output.loss.item():.3f}')


Finally, we can check some of the outputs on the test set.

In [None]:
item = dataset['test'][6]
output = model(batch['input_ids'][None],
                attention_mask=batch['attention_mask'][None],
                labels=batch['label'])
print(output.logits.argmax().item(), item['label'])