# masked language model after BERT
heavily influenced by the [BERT paper](https://arxiv.org/abs/1810.04805) and [codertimo's pytorch implementation](https://github.com/codertimo/BERT-pytorch), which itself borrows heavily on the code from [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (20, 14)

In [None]:
import math
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import io
import string
import random
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

import spacy
nlp = spacy.load('en')

First define multi-headed attention mechanism

In [None]:
class Attention(nn.Module):
    """
    Compute Scaled Dot Product Attention
    """
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn
    

class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [(layer(x)
                              .view(batch_size, -1, self.h, self.d_k)
                              .transpose(1, 2))
                             for layer, x in zip(self.linear_layers, 
                                                 (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)

Then a couple of utils defining `SublayerConnection()` and `PositionwiseFeedForward()`, which contribute to the transformer block. Most of this stuff comes from [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) page, with a little bit of modification.

In [None]:
class LayerNorm(nn.Module):
    "Construct a layernorm module"
    def __init__(self, wv_size, epsilon=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(wv_size))
        self.b_2 = nn.Parameter(torch.zeros(wv_size))
        self.epsilon = epsilon

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.epsilon) + self.b_2


class SublayerConnection(nn.Module):
    "A resnet connection, followed by a layer norm."
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

In [None]:
class GELU(nn.Module):
    "GELU paper here https://arxiv.org/abs/1606.08415"
    def forward(self, x):
        a = math.sqrt(2 / math.pi)
        b = x + 0.044715 * torch.pow(x, 3)
        c = 1 + torch.tanh(a * b)
        return 0.5 * x * c


class PositionwiseFeedForward(nn.Module):
    """
    Simple feed forward network with GELU.
    d_ff (the size of the hidden layer) is here taken to be 2*d_model
    (the input size), so the network has a good amount of room to shift 
    itself around in these layers.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                          GELU(),
                                          nn.Dropout(dropout),
                                          nn.Linear(d_ff, d_model))

    def forward(self, x):
        return self.feed_forward(x)

In [None]:
class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param dropout: dropout rate
        """
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=hidden*2, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

the full BERT model

In [None]:
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """
    def __init__(self, pretrained_embedding, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param pretrained_embedding: our pre-trained word embeddings
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        # modifying what's in the paper here - instead of doing the whole weirdo 
        # BERTEmbedding thing, we just use the established fasttext embeddings. These 
        # can happily be fine-tuned, but we avoid all of the unknowable unseeable 
        # complications that come with the method in the paper. Also, why are they 
        # using weird word embedding sizes?? 300d has been shown to be plenty... Very odd
        self.embedding = nn.Embedding.from_pretrained(pretrained_embedding)
        self.emb_to_hidden = nn.Linear(pretrained_embedding.shape[1], hidden)
        
        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, dropout) 
             for _ in range(n_layers)])

    def forward(self, x):
        # attention masking for padded token
        # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        
        # embedding the indexed sequence to sequence of vectors. then bump up
        # pretrained embeddings into the correct d space for our network
        x = self.embedding(x)
        x = self.emb_to_hidden(x)
        
        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        # after going through the transformer we switch back into the original wv space
        return x

which is then used to build a masked-language-model and next-sentence-predictor

In [None]:
class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """
    def __init__(self, hidden_size, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super(MaskedLanguageModel, self).__init__()
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))


class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """
    def __init__(self, hidden_size):
        """
        :param hidden: BERT model output size
        """
        super(NextSentencePrediction, self).__init__()
        self.linears = nn.Sequential(nn.Linear(hidden_size*2, hidden_size//2),
                                     nn.ReLU(),
                                     nn.Dropout(0.2),
                                     nn.Linear(hidden_size//2, 2))
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x_1, x_2):
        concatenated = torch.cat([x1[:, 0], 
                                  x2[:, 0]], 1)
        classification = self.linears(concatenated)
        return self.softmax(classification)
    

# word vectors

In [None]:
wv_path = '/Users/pimh/datasets/crawl-300d-2M.vec'
wv_file = io.open(wv_path, 'r', encoding='utf-8', newline='\n', errors='ignore')

fasttext = {line.split()[0]: np.array(line.split()[1:])
            for line in tqdm(list(wv_file)[1:])}

In [None]:
word_to_index = {word: index for index, word in enumerate(fasttext.keys())}

word_to_index['<UNK>'] = len(fasttext)
word_to_index['<MASK>'] = len(fasttext)
word_to_index['<PAD>'] = len(fasttext)

index_to_word = {index: word for word, index in word_to_index.items()}

In [None]:
fasttext_embedding_matrix = torch.Tensor(np.array(list(fasttext.values())).astype(np.float))

In [None]:
mean_wv = fasttext_embedding_matrix.mean(dim=0).view(1, -1)

In [None]:
fasttext_embedding_matrix = torch.cat([fasttext_embedding_matrix,
                                       mean_wv, mean_wv, mean_wv])

# books

# building a dataset

# three types of dataset and dataloader

In [None]:
def sentence_to_indexes(sentence):
    '''turn a spacy sentence into a list of usable indexes'''
    tokens = [str(token) for token in sentence]
    indexes = [word_to_index[token] 
               if token in word_to_index
               else word_to_index['<UNK>'] 
               for token in tokens]

    return indexes


def pad_sequence(sentences, pad_length=None):
    if pad_length is None:
        pad_length = max([len(sent) for sent in sentences])

    padded = np.full((len(sentences), pad_length), word_to_index['<PAD>'])
    for i, sentence in enumerate(sentences):
        padded[i][pad_length - len(sentence):] = sentence
    return padded


def custom_collate_fn(batch):
    masked, target = zip(*batch)
    
    batch_size = len(target)
    seq_length = max([len(s) for s in target])

    padded_masked = pad_sequence(masked, pad_length=seq_length)
    padded_target = pad_sequence(target, pad_length=seq_length)
    
    return torch.LongTensor(padded_masked), torch.LongTensor(padded_target)

### Masked Language

In [None]:
def random_mask(index_sequence):
    '''
    takes an input sequence of indexes and randomly chooses one of 
    them to be modified. The chosen index is:
        - with p=0.8 : replaced with a specific <MASK> token
        - with p=0.1 : a random token from the vocabulary
        - with p=0.1 : left unchanged
    '''
    masked_sequence = index_sequence.copy()
    replace_index = random.randint(0, len(index_sequence)-1)
    
    p = random.random()
    if p < 0.8: masked_sequence[replace_index] = word_to_index['<MASK>']
    elif p > 0.9: masked_sequence[replace_index] = random.randint(0, len(word_to_index)-1)

    return masked_sequence

In [None]:
class MLMDataset(Dataset):
    def __init__(self, index_lists):
        self.index_lists = index_lists

    def __getitem__(self, index):
        target = self.index_lists[index]
        masked = random_mask(target)
        return masked, target

    def __len__(self):
        return len(self.index_lists)

In [None]:
acceptable_sentences = set([s for s in sentences if 5 < len(s) < 50])
index_list = [sentence_to_indexes(s) for s in list(acceptable_sentences)]

In [None]:
split_ratio = 0.8
train_size = int(split_ratio * len(indexes)
train_indexes = index_list[:train_size]
test_indexes = index_list[train_size:]

In [None]:
train_mlm_dataset = MLMDataset(train_indexes)
test_mlm_dataset = MLMDataset(test_indexes)

In [None]:
batch_size = 5

train_mlm = DataLoader(dataset=train_mlm_dataset,
                       batch_size=batch_size,
                       num_workers=5,
                       shuffle=True,
                       collate_fn=custom_collate_fn)

test_mlm = DataLoader(dataset=test_mlm_dataset,
                      batch_size=batch_size,
                      num_workers=5,
                      collate_fn=custom_collate_fn)

### Next Sentences

In [None]:
sentences = list(acceptable_sentences)

In [None]:
nsp = {}
i = 0

for j in tqdm(range(len(sentences) - 1)):
    s1, s2 = sentences[j:j+2]
    if s1 in acceptable_sentences and s2 in acceptable_sentences:
        nsp[i] = {'s1': sentence_to_indexes(s1),
                  's2': sentence_to_indexes(s2),
                  'label': 1}
        i += 1
    else:
        random_sentence = random.choice(list(acceptable_sentences))
        indexed_random_sentence = sentence_to_indexes(random_sentence)

        if s1 in acceptable_sentences:
            nsp[i] = {'s1': sentence_to_indexes(s1), 
                      's2': indexed_random_sentence,
                      'label': 0}
            i += 1
        if s2 in acceptable_sentences:
            nsp[i] = {'s1': indexed_random_sentence,
                      's2': sentence_to_indexes(s2),
                      'label': 0}
            i += 1

### Short to long equivalence
iterate over sentences and find adjectives, nouns, and adjective-noun-pairs, and create equivalence between them

In [None]:
stle = {}

for sentence in tqdm(list(acceptable_sentences)):
    for word in sentence:
        if word.pos_ in ['NOUN', 'ADJ']:
            stle[i] = {'short': sentence_to_indexes([str(word)]),
                       'long': sentence_to_indexes(sentence)}
            i += 1

    for j in range(len(sentence) - 2):
        if len(words[j:j+2]) == 2:
            word_1, word_2 = words[j:j+2]
            if ((word_1.pos_ == 'ADJ') & (word_2.pos_ == 'NOUN')):
                print('fuck')
                stle[i] = {'short': sentence_to_indexes([str(word_1), str(word_2)]), 
                           'long': sentence_to_indexes(sentence)}
                i += 1

# first, just try training one task

In [None]:
class_weights = (pd.Series(np.concatenate(index_list))
                 .value_counts(normalize=True)
                 .sort_index()
                 .values)

class_weights = torch.Tensor(class_weights)#.cuda()

In [None]:
losses = []

def train(model, train_loader, loss_function, optimiser, n_epochs):
    model.train()
    for epoch in range(n_epochs):
        loop = tqdm(train_mlm)
        for masked, target in loop:
            masked = torch.LongTensor(masked)#.cuda(non_blocking=True)
            target = torch.LongTensor(target)#.cuda(non_blocking=True)

            optimiser.zero_grad()
            preds = model(masked)
            print(preds.shape, target.shape)
            loss = loss_function(preds, target)
            loss.backward()
            optimiser.step()
            
            loop.set_description('Epoch {}/{}'.format(epoch + 1, n_epochs))
            loop.set_postfix(loss=loss.item())
            losses.append(loss.item())

In [None]:
model = nn.Sequential(BERT(fasttext_embedding_matrix), 
                      MaskedLanguageModel(hidden_size, vocab_length))

In [None]:
torch.backends.cudnn.benchmark = True

trainable_parameters = filter(lambda p: p.requires_grad, model.parameters())
optimiser = optim.Adam(trainable_parameters, lr=0.001)
loss_function = nn.NLLLoss(weight=class_weights)

In [None]:
train(model=model,
      train_loader=train_mlm,
      loss_function=loss_function,
      optimiser=optimiser,
      n_epochs=1)

# one universal training loop for all tasks

In [None]:
vocab_size = len(word_to_index)
hidden_size = 768

In [None]:
def train(model, train_loader, loss_function, optimiser, n_epochs, batches_per_epoch):
    model.train()
    for epoch in range(n_epochs):
        task_choice = np.random.choice(['MLM', 'NSP', 'STLE'])
        
        if task_choice == 'MLM':
            for masked, target in train_mlm:
                masked = torch.LongTensor(masked).cuda(non_blocking=True)
                target = torch.LongTensor(target).cuda(non_blocking=True)
                optimiser.zero_grad()
                preds = MLM(masked)
                loss = mlm_loss_function(preds, target)
                loss.backward()
                optimiser.step()
        
        if task_choice == 'NSP':
            # do the NSP thing
            
        if task_choice == 'STLE':
            for short, long in train_mlm:
                short = torch.LongTensor(short).cuda(non_blocking=True)
                long = torch.LongTensor(long).cuda(non_blocking=True)
                optimiser.zero_grad()
                short_embedding = model(short)
                long_embedding = model(long)
                loss = mlm_loss_function(preds, long)
                loss.backward()
                optimiser.step()
        

In [None]:
sentence

In [None]:
sentence_to_indexes(sentence)