In [2]:
%reload_ext autoreload
%autoreload 2

# Language model (text generation, autoregressive model)

This experiment introduce a language model based on Transformers modules. The main idea is to introduce the mask to the delf attention module. Focusing on the process of each elelement of the sequence, the mask avoids to see forward words (avoids taking into account future words of the seqence).
The goal is to predict the next character given a seqence (input is a seqence of characters, output prediction should be the same character sequence shifted to the left)

In [69]:
import warnings
warnings.filterwarnings('ignore')
from _context import src
from src.models.model_utils import device_selection
from src.models.predict_model import GenerationCharacterTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import torch.distributions as dist

import numpy as np
import random, tqdm, sys, math, gzip


In [70]:
#Params
# NB, the enwik8 data contains tokens from 9 to 240, but well round up to the nearest
# power of two.
NUM_TOKENS = 256

context=256
embeding_size = 128
transformer_heads = 8
depth = 4
lr = 1e-4
warm = 100
iterations = 1e6
batch_size = 8

train_num = int(50e6)
val_num = int(5e6)
test_num = int(5e6)
total_num = train_num + val_num + test_num

# Dataset
Dataset is named enwik8 and contains $10^8$ characters of wikipedia text

In [5]:
# enwik8
# http://mattmahoney.net/dc/enwik8.zip

path = "./dataset/enwik8"
with open(path) as file:
    arr = np.fromstring(file.read(total_num), dtype=np.uint8)
    train_ds, val_ds, test_ds = np.split(arr, [train_num, val_num+ test_num])
    train_ds = torch.from_numpy(train_ds)
    val_ds = torch.from_numpy(val_ds)
    test_ds = torch.from_numpy(test_ds)

In [12]:
train_ds , train_ds.size()

(tensor([ 60, 109, 101,  ...,  61,  61,  61], dtype=torch.uint8),
 torch.Size([50000000]))

# Model
Provide a character sequence and the model generates for each character of the seqeuence the next character. hence the last prediction of the putput sequence shuold be the next unseen letter (output is one character shiftted)

<img src="images/transformer_model_generator.svg"  width="500" height="600">

It is necessary that the transformer model can not see further characters in order to perform the attention mechanism. For that reason the weighted matrix is masked with an upper diagonal of $-\inf$ (after softmax it gets a 0)
* Rimind that each row of th eattention weight matrix are the weight of each row sequence vector (each element of the weight vector is the weight of the row sequence vector)
* E.g: the first row of th eweight matrix contains only the fisrt element, the rest are $-inf$. Then, it means that for the first element of the sequence the only vector that counts is the first sequence vector, the others are ignored
<img src="images/mask.svg"  width="500" height="600">

In [63]:
model = GenerationCharacterTransformer(
                            embedding_size=128,
                            transformer_heads=10,
                            depth=5,
                            max_sequence=256, #context
                            token_size=NUM_TOKENS)

In [64]:
#Fast.check
x = torch.rand(8,256)
x = x.long()
o = model(x)
o.size(), o

(torch.Size([8, 256, 256]),
 tensor([[[-4.9760, -5.2845, -5.2069,  ..., -5.2880, -5.3604, -5.9527],
          [-5.4618, -6.0335, -5.3880,  ..., -4.5481, -5.9435, -5.3915],
          [-5.5381, -5.7659, -4.4586,  ..., -5.0764, -4.9413, -5.0837],
          ...,
          [-5.6106, -5.8205, -4.6277,  ..., -5.7351, -5.3977, -6.3035],
          [-5.0250, -6.0898, -4.6109,  ..., -5.4362, -5.0174, -5.7912],
          [-4.6288, -5.6312, -5.7336,  ..., -5.3678, -4.9244, -5.3985]],
 
         [[-4.9760, -5.2845, -5.2069,  ..., -5.2880, -5.3604, -5.9527],
          [-5.4618, -6.0335, -5.3880,  ..., -4.5481, -5.9435, -5.3915],
          [-5.5381, -5.7659, -4.4586,  ..., -5.0764, -4.9413, -5.0837],
          ...,
          [-5.6106, -5.8205, -4.6277,  ..., -5.7351, -5.3977, -6.3035],
          [-5.0250, -6.0898, -4.6109,  ..., -5.4362, -5.0174, -5.7912],
          [-4.6288, -5.6312, -5.7336,  ..., -5.3678, -4.9244, -5.3985]],
 
         [[-4.9760, -5.2845, -5.2069,  ..., -5.2880, -5.3604, -5.9527],


# Train

In [65]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [72]:
#One batch overfit
tensorboard = SummaryWriter(log_dir=".")

model = GenerationCharacterTransformer(
                            embedding_size=256,
                            transformer_heads=10,
                            depth=8,
                            max_sequence=context, #context
                            token_size=NUM_TOKENS)

optimizer = torch.optim.Adam(lr=lr, params=model.parameters())
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i:min(i/(warm/batch_size), 1.0))

num_characters = train_ds.size(0)
ixs = torch.randint(size=(batch_size,), low=0, high=num_characters - context -1)
ixs, ixs.size()
batch_input = [train_ds[i:i+context][None, :]for i in ix]
batch_target = [train_ds[i+1:i+context+1][None, :]for i in ix]
batch_input = torch.cat(batch_input, dim=0).long()
batch_target = torch.cat(batch_target, dim=0).long()

summary_loss = AverageMeter()
for i in tqdm.tqdm(range(300)):
    optimizer.zero_grad()
    pred = model(batch_input)
    #pred must be batch, classes, seq for using nll loss. So we transpose clsses by seq
    loss = F.nll_loss(pred.transpose(2,1), batch_target, reduction="mean")
    loss_val = float(loss.detach().item())
    if i%100 ==0:
        print("Losss ", loss_val)
    summary_loss.update(loss_val, n=batch_size)
    tensorboard.add_scalar("generation/train_loss", loss_val)
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

  0%|          | 0/300 [00:00<?, ?it/s]

Losss  5.660549640655518


 33%|███▎      | 100/300 [03:44<07:19,  2.20s/it]

Losss  0.2959350049495697


 67%|██████▋   | 200/300 [07:24<03:40,  2.20s/it]

Losss  0.01608143374323845


100%|██████████| 300/300 [11:10<00:00,  2.23s/it]


torch.Size([8, 128])