# Masked Language Modeling

In this lab, we will overview the **masked language modeling** objective, and how the **Transformer** architecture is used for large-scale masked language modeling.


In [None]:
%pylab inline
import os, sys, glob, json, math
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict
import torch
import torch.nn as nn

%load_ext autoreload
%autoreload 2
pd.set_option('display.max_colwidth', -1)

## Background

Recently, Devlin et al. published [BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding](https://arxiv.org/pdf/1810.04805.pdf).


**B**idirectional

**E**ncoder

**R**epresentations from

**T**ransformers


#### Goal: 
1. **pre-train** a model that produces language representations. 
2. **fine-tune** the model on a task.
    


## Masked Language Model Objective

Randomly mask some of the tokens from the input, predict original vocabulary id of each masked token.

- Given sequence $x_1,\ldots,x_N$.

- Form **mask** $m_1,\ldots,m_N$ where $m_i\in \{0,1\}$.
    - E.g. $m_i=1$ with probability 0.15
    
- Form **masked sequence** $\tilde{x}_1,\ldots,\tilde{x}_N$.
    - $\tilde{x}_i=\begin{cases} x_i & m_i=0\\ \texttt{[MASK]} & m_i=1\end{cases}$


#### $$\mathcal{L}_{\text{MLM}}=-\sum_{\underbrace{i | m_i=1}_{\text{MASKED POSITIONS}}}\log p_{\theta}(\underbrace{x_i}_{\text{TRUE TOKEN}}|\underbrace{\tilde{x}_1,\ldots,\tilde{x}_N}_{\text{MASKED SEQUENCE}})$$


<!-- Below, we will discuss the exact form of $\tilde{x}_i$ that the BERT authors used. -->


<!-- #### Diagram of BERT Implementation -->
<!-- ![](bert_overview.png) -->

## Transformers

So far we have modeled a sequence by factorizing the joint distribution into conditionals, and **parameterizing each conditional with a recurrent network**:


#### $$p_{\theta}(x_1,\ldots,x_T)=\prod_{t=1}^T p_{\theta}(x_t | x_{<t})$$
\begin{align}
h_t &= RNN(x_{t-1}, h_t)\\
p_{\theta}(x_t | x_{<t}) &=\text{softmax}\left(Wh_t+b\right),
\end{align}

where $\theta$ are the model parameters (RNN parameters, $W, b$, embedding matrix).


#### Alternative

An alternative proposed in [[Vaswani et al 2017](https://arxiv.org/pdf/1706.03762.pdf)] is to parameterize each conditional with a **particular feed-forward architecture** called the **Transformer**. With this model, it is possible to compute all conditionals with a **single feed-forward pass**:
\begin{align}
(h_1,\ldots,h_T) &= Transformer(x)\\
p_{\theta}(x_t | x_{<t}) &= \text{softmax}\left(Wh_t + b\right)
\end{align}

We will discuss briefly the key ideas, the overall Transformer architecture (encoder only), and how they are used in Pytorch.

### High-Level View

We can view the Transformer encoder as mapping a sequence to a sequence of vectors.

<img src="img/high1.png" alt="Drawing" style="width: 35%;"/>

Let's step through the key ideas of how this mapping is designed, and discuss some of its resulting properties.

### Key Idea 1: Position Embeddings

Unlike RNNs which can learn positional information via the hidden state over time, the Transformer has no notion of time.

Thus we encode inputs with **position** as well as **token** embeddings:

<img src="img/high2.png" alt="Drawing" style="width: 35%;"/>

In [None]:
input_sequence = ['<s>', 'my', 'pet', '[M]', '<s>']

max_len = 10

vocab = {'<s>': 0, 'my': 1, 'pet': 2, 'dog': 3, 'cat': 4, 'lion': 5, '[M]': 6}

dim = 6

token_embed = nn.Embedding(len(vocab), embedding_dim=dim)
position_embed = nn.Embedding(max_len, embedding_dim=dim)

In [None]:
input_vector = torch.tensor([vocab[x] for x in input_sequence]).unsqueeze(1)

input_embeddings = token_embed(input_vector) + position_embed(torch.arange(len(input_vector)))
input_embeddings.size()

**Warning!!** The pytorch Transformer classes accept input as `Length x Batch x Dim`

#### Key Idea 2: Modularity
The Transformer (encoder) is composed of a stack of **N identical layers**.

<img src="img/layers.png" alt="Drawing" style="width: 35%;"/>

In [None]:
import torch.nn as nn
nn.TransformerEncoder?

#### The `forward` passes the input through the N layers, then normalizes it:

**Warning!!** The forward function accepts input as `Length x Batch x Dim`

In [None]:
# nn.TransformerEncoder.forward??

In [None]:
encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=2, dim_feedforward=64, dropout=0.1)

encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)

In [None]:
outputs = encoder(input_embeddings)

print("input size: \t%s" % str(tuple(input_embeddings.shape)))
print("output size:\t%s" % str(tuple(outputs.shape)))
outputs

#### Each layer has two parts, **self-attention** and a feed-forward transformation:

<img src="img/layer.png" alt="Drawing" style="width: 65%;"/>

In [None]:
# nn.TransformerEncoderLayer??

In [None]:
# nn.TransformerEncoderLayer.forward??

### Key Idea 3: Self-Attention

In the RNN, the hidden state contains information about previous tokens.
The Transformer instead performs **attention** over all inputs at a given layer. 'Attention' computes an output vector by taking a weighted sum of input vectors. The weights are 'attention weights'. The Transformer uses **scaled dot-product attention**:
#### $$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

and 'Multi-head Attention' refers to applying several of these operations in parallel.

#### *Key Property*: Each output vector of a layer $n$ can using information from **all** inputs to the layer $n$.

Thus each **final output vector** can incorporate information from **all input words**.

(If we want to prevent information flow such as in left-to-right language modeling, we can use masking).

In [None]:
attn = nn.MultiheadAttention(dim, 2, dropout=0.0)

attn_outputs, attn_weights = attn.forward(query=outputs, key=outputs, value=outputs)

print("input shape: %s" % (str(tuple(outputs.size()))))
print("output shape: %s" % (str(tuple(attn_outputs.size()))))
print(outputs)

print("\nattn weights shape: %s" % (str(tuple(attn_weights.size()))))
print(attn_weights)

#### Summary

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, max_len, dim=8, num_layers=4, nhead=2):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, dim)
        self.position_embed = nn.Embedding(max_len, dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=nhead, dim_feedforward=64, dropout=0.0)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.projection = nn.Linear(dim, vocab_size)
    
    def features(self, token_indices):
        pos = torch.arange(len(token_indices), device=token_indices.device).unsqueeze(1)
        x = self.token_embed(token_indices) + self.position_embed(pos)
        x = self.encoder(x)
        return x
    
    def forward(self, token_indices):
        x = self.features(token_indices)
        x = self.projection(x)
        return x

In [None]:
input_vector.size()

In [None]:
model = Transformer(len(vocab), max_len=100)

model.features(input_vector)

## Back to Masked Language Modeling

Recall the **key property** of Transformers: due to self-attention, each output vector can incorporate information from *all* input tokens.

<img src="img/mlm.png" alt="Drawing" style="width: 45%;"/>

This is useful for masked language modeling, where we want to use information from the entire context when predicting the masked token(s).

#### MLM on Persona-Chat

In [None]:
import utils
raw_datasets, datasets, vocab = utils.load_personachat()

In [None]:
from torch.utils.data.dataloader import DataLoader

trainloader = DataLoader(datasets['train'], batch_size=4, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))
validloader = DataLoader(datasets['valid'], batch_size=4, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))

In [None]:
batch = next(trainloader.__iter__())
batch

In [None]:
def mask_tokens(inputs, mask_prob, pad_token_id, mask_token_id, vsize):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""
    inputs = inputs.clone()
    labels = inputs.clone()
    # Sample tokens in each sequence for masked-LM training
    masked_indices = torch.bernoulli(torch.full(labels.shape, mask_prob)).bool()
    masked_indices = masked_indices & (inputs != pad_token_id)
    labels[~masked_indices] = -1  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = mask_token_id

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vsize, labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

In [None]:
inputs, labels = mask_tokens(batch, mask_prob=0.15, mask_token_id=vocab.get_id('[M]'), pad_token_id=vocab.get_id('<pad>'), vsize=len(vocab))
print("Mask token id: %d" % vocab.get_id('[M]'))
inputs

In [None]:
labels

In [None]:
model = Transformer(len(vocab), max_len=200)

In [None]:
logits = model(inputs)
logits.size()

In [None]:
labels.size()

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)

In [None]:
logits_ = logits.view(-1, logits.size(2))
labels_ = labels.view(-1)

criterion(logits_, labels_)

In [None]:
if False:
    import torch.optim as optim
    from tqdm import tqdm, trange
    from collections import defaultdict
    from torch.utils.data.dataloader import DataLoader

    trainloader = DataLoader(datasets['train'], batch_size=64, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))
    validloader = DataLoader(datasets['valid'], batch_size=64, collate_fn=lambda x: utils.pad_collate_fn(vocab.get_id('<pad>'), x))

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    model = Transformer(len(vocab), max_len=65, dim=256, nhead=8).to(device)

    model_parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(model_parameters, lr=0.001)

    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)

    stats = defaultdict(list)

    for epoch in range(50):
        for step, batch in enumerate(trainloader):
            model.train()        
            # Mask the batch
            inputs, labels = mask_tokens(batch, mask_prob=0.15, 
                                         pad_token_id=vocab.get_id('<pad>'),
                                         mask_token_id=vocab.get_id('[M]'), 
                                         vsize=len(vocab))
            inputs = inputs.to(device)
            labels = labels.to(device)

            logits = model(inputs)
            logits_ = logits.view(-1, logits.size(2))
            labels_ = labels.view(-1)

            optimizer.zero_grad()
            loss = criterion(logits_, labels_)

            loss.backward()
            optimizer.step()

            stats['train_loss'].append(loss.item())
            stats['train_loss_log'].append(loss.item())
            if (step % 500) == 0:
                avg_loss = sum(stats['train_loss_log']) / len(stats['train_loss_log'])
                print("Epoch %d Step %d\tTrain Loss %.3f" % (epoch, step, avg_loss))
                stats['train_loss_log'] = []

        for batch in validloader:
            model.eval()
            with torch.no_grad():
                # Mask the batch
                inputs, labels = mask_tokens(batch, mask_prob=0.15, 
                                             pad_token_id=vocab.get_id('<pad>'),
                                             mask_token_id=vocab.get_id('[M]'), 
                                             vsize=len(vocab))
                inputs = inputs.to(device)
                labels = labels.to(device)

                logits = model(inputs)
                logits_ = logits.view(-1, logits.size(2))
                labels_ = labels.view(-1)

                loss = criterion(logits_, labels_)
                stats['valid_loss'].append(loss.item())
        print("=== Epoch %d\tValid Loss %.3f" % (epoch, stats['valid_loss'][-1]))

### Example Conditionals

#### Load model  

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

checkpoint = utils.load('model', 'model', best=True)
options = checkpoint['options']
stats = checkpoint['stats']


model = utils.Transformer(len(vocab), options['max_len'], 
                          dim=options['dim'], 
                          nhead=options['nhead'])
model.load_state_dict(checkpoint['model_dict'])

In [None]:
model.eval()
model = model.to(device)

In [None]:
sentences = [['<s>', 'i', 'have', 'a', 'pet', '[M]', '.', '<s>'],
             ['<s>', 'i', 'have', 'two', 'pet', '[M]', '.', '<s>'],
             ['<s>', 'my', '[M]', 'is', 'a', 'lawyer', '.', '<s>'],
             ['<s>', 'my', '[M]', 'is', 'a', '[M]', '.', '<s>'],
             ['<s>', 'i', '[M]', '[M]', '[M]', 'sometimes', '.' , '<s>']]


def get_top_masked_tokens(tokens, vocab, device, top=10):
    ids = torch.tensor([vocab.get_id(x) for x in tokens], device=device).unsqueeze(1)
    masked = ids == vocab.get_id('[M]')

    logits = model(ids)[masked]
    probs = torch.softmax(logits, -1)

    print(' '.join(tokens))
    for ps in probs:
        probs, idxs = ps.sort(descending=True)

        for i in range(top):
            print("\t%s (%.4f)" % (vocab.get_token(idxs[i].item()),
                                   probs[i].item()))
        print()

In [None]:
for s in sentences:
    get_top_masked_tokens(s, vocab, device)

## Back to *BERT*

**B**idirectional

**E**ncoder

**R**epresentations from

**T**ransformers

#### - Masked Language Modeling at scale

#### - Learned representations are useful downstream

<img src="img/bert_citations.png" alt="Drawing" style="width: 45%;"/>

#### Great implementation in [transformers](https://github.com/huggingface/transformers):

In [None]:
!pip install transformers

In [None]:
import torch
from transformers import (
    BertForMaskedLM,
    BertTokenizer
)

### Details -- Model Variants

- $\text{BERT}_{\text{BASE}}$: 12 layers, hidden dimension 768, 12 attention heads (**110 million parameters**)
- $\text{BERT}_{\text{LARGE}}$: 24 layers, hidden dimension 1024, 16 attention heads (**340 million parameters**)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
model = BertForMaskedLM.from_pretrained('bert-base-cased', output_attentions=True)

if torch.cuda.is_available():
    model.cuda()

### Details -- Input Implementation


- `[CLS]` token: starts each sequence. Used as aggregate sequence representation.
- `[SEP]` token: separates two segments (e.g. two sentences).
- **Segment embedding**: learned embedding for every token indicating whether it belongs
to sentence A or sentence B.
- **Position embedding**: learned.


<img src="img/bert_inputs.png" alt="Drawing" style="width: 75%;"/>

**Exercise:** Which downstream tasks would two sequences be useful for?

### Tokenization

#### BERT represents text using **subword** tokens with a 30k token vocabulary.  



(more info [here](https://github.com/google/sentencepiece) and in the papers mentioned there)

<!-- - **Token embedding**: WordPiece embeddings with 30k token vocabulary. -->

In [None]:
tokenizer.tokenize("Pretraining is cool.")

In [None]:
tokenizer.tokenize("BERT represents text using subwords.")

### Examining Learned Conditionals (& Representations)

**Probing tasks** can be used to examine aspects of what the model has learned. 

Following [Petroni et al 2019](https://arxiv.org/pdf/1909.01066.pdf) we probe for '**knowledge**' that the model has learned by querying for masked out objects, e.g.:

<img src="img/bert_kb.png" alt="Drawing" style="width: 75%;"/>

The task also illustrates some aspects of the **conditional distributions** and **contextualized representations** that the model has learned.

(image from [Petroni et al 2019])


**Exercise:** The authors only consider *single-token* prediction. Why?

#### Probing Task

We use a dataset from [Petroni et al 2019](https://github.com/facebookresearch/LAMA).

In [None]:
import utils
data = utils.load_lama_squad(download=True)
data[0]

In [None]:
results = []

model.eval()
for example in tqdm(data, total=len(data)):
    sentence, label = example['masked_sentences'][0], example['obj_label']
    inp = torch.tensor([
        [tokenizer.cls_token_id] + 
        tokenizer.encode(sentence) + 
        [tokenizer.sep_token_id]
    ], device=device)
    
    mask = (inp == tokenizer.vocab[tokenizer.mask_token])
    out, attn = model(inp)
    
    probs, token_ids = out[mask].softmax(1).topk(10)
    probs = probs[0].tolist()
    token_ids = token_ids[0].tolist()

    tokens = [tokenizer.ids_to_tokens[i] for i in token_ids]

    results.append({
        'sentence': sentence,
        'label': label,
        'top_tokens': tokens,
        'top_probs': probs,
        'correct@1': tokens[0] == label,
        'attn': attn
    })

print("correct@1: %.3f" % (
    len([r for r in results if r['correct@1']]) / len(results)
))

In [None]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

correct = [r for r in results if r['correct@1']]
wrong = [r for r in results if not r['correct@1']]

def show(idx=0, attn_layer=0, is_correct=True):
    result = correct[idx] if is_correct else wrong[idx]

    # --- format the result into a string
    top_str = '\n\t'.join([
        ('\t%s\t(%.4f)' % (tokens, probs)) 
        for tokens, probs in zip(result['top_tokens'], result['top_probs'])
    ])
    print("%s\n\tlabel:\t%s\n\n\ttop:%s" % (
        result['sentence'], 
        result['label'], 
        top_str
    ))

    # --- visualize attention
    print("Attention weights (12 heads) from layer %d:" % attn_layer)
    fig, axs = plt.subplots(3, 4, figsize=(18, 12))

    toks = ['[CLS]'] + tokenizer.tokenize(result['sentence']) + ['[SEP]']
    for i, ax in enumerate(axs.reshape(-1)):
        ax.matshow(result['attn'][attn_layer][0][i].data.cpu().numpy(), cmap='gray')

        ax.set_xticks(range(len(toks)))
        ax.set_xticklabels(toks, rotation=90, fontsize=15)
        ax.set_yticks(range(len(toks)))
        ax.set_yticklabels(toks, fontsize=15)
    plt.tight_layout()
    
interactive(
    show, 
    idx=(0, min(len(correct), len(wrong))-1), 
    attn_layer=range(12), 
    is_correct=True
)