#### IWSLT English MLM

This notebook shows a simple example of how to use the transformer provided by this repo for MLM.

We will use the IWSLT 2016 En dataset.

This is similar to BERT, except missing some other training tricks, such as NSP.

In [1]:
import numpy as np
from torchtext import data, datasets
from torchtext.data import get_tokenizer
import spacy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import sys
sys.path.append("..")

from model.EncoderDecoder import TransformerEncoder
from model.utils import device, Batch, BasicIterator
from model.opt import NoamOpt
from model.Layers import Linear

import time
from collections import Counter

##### The below does some basic data preprocessing and filtering, in addition to setting special tokens.

In [2]:
tok = get_tokenizer("moses")
PAD = "<pad>"
SOS = "<sos>"
EOS = "<eos>"

en_field = data.Field(tokenize=tok, pad_token=PAD, init_token=SOS, eos_token=EOS)
d = data.TabularDataset(".data/iwslt/de-en/train.en", format="csv", fields=[("text", en_field)], 
                    csv_reader_params={"delimiter":'\n'})
MIN_FREQ = 4
en_field.build_vocab(d.text, min_freq=MIN_FREQ, specials=["<mask>"])

##### The batch_size_fn helps deal with dynamic batch size for the torchtext iterator

##### The BasicIterator class helps with dynamic batching too, making sure batches are tightly grouped with minimal padding.

In [3]:
global max_text_in_batch
def batch_size_fn(new, count, _):
    global max_text_in_batch
    if count == 1:
        max_text_in_batch = 0
    max_text_in_batch = max(max_text_in_batch, len(new.text))
    return count * max_text_in_batch

train_loader = BasicIterator(d, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.text)),
                   batch_size_fn=batch_size_fn, train=True)

##### Single step over entire dataset, with tons of gradient accumulation to get batch sizes big enough for stable training.

In [4]:
def train_step(dataloader):
    i = 0
    loss = 0
    total_loss = 0
    for batch in dataloader:
        # Only take a step every 20th batch
        if (i + 1) % 20 == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss, _, _ = transformer.mask_forward_and_return_loss(criterion, batch.text, .15)
        loss.backward()
        total_loss += loss.item()
        i += 1

    return total_loss / i

#### Creating the pseudoBERT:

Subclassing the TransformerEncoder class allows us to implement a forward_and_return_loss_function easily, and requires nothing else before being fully functional. 

The TransformerEncoder class handles embedding and the transformer encoder layers itself, we simply need to follow it up with a single Linear layer. The masking is a bit complex, but should be understandable with the below comments.

The goal of MLM is to randomly mask tokens, then train a model to predict what the ground truth token actually is. This is a hard task that requires good understanding of language itself.

We use the utils.Batch object to automatically create padding masks.

In [5]:
class MLM(TransformerEncoder):
    def __init__(self, input_vocab_size,embedding_dim,
        n_layers,hidden_dim,n_heads,dropout_rate,
        pad_idx,mask_idx,):
        
        super(MLM, self).__init__(input_vocab_size,embedding_dim, n_layers,
                                  hidden_dim,n_heads,dropout_rate,pad_idx,)
        
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx
        
        self.fc1 = Linear(embedding_dim, input_vocab_size)
        
    def mask_forward_and_return_loss(self, criterion, seq, mask_rate):
        """
        Pass input through transformer encoder and returns loss, handles masking for
        both MLM and padding automagically
        Args:
            criterion: torch.nn.functional loss function of choice
            sources: source sequences, [seq_len, bs]
            mask_rate: masking rate for non-padding tokens

        Returns:
            loss, transformer output, mask
        """
        # count number of tokens that are padding
        number_of_pad_tokens = torch.sum(
            torch.where(seq == self.pad_idx, torch.ones_like(seq),
                        torch.zeros_like(seq)
                       ).float())
        # Don't mask pad tokens, scale mask ratio up accordingly
        num_tokens = np.prod(seq.size())
        # clamp to prevent errors if there are a huge amount of padding
        # tokens in a given batch (> 70%)
        true_masking_rate = torch.clamp((1 / (1 - (number_of_pad_tokens / num_tokens))) * mask_rate, 0, 1)
        bernoulli_probabilities = torch.zeros_like(seq) + true_masking_rate
        masking_mask = torch.bernoulli(bernoulli_probabilities).long().to(device)
        masked_seq = torch.where(torch.logical_and((seq != self.pad_idx), (masking_mask == 1)), 
                                                   (torch.ones_like(seq) * self.mask_idx).to(device), seq) 
        
        batch = Batch(masked_seq, None, self.pad_idx)
        out = self.forward(batch.src.to(device), batch.src_mask.to(device))
        out = self.fc1(out.transpose(0, 1)).transpose(0, 1)
        # zeroing out token predictions on non-masked tokens
        out = out * masking_mask.unsqueeze(-1)
        
                
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            # ((A-1) @ M) + 1 = A is 1 where B is 0, and otherwise unchanged
            # This makes loss only depend on masked tokens, like BERT
            (((seq-1) * masking_mask) + 1).contiguous().view(-1),
            ignore_index=self.pad_idx,
        )
        
        return loss, out, masking_mask

##### Here we instantiate the model and set hyperparameters. Note: this MLM model is extremely small for ease of recreating experiments. 

In [6]:
input_vocab_size = len(en_field.vocab)
embedding_dim = 512
n_layers = 4
hidden_dim = 1024
n_heads = 4
dropout_rate = .1
pad_idx = 1
mask_idx = 4
transformer = MLM(input_vocab_size, embedding_dim, n_layers, hidden_dim,
           n_heads, dropout_rate, pad_idx, mask_idx).to(device)

adamopt = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
optimizer = NoamOpt(embedding_dim, 1, 2000, adamopt)
criterion = F.cross_entropy

# optimization is unstable without this step
for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

#### Let's run 5 epochs over the entire dataset, printing loss once per epoch.

In [7]:
true_start = time.time()
for i in range(5):
    transformer.train()
    t = time.time()
    
    loss = train_step(train_loader)
    print("Epoch {}. Loss: {}, ".format((i+1), str(loss)[:5], int(time.time() - t)))
    print("Total time (s): {}, Last epoch time (s): {}".format(int(time.time()- true_start), int(time.time() - t)))

Epoch 1. Loss: 6.848, 
Total time (s): 373, Last epoch time (s): 373
Epoch 2. Loss: 5.246, 
Total time (s): 746, Last epoch time (s): 373
Epoch 3. Loss: 4.379, 
Total time (s): 1123, Last epoch time (s): 377
Epoch 4. Loss: 4.067, 
Total time (s): 1499, Last epoch time (s): 376
Epoch 5. Loss: 3.876, 
Total time (s): 1877, Last epoch time (s): 377


In [8]:
torch.save(transformer, "basic_MLM.pt")

  "type " + obj.__name__ + ". It won't be checked "


##### Let's go ahead and process some random example sentences not in the training data, and vizualize the results.

In [9]:
transformer.eval()
inp = en_field.process([tok("Now, if we all played football, this wouldn't be an issue."), 
                        tok("I don't really agree with you, honestly"),
                       tok("Not all who wander are lost.")]).to(device)
_, pred, mask = transformer.mask_forward_and_return_loss(criterion, inp, .20)
pred = pred.transpose(0, 1)
mask = mask.transpose(0, 1)

##### Simple code for visualization. Let's check out how our model did.

In [10]:
def visualize_model_predictions(inp, pred, mask):
    print("Sentence:", end=" ")
    for i in range(len(inp)):
        if en_field.vocab.itos[inp[i]] == "<eos>":
            break
        if mask[i] == 1:
            print("<" + en_field.vocab.itos[pred[i]] +  " | " + en_field.vocab.itos[inp[i]] + ">", end = " ")
        else:
            print(en_field.vocab.itos[inp[i]], end = " ")
    print("\n")

##### Masked tokens are surrounded by < >. The word on the left is the prediction, the word on the right is the ground truth. They're seperated by a |.

##### Despite being a small model, the predictions are fairly accurate.

In [11]:
for i in range(len(inp.T)):
    visualize_model_predictions(
        inp.transpose(0, 1)[i].tolist(), 
        torch.argmax(pred[i], dim=-1).tolist(), 
        mask[i])

Sentence: <sos> Now <, | ,> if we all played football <, | ,> this wouldn <&apos;t | &apos;t> be an issue . 

Sentence: <sos> I <don | don> &apos;t <just | really> agree with you <. | ,> honestly 

Sentence: <sos> <And | Not> all who wander are <there | lost> . 

