# Masked Language Modeling

In this lab, we will overview the **masked language modeling** objective, and how the **Transformer** architecture is used for large-scale masked language modeling.


In [1]:
%pylab inline
import os, sys, glob, json, math
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict
import torch
import torch.nn as nn

%load_ext autoreload
%autoreload 2
pd.set_option('display.max_colwidth', -1)

Populating the interactive namespace from numpy and matplotlib


Let set the random seeds for reproducability.

In [27]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Background

Recently, Devlin et al. published [BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding](https://arxiv.org/pdf/1810.04805.pdf).


**B**idirectional

**E**ncoder

**R**epresentations from

**T**ransformers


#### Goal: 
1. **pre-train** a model that produces language representations. 
2. **fine-tune** the model on a task.
    


## Masked Language Model Objective

Randomly mask some of the tokens from the input, predict original vocabulary id of each masked token.

- Given sequence $x_1,\ldots,x_N$.

- Form **mask** $m_1,\ldots,m_N$ where $m_i\in \{0,1\}$.
    - E.g. $m_i=1$ with probability 0.15
    
- Form **masked sequence** $\tilde{x}_1,\ldots,\tilde{x}_N$.
    - $\tilde{x}_i=\begin{cases} x_i & m_i=0\\ \texttt{[MASK]} & m_i=1\end{cases}$


#### $$\mathcal{L}_{\text{MLM}}=-\sum_{\underbrace{i | m_i=1}_{\text{MASKED POSITIONS}}}\log p_{\theta}(\underbrace{x_i}_{\text{TRUE TOKEN}}|\underbrace{\tilde{x}_1,\ldots,\tilde{x}_N}_{\text{MASKED SEQUENCE}})$$


<!-- Below, we will discuss the exact form of $\tilde{x}_i$ that the BERT authors used. -->


<!-- #### Diagram of BERT Implementation -->
<!-- ![](bert_overview.png) -->

## Transformers

So far we have modeled a sequence by factorizing the joint distribution into conditionals, and **parameterizing each conditional with a recurrent network**:


#### $$p_{\theta}(x_1,\ldots,x_T)=\prod_{t=1}^T p_{\theta}(x_t | x_{<t})$$
\begin{align}
h_t &= RNN(x_{t-1}, h_t)\\
p_{\theta}(x_t | x_{<t}) &=\text{softmax}\left(Wh_t+b\right),
\end{align}

where $\theta$ are the model parameters (RNN parameters, $W, b$, embedding matrix).


#### Alternative

An alternative proposed in [[Vaswani et al 2017](https://arxiv.org/pdf/1706.03762.pdf)] is to parameterize each conditional with a **particular feed-forward architecture** called the **Transformer**. With this model, it is possible to compute all conditionals with a **single feed-forward pass**:
\begin{align}
(h_1,\ldots,h_T) &= Transformer(x)\\
p_{\theta}(x_t | x_{<t}) &= \text{softmax}\left(Wh_t + b\right)
\end{align}

We will discuss briefly the key ideas, the overall **Transformer architecture (encoder only)**, and how they are used in Pytorch.

### High-Level View

We can view the Transformer encoder as mapping a sequence to a sequence of vectors.

<img src="img/high1.png" alt="Drawing" style="width: 35%;"/>

Let's step through the key ideas of how this mapping is designed, and discuss some of its resulting properties.

### Key Idea 1: Position Embeddings

**Unlike RNNs which can learn positional information via the hidden state over time, the Transformer has no notion of time**.

Thus we encode inputs with **position** as well as **token** embeddings:

<img src="img/high2.png" alt="Drawing" style="width: 35%;"/>

In [28]:
input_sequence = ['<s>', 'my', 'pet', '[M]', '<s>']

max_len = 10

vocab = {'<s>': 0, 'my': 1, 'pet': 2, 'dog': 3, 'cat': 4, 'lion': 5, '[M]': 6}

dim = 6

token_embed = nn.Embedding(len(vocab), embedding_dim=dim)  # an embedding for each token
position_embed = nn.Embedding(max_len, embedding_dim=dim)  # an empbedding to count for the position of the token 

In [29]:
input_vector = torch.tensor([vocab[x] for x in input_sequence]).unsqueeze(1)   # get the numerical representation of the token

input_embeddings = token_embed(input_vector) + position_embed(torch.arange(len(input_vector))).unsqueeze(1)  # add the input embedding to the position embedding
input_embeddings.size()

torch.Size([5, 1, 6])

**Warning!!** The pytorch Transformer classes accept input as `Length x Batch x Dim`

#### Key Idea 2: Modularity
The Transformer (encoder) is composed of a stack of **N identical layers**.

<img src="img/layers.png" alt="Drawing" style="width: 35%;"/>

In [30]:
import torch.nn as nn
nn.TransformerEncoder?

[0;31mInit signature:[0m [0mnn[0m[0;34m.[0m[0mTransformerEncoder[0m[0;34m([0m[0mencoder_layer[0m[0;34m,[0m [0mnum_layers[0m[0;34m,[0m [0mnorm[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
TransformerEncoder is a stack of N encoder layers

Args:
    encoder_layer: an instance of the TransformerEncoderLayer() class (required).
    num_layers: the number of sub-encoder-layers in the encoder (required).
    norm: the layer normalization component (optional).

Examples::
    >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
    >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    >>> src = torch.rand(10, 32, 512)
    >>> out = transformer_encoder(src)
[0;31mInit docstring:[0m Initializes internal Module state, shared by both nn.Module and ScriptModule.
[0;31mFile:[0m           ~/anaconda3/envs/aims/lib/python3.7/site-packages/torch/nn/modules/transformer.py
[0;31mType:[0

#### The `forward` passes the input through the N layers, then normalizes it:

**Warning!!** The forward function accepts input as `Length x Batch x Dim`

In [31]:
# nn.TransformerEncoder.forward??

In [48]:
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=2, dim_feedforward=64, dropout=0.1) # attention is all you need config nn.TransformerEncoderLayer(d_model=512, nhead=8)

encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)

In [49]:
outputs = encoder(input_embeddings)

print("input size: \t%s" % str(tuple(input_embeddings.shape)))
print("output size:\t%s" % str(tuple(outputs.shape)))             # contextualize embedding of each token
outputs

input size: 	(5, 1, 6)
output size:	(5, 1, 6)


tensor([[[ 0.5391, -1.9445,  1.0577, -0.4564,  0.0294,  0.7747]],

        [[-0.7556,  0.5501, -1.3479,  1.5129,  0.7344, -0.6939]],

        [[ 0.7792, -1.9722,  0.9250,  0.5844,  0.1997, -0.5160]],

        [[-0.9059, -1.2809,  0.1254, -0.3488,  1.7034,  0.7068]],

        [[ 1.1515, -1.2860,  0.2149,  1.3268, -0.3691, -1.0381]]],
       grad_fn=<NativeLayerNormBackward>)

In [50]:
print(token_embed(input_vector)) # the original token embedding

tensor([[[-0.1117, -0.4966,  0.1631, -0.8817,  0.0539,  0.6684]],

        [[-0.0597, -0.4675, -0.2153,  0.8840, -0.7584, -0.3689]],

        [[-0.3424, -1.4020,  0.3206, -1.0219,  0.7988, -0.0923]],

        [[-0.7690, -1.5606, -0.5309,  0.2178,  1.3232,  1.8169]],

        [[-0.1117, -0.4966,  0.1631, -0.8817,  0.0539,  0.6684]]],
       grad_fn=<EmbeddingBackward>)


#### Each layer has two parts, **self-attention** and a feed-forward transformation:

<img src="img/layer.png" alt="Drawing" style="width: 65%;"/>

In [28]:
# nn.TransformerEncoderLayer??

In [29]:
# nn.TransformerEncoderLayer.forward??

### Key Idea 3: Self-Attention

In the RNN, the hidden state contains information about previous tokens.
The Transformer instead performs **attention** over all inputs at a given layer. 'Attention' computes an output vector by taking a weighted sum of input vectors. The weights are 'attention weights'. The Transformer uses **scaled dot-product attention**:
#### $$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

and 'Multi-head Attention' refers to applying several of these operations in parallel.

#### *Key Property*: Each output vector of a layer $n$ can using information from **all** inputs to the layer $n$.

Thus each **final output vector** can incorporate information from **all input words**.

(If we want to prevent information flow such as in left-to-right language modeling, we can use masking).

In [17]:
nn.MultiheadAttention?

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mMultiheadAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0membed_dim[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_heads[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdropout[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0madd_bias_kv[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0madd_zero_attn[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkdim[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvdim[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need

.. math::
    \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
    \text{w

In [54]:
attn = nn.MultiheadAttention(dim, 2, dropout=0.0)

attn_outputs, attn_weights = attn.forward(query=outputs, key=outputs, value=outputs)

print("input shape: %s" % (str(tuple(outputs.size()))))
print("output shape: %s" % (str(tuple(attn_outputs.size()))))
print(outputs)

print("\nattn weights shape: %s" % (str(tuple(attn_weights.size()))))
print(attn_weights)

input shape: (5, 1, 6)
output shape: (5, 1, 6)
tensor([[[ 0.5391, -1.9445,  1.0577, -0.4564,  0.0294,  0.7747]],

        [[-0.7556,  0.5501, -1.3479,  1.5129,  0.7344, -0.6939]],

        [[ 0.7792, -1.9722,  0.9250,  0.5844,  0.1997, -0.5160]],

        [[-0.9059, -1.2809,  0.1254, -0.3488,  1.7034,  0.7068]],

        [[ 1.1515, -1.2860,  0.2149,  1.3268, -0.3691, -1.0381]]],
       grad_fn=<NativeLayerNormBackward>)

attn weights shape: (1, 5, 5)
tensor([[[0.1703, 0.2043, 0.1716, 0.2744, 0.1794],
         [0.2384, 0.1427, 0.2343, 0.1328, 0.2517],
         [0.1686, 0.1739, 0.1843, 0.2152, 0.2581],
         [0.1906, 0.1497, 0.2084, 0.2162, 0.2351],
         [0.1768, 0.1457, 0.1723, 0.2346, 0.2706]]], grad_fn=<DivBackward0>)


#### Summary

In [55]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, max_len, dim=8, num_layers=4, nhead=2):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, dim)
        self.position_embed = nn.Embedding(max_len, dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=nhead, dim_feedforward=64, dropout=0.0)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.projection = nn.Linear(dim, vocab_size)
    
    def features(self, token_indices):
        pos = torch.arange(len(token_indices), device=token_indices.device).unsqueeze(1)
        x = self.token_embed(token_indices) + self.position_embed(pos)
        x = self.encoder(x)
        return x
    
    def forward(self, token_indices):
        x = self.features(token_indices)
        x = self.projection(x)
        return x

In [56]:
input_vector.size()

torch.Size([5, 1])

In [66]:
model = Transformer(len(vocab), max_len=100)

model.features(input_vector)

tensor([[[-0.1841,  0.3811,  1.4767,  0.3744, -2.2709, -0.2429, -0.0631,
           0.5290]],

        [[ 0.6175,  0.9034,  1.0025, -0.8787,  1.1249, -1.4216, -0.0295,
          -1.3185]],

        [[-0.3699,  0.7491,  0.4773, -0.6109,  1.7609,  0.3687, -0.6204,
          -1.7548]],

        [[-1.1181,  0.9639, -0.8665,  0.5216, -1.5619,  0.5622,  0.0716,
           1.4273]],

        [[-0.1518, -0.3191,  1.9927,  0.0086, -1.1681, -1.3354,  0.1103,
           0.8627]]], grad_fn=<NativeLayerNormBackward>)

In [67]:
model.features(input_vector).shape

torch.Size([5, 1, 8])

In [68]:
model

Transformer(
  (token_embed): Embedding(7, 8)
  (position_embed): Embedding(100, 8)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=8, out_features=8, bias=True)
        )
        (linear1): Linear(in_features=8, out_features=64, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=64, out_features=8, bias=True)
        (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=8, out_features=8, bias=True)
        )
        (linear1): Linear(in_features=8, out_features=64, bias=True)
        (dropout): Dropout(p=0.0, inpla

In [70]:
model.token_embed(input_vector)

tensor([[[-1.1013, -1.4392, -0.7411, -1.5282,  0.5737, -1.4641, -0.3159,
           0.9548]],

        [[-0.0104,  0.0784,  0.5036, -0.7424,  1.9766, -0.9712,  0.1476,
          -1.6098]],

        [[-0.7888, -0.8349,  1.3856, -2.7129,  1.0593,  1.2605,  0.0641,
          -0.4017]],

        [[ 0.0079, -1.4143, -2.0439, -0.7530,  0.1951,  1.4279, -0.3208,
           1.5795]],

        [[-1.1013, -1.4392, -0.7411, -1.5282,  0.5737, -1.4641, -0.3159,
           0.9548]]], grad_fn=<EmbeddingBackward>)

## Back to Masked Language Modeling

Recall the **key property** of Transformers: due to self-attention, each output vector can incorporate information from *all* input tokens.

<img src="img/mlm.png" alt="Drawing" style="width: 45%;"/>

This is useful for masked language modeling, where we want to use information from the entire context when predicting the masked token(s).

#### MLM on Persona-Chat

In [71]:
import utils
raw_datasets, datasets, vocab = utils.load_personachat()

100%|██████████| 133176/133176 [00:31<00:00, 4251.03it/s]
100%|██████████| 133176/133176 [00:00<00:00, 168611.57it/s]
100%|██████████| 16181/16181 [00:00<00:00, 169260.05it/s]


Vocab size: 19157


In [72]:
from torch.utils.data.dataloader import DataLoader

trainloader = DataLoader(datasets['train'], batch_size=4, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))
validloader = DataLoader(datasets['valid'], batch_size=4, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))

In [73]:
batch = next(trainloader.__iter__())
batch

tensor([[ 0,  0,  0,  0],
        [ 4,  4,  4, 22],
        [ 5,  5, 18, 23],
        [ 6, 13, 17, 24],
        [ 7, 14, 19, 15],
        [ 8, 15, 13, 25],
        [ 9, 16, 20, 26],
        [10, 17, 21, 27],
        [11, 12, 12, 28],
        [12,  0,  0, 29],
        [ 0,  2,  2, 30],
        [ 2,  2,  2, 24],
        [ 2,  2,  2,  4],
        [ 2,  2,  2, 31],
        [ 2,  2,  2, 32],
        [ 2,  2,  2, 27],
        [ 2,  2,  2, 33],
        [ 2,  2,  2, 34],
        [ 2,  2,  2, 35],
        [ 2,  2,  2, 36],
        [ 2,  2,  2, 24],
        [ 2,  2,  2,  0]])

In [37]:
def mask_tokens(inputs, mask_prob, pad_token_id, mask_token_id, vsize):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""
    inputs = inputs.clone()
    labels = inputs.clone()
    # Sample tokens in each sequence for masked-LM training
    masked_indices = torch.bernoulli(torch.full(labels.shape, mask_prob)).bool()
    masked_indices = masked_indices & (inputs != pad_token_id)
    labels[~masked_indices] = -1  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = mask_token_id

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vsize, labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

In [38]:
inputs, labels = mask_tokens(batch, mask_prob=0.15, mask_token_id=vocab.get_id('[M]'), pad_token_id=vocab.get_id('<pad>'), vsize=len(vocab))
print("Mask token id: %d" % vocab.get_id('[M]'))
inputs

Mask token id: 1


tensor([[ 0,  0,  0,  0],
        [ 1,  4,  4, 22],
        [ 5,  5, 18, 23],
        [ 6, 13, 17,  1],
        [ 7,  1, 19, 15],
        [ 8, 15, 13,  1],
        [ 9, 16, 20, 26],
        [10, 17, 21, 27],
        [ 1, 12, 12, 28],
        [12,  0,  0, 29],
        [ 0,  2,  2, 30],
        [ 2,  2,  2, 24],
        [ 2,  2,  2,  4],
        [ 2,  2,  2, 31],
        [ 2,  2,  2, 32],
        [ 2,  2,  2, 27],
        [ 2,  2,  2,  1],
        [ 2,  2,  2, 34],
        [ 2,  2,  2, 35],
        [ 2,  2,  2, 36],
        [ 2,  2,  2, 24],
        [ 2,  2,  2,  0]])

In [39]:
labels

tensor([[-1, -1, -1, -1],
        [ 4, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, 24],
        [-1, 14, -1, -1],
        [-1, -1, -1, 25],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [11, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, 33],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1],
        [-1, -1, -1, -1]])

In [40]:
model = Transformer(len(vocab), max_len=200)

In [41]:
logits = model(inputs)
logits.size()

torch.Size([22, 4, 19157])

In [42]:
labels.size()

torch.Size([22, 4])

In [43]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)

In [44]:
logits_ = logits.view(-1, logits.size(2))
labels_ = labels.view(-1)

criterion(logits_, labels_)

tensor(10.2056, grad_fn=<NllLossBackward>)

In [45]:
if False:
    import torch.optim as optim
    from tqdm import tqdm, trange
    from collections import defaultdict
    from torch.utils.data.dataloader import DataLoader

    trainloader = DataLoader(datasets['train'], batch_size=64, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))
    validloader = DataLoader(datasets['valid'], batch_size=64, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    model = Transformer(len(vocab), max_len=65, dim=256, nhead=8).to(device)

    model_parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(model_parameters, lr=0.001)

    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)

    stats = defaultdict(list)

    for epoch in range(50):
        for step, batch in enumerate(trainloader):
            model.train()        
            # Mask the batch
            inputs, labels = mask_tokens(batch, mask_prob=0.15, 
                                         pad_token_id=vocab.get_id('<pad>'),
                                         mask_token_id=vocab.get_id('[M]'), 
                                         vsize=len(vocab))
            inputs = inputs.to(device)
            labels = labels.to(device)

            logits = model(inputs)
            logits_ = logits.view(-1, logits.size(2))
            labels_ = labels.view(-1)

            optimizer.zero_grad()
            loss = criterion(logits_, labels_)

            loss.backward()
            optimizer.step()

            stats['train_loss'].append(loss.item())
            stats['train_loss_log'].append(loss.item())
            if (step % 500) == 0:
                avg_loss = sum(stats['train_loss_log']) / len(stats['train_loss_log'])
                print("Epoch %d Step %d\tTrain Loss %.3f" % (epoch, step, avg_loss))
                stats['train_loss_log'] = []

        for batch in validloader:
            model.eval()
            with torch.no_grad():
                # Mask the batch
                inputs, labels = mask_tokens(batch, mask_prob=0.15, 
                                             pad_token_id=vocab.get_id('<pad>'),
                                             mask_token_id=vocab.get_id('[M]'), 
                                             vsize=len(vocab))
                inputs = inputs.to(device)
                labels = labels.to(device)

                logits = model(inputs)
                logits_ = logits.view(-1, logits.size(2))
                labels_ = labels.view(-1)

                loss = criterion(logits_, labels_)
                stats['valid_loss'].append(loss.item())
        print("=== Epoch %d\tValid Loss %.3f" % (epoch, stats['valid_loss'][-1]))

### Example Conditionals

#### Load model  

In [46]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

checkpoint = utils.load('model', 'model', best=True)
options = checkpoint['options']
stats = checkpoint['stats']


model = utils.Transformer(len(vocab), options['max_len'], 
                          dim=options['dim'], 
                          nhead=options['nhead'])
model.load_state_dict(checkpoint['model_dict'])

<All keys matched successfully>

In [47]:
model.eval()
model = model.to(device)

In [48]:
sentences = [['<s>', 'i', 'have', 'a', 'pet', '[M]', '.', '<s>'],
             ['<s>', 'i', 'have', 'two', 'pet', '[M]', '.', '<s>'],
             ['<s>', 'my', '[M]', 'is', 'a', 'lawyer', '.', '<s>'],
             ['<s>', 'my', '[M]', 'is', 'a', '[M]', '.', '<s>'],
             ['<s>', 'i', '[M]', '[M]', '[M]', 'sometimes', '.' , '<s>']]


def get_top_masked_tokens(tokens, vocab, device, top=10):
    ids = torch.tensor([vocab.get_id(x) for x in tokens], device=device).unsqueeze(1)
    masked = ids == vocab.get_id('[M]')

    logits = model(ids)[masked]
    probs = torch.softmax(logits, -1)

    print(' '.join(tokens))
    for ps in probs:
        probs, idxs = ps.sort(descending=True)

        for i in range(top):
            print("\t%s (%.4f)" % (vocab.get_token(idxs[i].item()),
                                   probs[i].item()))
        print()

In [49]:
for s in sentences:
    get_top_masked_tokens(s, vocab, device)

<s> i have a pet [M] . <s>
	cat (0.0707)
	dog (0.0533)
	sibling (0.0342)
	puppy (0.0340)
	sister (0.0302)
	retriever (0.0265)
	daughter (0.0264)
	shepard (0.0232)
	named (0.0213)
	brother (0.0208)

<s> i have two pet [M] . <s>
	cats (0.1525)
	dogs (0.0874)
	girls (0.0748)
	boys (0.0501)
	brothers (0.0499)
	wives (0.0420)
	children (0.0386)
	kids (0.0377)
	sisters (0.0333)
	, (0.0219)

<s> my [M] is a lawyer . <s>
	mother (0.2872)
	dad (0.2481)
	mom (0.1561)
	husband (0.0864)
	father (0.0363)
	brother (0.0230)
	job (0.0144)
	sister (0.0143)
	parents (0.0131)
	wife (0.0104)

<s> my [M] is a [M] . <s>
	mother (0.2330)
	dad (0.2212)
	mom (0.1373)
	husband (0.1110)
	brother (0.0364)
	father (0.0357)
	sister (0.0285)
	job (0.0145)
	wife (0.0141)
	parents (0.0127)

	teacher (0.0899)
	lawyer (0.0456)
	nurse (0.0426)
	cop (0.0414)
	mechanic (0.0386)
	doctor (0.0259)
	pilot (0.0195)
	journalist (0.0163)
	dancer (0.0148)
	hairdresser (0.0123)

<s> i [M] [M] [M] sometimes . <s>
	am (0.1669)
	love 

## Back to *BERT*

**B**idirectional

**E**ncoder

**R**epresentations from

**T**ransformers

#### - Masked Language Modeling at scale

#### - Learned representations are useful downstream

<img src="img/bert_citations.png" alt="Drawing" style="width: 45%;"/>

#### Great implementation in [transformers](https://github.com/huggingface/transformers):

In [59]:
# !pip install transformers

In [52]:
import torch
from transformers import (
    BertForMaskedLM,
    BertTokenizer
)

### Details -- Model Variants

- $\text{BERT}_{\text{BASE}}$: 12 layers, hidden dimension 768, 12 attention heads (**110 million parameters**)
- $\text{BERT}_{\text{LARGE}}$: 24 layers, hidden dimension 1024, 16 attention heads (**340 million parameters**)

In [53]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
model = BertForMaskedLM.from_pretrained('bert-base-cased', output_attentions=True)

if torch.cuda.is_available():
    model.cuda()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=361.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




### Details -- Input Implementation


- `[CLS]` token: starts each sequence. Used as aggregate sequence representation.
- `[SEP]` token: separates two segments (e.g. two sentences).
- **Segment embedding**: learned embedding for every token indicating whether it belongs
to sentence A or sentence B.
- **Position embedding**: learned.


<img src="img/bert_inputs.png" alt="Drawing" style="width: 75%;"/>

**Exercise:** Which downstream tasks would two sequences be useful for?

### Tokenization

#### BERT represents text using **subword** tokens with a 30k token vocabulary.  



(more info [here](https://github.com/google/sentencepiece) and in the papers mentioned there)

<!-- - **Token embedding**: WordPiece embeddings with 30k token vocabulary. -->

In [54]:
tokenizer.tokenize("Pretraining is cool.")

['Pre', '##tra', '##ining', 'is', 'cool', '.']

In [55]:
tokenizer.tokenize("BERT represents text using subwords.")

['B', '##ER', '##T', 'represents', 'text', 'using', 'sub', '##words', '.']

### Examining Learned Conditionals (& Representations)

**Probing tasks** can be used to examine aspects of what the model has learned. 

Following [Petroni et al 2019](https://arxiv.org/pdf/1909.01066.pdf) we probe for '**knowledge**' that the model has learned by querying for masked out objects, e.g.:

<img src="img/bert_kb.png" alt="Drawing" style="width: 75%;"/>

The task also illustrates some aspects of the **conditional distributions** and **contextualized representations** that the model has learned.

(image from [Petroni et al 2019])


**Exercise:** The authors only consider *single-token* prediction. Why?

#### Probing Task

We use a dataset from [Petroni et al 2019](https://github.com/facebookresearch/LAMA).

In [56]:
import utils
data = utils.load_lama_squad(download=True)
data[0]

{'masked_sentences': ['To emphasize the 50th anniversary of the Super Bowl the [MASK] color was used.'],
 'obj_label': 'gold',
 'id': '56be4db0acb8001400a502f0_0',
 'sub_label': 'Squad'}

In [57]:
results = []

model.eval()
for example in tqdm(data, total=len(data)):
    sentence, label = example['masked_sentences'][0], example['obj_label']
    inp = torch.tensor([
        [tokenizer.cls_token_id] + 
        tokenizer.encode(sentence) + 
        [tokenizer.sep_token_id]
    ], device=device)
    
    mask = (inp == tokenizer.vocab[tokenizer.mask_token])
    out, attn = model(inp)
    
    probs, token_ids = out[mask].softmax(1).topk(10)
    probs = probs[0].tolist()
    token_ids = token_ids[0].tolist()

    tokens = [tokenizer.ids_to_tokens[i] for i in token_ids]

    results.append({
        'sentence': sentence,
        'label': label,
        'top_tokens': tokens,
        'top_probs': probs,
        'correct@1': tokens[0] == label,
        'attn': attn
    })

print("correct@1: %.3f" % (
    len([r for r in results if r['correct@1']]) / len(results)
))

100%|██████████| 305/305 [00:07<00:00, 38.36it/s]

correct@1: 0.121





In [58]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

correct = [r for r in results if r['correct@1']]
wrong = [r for r in results if not r['correct@1']]

def show(idx=0, attn_layer=0, is_correct=True):
    result = correct[idx] if is_correct else wrong[idx]

    # --- format the result into a string
    top_str = '\n\t'.join([
        ('\t%s\t(%.4f)' % (tokens, probs)) 
        for tokens, probs in zip(result['top_tokens'], result['top_probs'])
    ])
    print("%s\n\tlabel:\t%s\n\n\ttop:%s" % (
        result['sentence'], 
        result['label'], 
        top_str
    ))

    # --- visualize attention
    print("Attention weights (12 heads) from layer %d:" % attn_layer)
    fig, axs = plt.subplots(3, 4, figsize=(18, 12))

    toks = ['[CLS]'] + tokenizer.tokenize(result['sentence']) + ['[SEP]']
    for i, ax in enumerate(axs.reshape(-1)):
        ax.matshow(result['attn'][attn_layer][0][i].data.cpu().numpy(), cmap='gray')

        ax.set_xticks(range(len(toks)))
        ax.set_xticklabels(toks, rotation=90, fontsize=15)
        ax.set_yticks(range(len(toks)))
        ax.set_yticklabels(toks, fontsize=15)
    plt.tight_layout()
    
interactive(
    show, 
    idx=(0, min(len(correct), len(wrong))-1), 
    attn_layer=range(12), 
    is_correct=True
)

interactive(children=(IntSlider(value=0, description='idx', max=36), Dropdown(description='attn_layer', option…