# Transformer for Penn TreeBank

In [22]:
import torch.nn.functional as F
from torch import nn, Tensor
import torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
import math
from typing import Tuple, List, Union, Dict
import numpy as np
import grok
from grok.training import TrainableTransformer

In [3]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)



In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

## Load and batch data 

In [5]:
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [19]:
class PTBIterator:
    def __init__(self,train_pct, batchsize_hint, device,split="train", data_dir="../data",shuffle:bool = True) -> None:
        self.device = device
        # build vocab and tokenizer
        train_iter = PennTreebank(root=data_dir, split=("train"))
        self.tokenizer = get_tokenizer("basic_english")
        self.vocab = build_vocab_from_iterator(map(self.tokenizer, train_iter), specials=["<unk>"])
        self.vocab.set_default_index(self.vocab["<unk>"])
        
        self.make_dataset(train_pct,batchsize_hint,data_dir, split)
        self.reset_iteration(shuffle=shuffle)

    def make_dataset(self,train_pct:float, batchsize, data_dir, split):

        iter = PennTreebank(root=data_dir, split=(split))
        dataset = self.batchify(self.data_process(iter), batchsize)
        self.batchsize = batchsize
        rows, _ = self.calc_split_len(train_pct, dataset.shape[0])
        self.dataset = dataset[:rows]

    
    @staticmethod
    def calculate_batchsize(ds_size: int, batchsize_hint: int = 0) -> int:
        """
        Calculates which batch size to use

        :param ds_size: the number of equations in the dataset
        :param batchsize_hint: * 0 means we use a default batchsize
                               * -1 means the entire dataset
                               * float between 0 and 1 means each batch is
                                 that fraction of the DS
                               * int > 1 means that specific batch size
        :returns: the actual batchsize to use
        """

        if batchsize_hint == -1:
            return ds_size
        elif batchsize_hint == 0:
            return min(512, math.ceil(ds_size / 2.0))
        elif (batchsize_hint > 0) and (batchsize_hint < 1):
            return math.ceil(ds_size * batchsize_hint)
        elif batchsize_hint > 1:
            return min(batchsize_hint, ds_size)
        else:
            raise ValueError("batchsize_hint must be >= -1")

    def reset_iteration(self, shuffle=True):
        self.index = 0
        if shuffle:
            self.permutation = torch.randperm(self.dataset.shape[0])
        else:
            self.permutation = torch.arange(self.dataset.shape[0])

    def __iter__(self):
        """
        :returns: this iterator
        """
        return self

    def __next__(self) -> Dict[str, Tensor]:
        """
        Returns one batch of data.

        :raises: StopIteration when we're out of data
        :returns: batch tensor of shape (self.batchsize, tokens_per_eq)
        """

        batch_begin = self.index * self.batchsize
        if batch_begin > len(self.dataset) - 1:
            self.reset_iteration()
            raise StopIteration
        indices = self.permutation[batch_begin : batch_begin + self.batchsize]
        text = self.dataset[indices, :-1]
        target = self.dataset[indices, 1:]
        batch = {"text": text.to(self.device), "target": target.to(self.device)}
        self.index += 1
        return batch

    def __len__(self) -> int:
        """
        :returns: the total number of batches
        """
        return self.dataset.shape[0]

    def calc_split_len(self, train_pct, ds_len):
        train_rows = round(ds_len * (train_pct / 100.0))
        val_rows = ds_len - train_rows
        return train_rows, val_rows

    def data_process(self, raw_text_iter: dataset.IterableDataset) -> Tensor:
        """Convert raw text to flat tensor"""
        data = [torch.tensor(self.vocab(self.tokenizer(item))) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

    def batchify(self, data, bsz):
        print(data.shape[0])
        seq_len = data.shape[0] // bsz
        data = data[:seq_len * bsz]
        return data.view(bsz, seq_len).contiguous()
        

In [20]:
train_iter = PTBIterator(50,100,None)
next(iter(train_iter))

924412


{'text': tensor([[2745,   97,  346,  ...,   50,    6,   61],
         [   2,    4,  182,  ...,    6,    1,  803],
         [1567,    7, 5504,  ...,    2,    7,    2],
         ...,
         [ 110, 9203,  747,  ...,   77,   12,    1],
         [1925,    3,    2,  ..., 1086,   15, 6607],
         [  31,  871,    0,  ...,   34,   26,  972]]),
 'target': tensor([[  97,  346,    4,  ...,    6,   61,  746],
         [   4,  182,   38,  ...,    1,  803,  277],
         [   7, 5504, 4465,  ...,    7,    2,    2],
         ...,
         [9203,  747,   52,  ...,   12,    1,  162],
         [   3,    2,  171,  ...,   15, 6607,    1],
         [ 871,    0, 1192,  ...,   26,  972,   12]])}

In [21]:
train_dataset, _ =  grok.data.ArithmeticDataset.splits(5, "+")
iterator = grok.data.ArithmeticIterator(
            train_dataset,
            None,
            batchsize_hint=12,  # type: ignore
        )
d=  next(iter(iterator))
d

{'text': tensor([[  0,  73,   6, 107,   1,  61],
         [  0,  81,   6, 103,   1,  65],
         [  0,  51,   6,  58,   1,  87],
         [  0,  70,   6,  98,   1,  49],
         [  0,  56,   6, 100,   1,  37],
         [  0,  40,   6,  88,   1, 106],
         [  0, 118,   6, 101,   1, 100],
         [  0,  43,   6,  66,   1,  87],
         [  0,  73,   6, 113,   1,  67],
         [  0,  29,   6,  26,   1,  33],
         [  0, 118,   6,  68,   1,  67],
         [  0,  31,   6,  87,   1,  96]]),
 'target': tensor([[ 73,   6, 107,   1,  61,   0],
         [ 81,   6, 103,   1,  65,   0],
         [ 51,   6,  58,   1,  87,   0],
         [ 70,   6,  98,   1,  49,   0],
         [ 56,   6, 100,   1,  37,   0],
         [ 40,   6,  88,   1, 106,   0],
         [118,   6, 101,   1, 100,   0],
         [ 43,   6,  66,   1,  87,   0],
         [ 73,   6, 113,   1,  67,   0],
         [ 29,   6,  26,   1,  33,   0],
         [118,   6,  68,   1,  67,   0],
         [ 31,   6,  87,   1,  96,   

## Initiate an instance

In [25]:
ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

## Run the model

In [26]:
import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(0)
        if batch_size != bptt:  # only on last batch
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [27]:
best_val_loss = float('inf')
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

| epoch   1 |   200/ 1320 batches | lr 5.00 | ms/batch 255.35 | loss  7.25 | ppl  1401.81
| epoch   1 |   400/ 1320 batches | lr 5.00 | ms/batch 157.64 | loss  6.12 | ppl   454.28
| epoch   1 |   600/ 1320 batches | lr 5.00 | ms/batch 157.42 | loss  5.87 | ppl   352.69
| epoch   1 |   800/ 1320 batches | lr 5.00 | ms/batch 188.83 | loss  5.68 | ppl   292.56
| epoch   1 |  1000/ 1320 batches | lr 5.00 | ms/batch 152.71 | loss  5.60 | ppl   270.14
| epoch   1 |  1200/ 1320 batches | lr 5.00 | ms/batch 161.84 | loss  5.49 | ppl   243.05
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 251.30s | valid loss  5.49 | valid ppl   242.27
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1320 batches | lr 4.75 | ms/batch 116.82 | loss  5.43 | ppl   228.55
| epoch   2 |   400/ 1320 batches | lr 4.75 | ms/batch 109.64 | loss  5.37 | ppl   215.61
| epoch   2 |   600/ 1320

## Evaluate the best model on the test dataset

In [28]:
test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

| End of training | test loss  5.23 | test ppl   186.91
