In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import random
from sklearn.model_selection import train_test_split
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers import normalizers
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer
from tokenizers import decoders
from sklearn.metrics import accuracy_score

# Load data

In [22]:
with open('qn_sequences.txt','r') as f:
    lines = f.read().split('\n')
    lines = [ln for ln in lines if len(ln)>1]
    train_lines, val_lines = train_test_split(lines,test_size=0.1,shuffle=True)
    val_lines, test_lines = train_test_split(val_lines,test_size=0.2,shuffle=True)
    
len(lines), len(train_lines), len(val_lines), len(test_lines)

(4145210, 3730689, 331616, 82905)

# Build Tokenizer

In [23]:
bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
bert_tokenizer.normalizer = normalizers.Sequence([Lowercase()])
bert_tokenizer.pre_tokenizer = Whitespace()
bert_tokenizer.decoder = decoders.WordPiece()

trainer = WordPieceTrainer(special_tokens=["[UNK]","[PAD]", "[MASK]", "[SEP]"],vocab_size=8192)
bert_tokenizer.train_from_iterator(lines,trainer)
bert_tokenizer.enable_padding(pad_id=bert_tokenizer.token_to_id('[PAD]'),length=128)
bert_tokenizer.enable_truncation(128)

base = Path('mlm-baby-bert/tokenizer',)
base.mkdir(exist_ok=True,parents=True)
bert_tokenizer.save(str(base / 'qn_sequences.json'))

# Load tokenizer

In [3]:
tokenizer = Tokenizer.from_file('./mlm-baby-bert/tokenizer/qn_sequences.json')

In [4]:
a = tokenizer.encode('nam quốc sơn hà nam đế cư')
for i,t in zip(a.ids,a.tokens):
    print(f'{i}:{t} ',end=' ')
tokenizer.decode(a.ids)

368:nam  607:quốc  224:sơn  189:hà  368:nam  851:đế  596:cư  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[PAD]  1:[

'nam quốc sơn hà nam đế cư'

In [5]:
tokenizer.token_to_id("[UNK]"),tokenizer.token_to_id("[PAD]"),tokenizer.token_to_id("[MASK]"), tokenizer.token_to_id("[SEP]")

(0, 1, 2, 3)

# Prepare Maskset

In [6]:
class MLMDataset:
    def __init__(self,lines):
        self.lines = lines
    def __len__(self,):
        return len(self.lines)
    def __getitem__(self,idx):
        line = self.lines[idx]
        ids = tokenizer.encode(line).ids
        labels = ids.copy()
        return ids, labels

In [7]:
def collate_fn(batch):
    input_ids = [torch.tensor(i[0]) for i in batch]
    labels = [torch.tensor(i[1]) for i in batch]
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels)
    # mask 15% of text leaving [PAD]
    mlm_mask = torch.rand(input_ids.size()) < 0.15 * (input_ids!=1)
    masked_tokens = input_ids * mlm_mask
    labels[masked_tokens==0]=-100 # set all tokens except masked tokens to -100
    input_ids[masked_tokens!=0]=2 # MASK TOKEN
    return input_ids, labels

In [8]:
ds = MLMDataset(lines)
dl = torch.utils.data.DataLoader(ds,batch_size=2,shuffle=True,collate_fn=collate_fn)

In [9]:
ds.__getitem__(0)[0][:10]

[360, 296, 429, 235, 415, 296, 385, 211, 1363, 950]

In [10]:
i,l = next(iter(dl))
print(i[0])
print(l[0])

tensor([387, 529, 484,   2, 330,   2,   2,   2,   2, 379,   2, 834,   2, 449,
        396,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1])
tensor([-100, -100, -100,  273, -100,  385,  277,  738,  199, -100,  538, -100,
        1500, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -

# Model construct

- straightforward simple implementation.

- nn.LayerNorm replaced with RMSNorm which is preferred to by many.

- It looks like BERT but it is not BERT. BERT is more complicated than this.

- Only implementing the MLM part of BERT so no need of [CLS] and [SEP] tokens

- Learned positional embeddings instead of sinusoidal in BERT.

- We can have a mask for the encoder self-attention as well by masking out the pad tokens so 

- attention layers ignore the extra stuff.

- For inference currently only supports batch size of 1.

- After the encoder outputs pass through the dim->vocab Linear layer, the logits at the 

- position where the token was masked are softmaxed and then with argmax the token that's 

- supposed to be there is predicted.

```
out: 1 x 128 x 256

if the input sequence for inference was masked at position 4, we extract 1 x 256 at index 4:

preds: out[:,4,:]

softmax -> argmax

preds: predicted token
```

In [11]:
# https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

In [12]:
class MultiheadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        assert dim % n_heads == 0, 'dim should be div by n_heads'
        self.head_dim = self.dim // self.n_heads
        self.in_proj = nn.Linear(dim,dim*3,bias=False)
        self.attn_dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
        self.out_proj = nn.Linear(dim,dim)
        
    def forward(self,x,mask=None):
        b,t,c = x.shape
        q,k,v = self.in_proj(x).chunk(3,dim=-1)
        q = q.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        k = k.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        v = v.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        
        qkT = torch.matmul(q,k.transpose(-1,-2)) * self.scale
        qkT = self.attn_dropout(qkT)
        
        if mask is not None:
            mask = mask.to(dtype=qkT.dtype,device=qkT.device)
            qkT = qkT.masked_fill(mask==0,float('-inf'))
              
        qkT = F.softmax(qkT,dim=-1)
        attn = torch.matmul(qkT,v)
        attn = attn.permute(0,2,1,3).contiguous().view(b,t,c)
        out = self.out_proj(attn)
        
        return out

In [13]:
class FeedForward(nn.Module):
    def __init__(self,dim,dropout=0.):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(dim,dim*4),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(dim*4,dim)
        )
        
    def forward(self, x):
        return self.feed_forward(x)

In [14]:
class EncoderBlock(nn.Module):
    def __init__(self, dim, n_heads, attn_dropout=0., mlp_dropout=0.):
        super().__init__()
        self.attn = MultiheadAttention(dim,n_heads,attn_dropout)
        self.ffd = FeedForward(dim,mlp_dropout)
        self.ln_1 = RMSNorm(dim)
        self.ln_2 = RMSNorm(dim)
        
    def forward(self,x,mask=None):
        x = self.ln_1(x)
        x = x + self.attn(x,mask)
        x = self.ln_2(x)
        x = x + self.ffd(x)
        return x

In [15]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,max_len,dim):
        super().__init__()
        self.max_len = max_len
        self.class_embedding = nn.Embedding(vocab_size,dim)
        self.pos_embedding = nn.Embedding(max_len,dim)
    def forward(self,x):
        x = self.class_embedding(x)
        pos = torch.arange(0,x.size(1),device=x.device)
        x = x + self.pos_embedding(pos)
        return x

In [16]:
class MLMBERT(nn.Module):
    def __init__(self, config):
        
        super().__init__()
        
        self.embedding = Embedding(config['vocab_size'],config['max_len'],config['dim'])
        
        self.depth = config['depth']
        self.encoders = nn.ModuleList([
            EncoderBlock(
                dim=config['dim'],
                n_heads=config['n_heads'],
                attn_dropout=config['attn_dropout'],
                mlp_dropout=config['mlp_dropout']
            ) for _ in range(self.depth)
        ])
        
        self.ln_f = RMSNorm(config['dim'])
        
        self.mlm_head = nn.Linear(config['dim'],config['vocab_size'],bias=False)
        
        self.embedding.class_embedding.weight = self.mlm_head.weight # weight tying
        
        self.pad_token_id = config['pad_token_id']
        self.mask_token_id = config['mask_token_id']
        self.sep_token_id = config['sep_token_id']
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
    def create_src_mask(self,src):
        return (src != self.pad_token_id).unsqueeze(1).unsqueeze(2) # N, 1, 1, src_len
    
    def forward(self,input_ids,labels=None):
        
        src_mask = self.create_src_mask(input_ids)
        enc_out = self.embedding(input_ids)
        for layer in self.encoders:
            enc_out = layer(enc_out,mask=src_mask)
        
        enc_out = self.ln_f(enc_out)
        
        logits = self.mlm_head(enc_out)
        
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1,logits.size(-1)),labels.view(-1))
            return {'loss': loss, 'logits': logits}
        else:
            # assuming inference input_ids only have 1 [MASK] token
            mask_idx = (input_ids==self.mask_token_id).flatten().nonzero().item()
            mask_preds = F.softmax(logits[:,mask_idx,:],dim=-1).argmax(dim=-1)
            return {'mask_predictions':mask_preds}

# Training Preparation

In [19]:
config = {
    'dim': 256,
    'n_heads': 8,
    'attn_dropout': 0.1,
    'mlp_dropout': 0.1,
    'depth': 6,
    'vocab_size': 8192,
    'max_len': 128,
    'pad_token_id': 1,
    'mask_token_id': 2,
    'sep_token_id' : 3
}

In [None]:
model = MLMBERT(config).to('cuda')
print('trainable:',sum([p.numel() for p in model.parameters() if p.requires_grad]))

In [None]:
train_ds = MLMDataset(train_lines)
val_ds = MLMDataset(val_lines)

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=128, shuffle=False, collate_fn=collate_fn)

# Single Masking Test

In [None]:
# TEST : SINGLE TOKEN MASKING

test_actuals = []
test_batches = []
for ln in tqdm(test_lines):
    tokenized = tokenizer.encode(ln)
    fi = len(tokenized.ids)
    if 1 in tokenized.special_tokens_mask:
        fi = torch.tensor(tokenized.special_tokens_mask).nonzero()[0].item() # ignore [PAD]
    m = torch.randint(0,fi,(1,)).item() # select random token to mask
    input_ids = torch.tensor(tokenized.ids)
    test_actuals.append(input_ids[m].item())
    input_ids[m]=2 # replace with [MASK]
    test_batches.append(input_ids)

# Training

In [None]:
epochs = 100
train_losses = []
valid_losses = []
test_accuracies = []
best_val_loss = 1e9

In [None]:
optim = torch.optim.Adam(model.parameters(),lr=6e-4 / 25.)
sched = torch.optim.lr_scheduler.OneCycleLR(optim,max_lr=6e-4,steps_per_epoch=len(train_dl),epochs=epochs)

In [None]:
for ep in tqdm(range(epochs)):
    model.train()
    trl = 0.
    tprog = tqdm(enumerate(train_dl),total=len(train_dl))
    for i, (input_ids, labels) in tprog:
        input_ids = input_ids.to('cuda')
        labels = labels.to('cuda')
        loss = model(input_ids,labels)['loss']
        loss.backward()
        optim.step()
        optim.zero_grad()
        sched.step()
        trl += loss.item()
        tprog.set_description(f'train step loss: {loss.item():.4f}')
    train_losses.append(trl/len(train_dl))
        
    model.eval()
    with torch.no_grad():
        vrl = 0.
        vprog = tqdm(enumerate(val_dl),total=len(val_dl))
        for i, (input_ids, labels) in vprog:

            input_ids = input_ids.to('cuda')
            labels = labels.to('cuda')
            loss = model(input_ids,labels)['loss']
            vrl += loss.item()
            vprog.set_description(f'valid step loss: {loss.item():.4f}')
        vloss = vrl/len(val_dl)
        valid_losses.append(vloss)
        print(f'epoch {ep} | train_loss: {train_losses[-1]:.4f} valid_loss: {valid_losses[-1]:.4f}')
        
        if vloss < best_val_loss:
            best_val_loss = vloss
            print('PREDICTING!')
            test_predictions = []
            for input_ids in tqdm(test_batches):
                input_ids = input_ids.unsqueeze(0)
                input_ids = input_ids.to('cuda')
                mask_preds = model(input_ids)['mask_predictions']
                test_predictions.extend(list(mask_preds.detach().cpu().flatten().numpy()))
            
            tacc = accuracy_score(test_actuals, test_predictions)
            test_accuracies.append(tacc)
            print(f'SINGLE MASK TOKEN PREDICTION ACCURACY: {tacc:.4f}')
            print('saving best model...')
            sd = model.state_dict()
            torch.save(sd,'./mlm-baby-bert/model.pt')

In [None]:
plt.plot(train_losses,color='red',label='train loss')
plt.plot(valid_losses,color='orange',label='valid loss')
plt.legend()
plt.show()

In [None]:
plt.plot(test_accuracies)
plt.title('single mask token prediction accuracy')
plt.show()

In [None]:
# best model
sd = torch.load('./mlm-baby-bert/model.pt')
model.load_state_dict(sd)

In [None]:
def predict_mask(sentence):
    
    x = tokenizer.encode(sentence)
    
    # picking an index to mask, range: [0,len-1]
    fi = len(x.ids)
    # if the sample contains pad tokens, we can't mask them, so limiting the end index to ignore padding
    if 1 in x.special_tokens_mask:
        fi = torch.tensor(x.special_tokens_mask).nonzero()[0].item() # ignore [PAD]
    # random index to mask
    idx = torch.randint(0,fi,(1,)).item()
    
    input_ids = x.ids.copy()
    masked_token = tokenizer.decode([input_ids[idx]])
    
    # masking
    input_ids[idx] = 2 # idx -> [MASK]
    masked_sentence = input_ids.copy()
    
    # preparing input
    input_ids = torch.tensor(input_ids,dtype=torch.long).unsqueeze(0).to('cuda')
    
    # extracting the predicted token
    out = model(input_ids)
    predicted = x.ids.copy()
    predicted[idx] = out['mask_predictions'].item()
    predicted_token = tokenizer.decode([out['mask_predictions'].item()])
    
    print(f'masked: {masked_token} predicted: {predicted_token}')
    masked_sentence = tokenizer.decode(masked_sentence,skip_special_tokens=False)
    masked_sentence = masked_sentence.replace('[PAD]','')
    print('ACTUAL:',sentence)
    print('MASKED:',masked_sentence)
    print(' MODEL:',tokenizer.decode(predicted))
    
    return int(masked_token == predicted_token)

In [None]:
correct = 0
for sentence in random.choices(train_lines+test_lines,k=100):
    correct += predict_mask(sentence)
    print('\n\n')
print(f'CORRECT:{correct}/{100}')