# My Transformer

I wrote this to help me understand _Attention Is All You Need_: https://arxiv.org/abs/1706.03762

In [None]:
import logging
logging.basicConfig(level=logging.WARN)


from argparse import ArgumentParser
from collections import defaultdict, Counter
import multiprocessing.pool
import math
import os
import random
import sys
import string
import time
import numpy as np
from numpy import sqrt, sin, cos, floor, mean, exp

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.checkpoint import checkpoint
from torch.optim.lr_scheduler import LambdaLR
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
import torchtext.datasets
from torchtext.datasets.language_modeling import LanguageModelingDataset


import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.profiler import AdvancedProfiler
#from pytorch_lightning.callbacks.lr_logger import LearningRateLogger

## Model Architecture

### Section 3.2.1: Scaled Dot-Product Attention

In [None]:
class AttentionHead(nn.Module):
    """Implements section 3.2.1"""
    def __init__(self, d_model, d_key):
        super().__init__()
        
        self.d_key = d_key

        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, queries, keys, values, mask=None):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) 
        scores = scores / sqrt(self.d_key)

        # Filter out attention to future positions
        if mask is not None:
            scores.masked_fill_(mask == 0, float('-inf'))
            
        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)

        return attn

### 3.2.2 Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    "Section 3.2.2"
    def __init__(self, d_model, heads):
        super().__init__()
        d_key = int(d_model / heads)

        attn_heads = [AttentionHead(d_model, d_key) for _ in range(heads)]
        self.attn_heads = nn.ModuleList(attn_heads)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values, mask=None):
        head_attns = [h(queries=queries, keys=keys, values=values, mask=mask) 
                      for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        ret = self.Wo(head_attn)
        return ret

### 3.3 Position-wise Feed-Forward Networks

In [None]:
class FFN(nn.Module):
    "Section 3.3"
    def __init__(self, d_model, multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff, bias=False), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model, bias=False))

    def forward(self, x):
        return self.ffn(x)

### 3.1 Encoder and Decoder Stacks

In [None]:
class EncoderBlock(nn.Module):
    "Section 3.1, Encoder"
    def __init__(self, d_model, heads, dropout):
        super().__init__()
        
        self.attn = MultiHeadAttention(d_model=d_model, heads=heads)
        self.attn_drop = nn.Dropout(p=dropout)
        self.attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.attn(x, x, x)
        a1 = self.attn_drop(a1)
        a1 = self.attn_norm(x + a1) 

        a2 = self.ffn(a1)
        a2 = self.ffn_drop(a2)
        a2 = self.ffn_norm(a1 + a2)

        return a2

In [None]:
class DecoderBlock(nn.Module):
    "Section 3.1, Decoder"
    def __init__(self, d_model, heads, dropout, encoder_attention=True):
        super().__init__()
           
        
        self.self_attn = MultiHeadAttention(d_model, heads)
        self.self_attn_drop = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        if encoder_attention:
            self.enc_attn = MultiHeadAttention(d_model, heads)
            self.enc_attn_drop = nn.Dropout(p=dropout)
            self.enc_attn_norm = nn.LayerNorm(d_model)
        else:
            self.enc_attn = None
            
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out=None, self_attn_mask=None):
        a1 = self.self_attn(x, x, x, self_attn_mask)
        a1 = self.self_attn_drop(a1)
        a1 = x + a1  # residual
        a1 = self.self_attn_norm(a1) 
        
        if self.enc_attn is None:
            a2 = a1
        else:
            a2 = self.enc_attn(a1, encoder_out, encoder_out)
            a2 = self.enc_attn_drop(a2)
            a2 = a1 + a2  # residual
            a2 = self.enc_attn_norm(a2)

        a3 = self.ffn(a2)
        a3 = self.ffn_drop(a3)
        a3 = a2 + a3  # residual
        a3 = self.ffn_norm(a3)
        
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, d_model, heads, num_blocks, dropout):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, d_model, heads, num_blocks, dropout, encoder_attention=True):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, dropout, encoder_attention) 
                                     for _ in range(num_blocks)])
            
    def forward(self, decoder_in, encoder_out=None, self_attn_mask=None):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out, self_attn_mask)
        return a

### Section 3 Model Architecture

In [None]:
class Transformer(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        
        self.hparams = hparams
        bptt_len = hparams.bptt_len
        d_model = hparams.d_model
        decoder_layers = hparams.decoder_layers 
        dropout = hparams.dropout
        encoder_layers = hparams.encoder_layers
        encoder_decoder_attention = encoder_layers > 0
        heads = hparams.attn_heads 
      
        self.prepare_data()
        
        self.vocab = self.train_dataset.fields['text'].vocab
        vocab_len = len(self.vocab)
        pad_token = '<pad>'
        pad_index = self.vocab.stoi[pad_token]       
                
        self.embedding = nn.Embedding(vocab_len, d_model, padding_idx=pad_index)
        self.register_buffer('position_encoding', self._position_encoding(bptt_len, d_model))
        self.register_buffer('self_attn_mask', self._make_mask(bptt_len))
                                            
        self.encoder = Encoder(d_model, heads, encoder_layers, dropout)
        self.decoder = Decoder(d_model, heads, decoder_layers, dropout, encoder_decoder_attention)

        self.linear = nn.Linear(d_model, vocab_len, bias=False)
           
        num_params = sum([np.prod(p.shape) for p in self.parameters()])
        print(f"Model has {num_params:,d} parameters")

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--attn_heads', type=int, default=8)
        parser.add_argument('--bptt_len', type=int, default=360)
        parser.add_argument('--d_model', type=int, default=512)
        parser.add_argument('--dataset', type=str, default='WikiText103')
        #parser.add_argument('--dataset', type=str, default='WikiText2')
        parser.add_argument('--dropout', type=float, default=0.1)
        parser.add_argument('--encoder_layers', type=int, default=0)
        parser.add_argument('--decoder_layers', type=int, default=6)
        parser.add_argument('--max_vocab_size', type=int, default=60000)
        parser.add_argument('--minibatch_size', type=int, default=16)
        parser.add_argument('--warmup_steps', type=int, default=2500)
        parser.add_argument('--lr_multiplier', type=float, default=1)
        return parser
            
    @classmethod
    def _make_mask(cls, bptt_len):
        return torch.ones([bptt_len, bptt_len]).tril()

    @classmethod
    def _position_encoding(cls, bptt_len, d_model):
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(bptt_len)]
        stack = torch.stack(rows, dim=1)
        
        return stack.T

    def embed(self, indices):
        """Implements the embedding from Section 3.4 Embeddings and Softmax"""
        this_bptt_len = indices.shape[-1]
        pe = self.position_encoding[:this_bptt_len, :]

        embedded = self.embedding(indices)
        
        return pe + embedded
        
    def forward(self, decoder_in=None, encoder_in=None, pos=None):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""

        this_bptt_len = decoder_in.shape[-1]

        # Encode
        if self.hparams.encoder_layers > 0:
            ei_embedded = self.embed(encoder_in)
            encoded = self.encoder(ei_embedded)
        else:
            encoded = None
        
        # Decode
        if self.hparams.decoder_layers > 0:
            di_embedded = self.embed(decoder_in)
            self_attn_mask = self.self_attn_mask[:this_bptt_len, :this_bptt_len]
            decoded = self.decoder(di_embedded, encoded, self_attn_mask)
        else:
            decoded = encoded
        
        # Return predictions for next token
        if pos is not None:
            decoded = decoded[:, pos, :]
        
        y_hat = self.linear(decoded)
                
        return y_hat

    def _accuracy(self, output, expected_indices):
        indices = torch.max(output, dim=-2)[1]
        indices = indices.squeeze()
        acc = (indices == expected_indices) / float(indices.numel())
        acc = acc.sum() * 100
        return acc

    def _step(self, batch, batch_idx):
        x = batch.text.T
        y = batch.target.T 
        y_hat = self(decoder_in=x)
        y_hat = y_hat.transpose(-2, -1)
        loss = F.cross_entropy(y_hat, y)
        acc = self._accuracy(y_hat, y)
        return loss, acc
        
    def training_step(self, batch, batch_idx):
        loss, accuracy = self._step(batch, batch_idx)
        perplexity = torch.exp(loss)
        schedulers = self.trainer.lr_schedulers[0]
        lr = schedulers['scheduler'].optimizer.param_groups[0]['lr']
        logs = {'train_loss': loss, 
                'train_perplexity': perplexity, 
                'train_accuracy': accuracy, 
                'learning_rate': lr,
               }
        return {'loss': loss, 'log': logs}
            
    def _scheduler_lr(self, step):
        """return the learning rate for this step."""
        step = step + 1  # handle step 0
        lr = (self.hparams.d_model**-.5) * min(step**-.5, step * (self.hparams.warmup_steps**-1.5))
        return lr
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), betas=(0.9, 0.98), eps=1e-9, lr=self.hparams.lr_multiplier)
        schedulers = [{
                         'scheduler': LambdaLR(optimizer, lr_lambda=self._scheduler_lr), 
                         'interval': 'step',
                         'frequency': 1
                      }]
        return [optimizer], schedulers

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self._step(batch, batch_idx)
        logs = {'val_loss': loss, 'val_accuracy': accuracy}
        return logs

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        perplexity = torch.exp(avg_loss)
        avg_accuracy = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss, 'val_accuracy': avg_accuracy, 'val_perplexity': perplexity}
        return {'val_loss': avg_loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        loss, accuracy = self._step(batch, batch_idx)
        logs = {'test_loss': loss, 'test_accuracy': accuracy}
        return logs

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        perplexity = torch.exp(avg_loss)
        avg_accuracy = torch.stack([x['test_accuracy'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss, 'test_accuracy': avg_accuracy, 'test_perplexity': perplexity}
        return {'test_loss': avg_loss, 'log': logs}

    def prepare_data(self):    
        dataloader = getattr(torchtext.datasets, self.hparams.dataset)
        TEXT = torchtext.data.Field()
        self.train_dataset, self.val_dataset, self.test_dataset = dataloader.splits(TEXT)
        TEXT.build_vocab(self.train_dataset, max_size=self.hparams.max_vocab_size)

    def _make_bptt_iterator(self, dataset):
        device = self.embedding.weight.device
        return torchtext.data.BPTTIterator(dataset, 
                                           batch_size=self.hparams.minibatch_size, 
                                           bptt_len=self.hparams.bptt_len,
                                           device=device)
    def train_dataloader(self):
        #Load the dataset
        return self._make_bptt_iterator(self.train_dataset)
 
    def val_dataloader(self):
        #Load val dataset
        return self._make_bptt_iterator(self.val_dataset)

    def test_dataloader(self):
        #Load test data
        return self._make_bptt_iterator(self.test_dataset)

## Parse Arguments

In [None]:
parser = ArgumentParser()
parser.add_argument('--random_seed', type=int, default=8)
#parser.add_argument('--cuda_device_id', type=int, default=0)
parser = Transformer.add_model_specific_args(parser)
#parser = Trainer.add_argparse_args(parser)
args, _ = parser.parse_known_args()

# Make sure d_model, heads, and d_key are compatible
assert args.d_model % args.attn_heads == 0, \
    f'attn_heads=%s does not evenly divide d_model=%s' % (args.attn_heads, 
                                                          args.d_model)
args.d_key = args.d_model / args.attn_heads

In [None]:
# Set up the RNGs for repeatability
if args.random_seed:
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.random_seed)

## Train the Model

In [None]:
#PATH = 'perplexity-47_heads-8_decoder-layers-6_dmodel-512_bptt-360.ckpt'
#PATH = 'pl_checkpoints-expt2/_ckpt_epoch_4.ckpt'
#model = Transformer.load_from_checkpoint(checkpoint_path=PATH)

In [None]:
# Create the model
model = Transformer(args)
model = model.float()

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir tb_logs/my_transformer --bind_all serve 

In [None]:
logger=WandbLogger(project='my-transformer')

checkpoint_dir = './checkpoints/wandb-%s/' % logger.experiment.id

trainer = Trainer(gpus=1,
                  #distributed_backend='dp',
                  max_epochs=10, 
                  val_check_interval=0.05,
                  profiler=True,
                  #profiler=AdvancedProfiler(),
                  checkpoint_callback=ModelCheckpoint(filepath=checkpoint_dir),
                  early_stop_callback=EarlyStopping(patience=0),
                  logger=logger, 
                  #logger=TensorBoardLogger("tb_logs", name="my_transformer")
                 )

trainer.fit(model=model)

In [None]:
print(trainer.profiler.summary())

In [None]:
#trainer.save_checkpoint("6-layers_512-d_model_bptt-60.ckpt")

## Sampling Helper Functions

In [None]:
unk_token = '<unk>'
pad_token = '<pad>'
eos_token = '<eos>'

def numericalize(string, model, pad=True):
    """Takse a string and returns a tensor of indices for the tokens"""
    bptt_len = args.bptt_len
    TEXT_FIELD = model.train_dataset.fields['text']
    tokens = TEXT_FIELD.tokenize(string)
    tokens = tokens[-bptt_len:]
    t = TEXT_FIELD.numericalize([tokens]).T
    num_tokens = t.shape[1]
    if pad:
        pad_len = max(0, bptt_len - num_tokens)
        pad_t = torch.zeros([1, pad_len]).long()
        t = torch.cat((t, pad_t), dim=1)
    return t, num_tokens

def tokenize(indices, vocab):
    "Takes a tensor of token indices and returns a string"
    vocab = model.train_dataset.fields['text'].vocab
    tokens = [vocab.itos[i] for i in indices.squeeze()]
    return ' '.join(tokens)

def get_next_token(decoder_in, vocab, pos, model=model, deterministic=False,
                  ):
    """Runs one step of auto-regression, returning the output token for
    position `pos`."""
    
    #print('decoder_in=', decoder_in)
    #print('pos=', pos)
    #print('decoder_in[0,pos]=', decoder_in[0,pos])
    #if pos + 1 < decoder_in.shape[1]:
        #print('decoder_in[0,pos+1]=', decoder_in[0,pos+1])
    decoder_out = model(decoder_in=decoder_in)
    #print('decoder_out=', decoder_out)
    #print('decoder_out[o,pos,:]=', decoder_out[0,pos,:])
    
    if deterministic:
        _, indices = torch.max(decoder_out, dim=-1)
    else:
        probs = nn.functional.softmax(decoder_out.float(), dim=-1)
        m = torch.distributions.multinomial.Multinomial(probs=probs)
        _, indices = torch.max(m.sample(), dim=-1)

    next_index = int(indices[0,pos])
    return next_index, vocab.itos[next_index]

def sample(prompt, model, deterministic=False, prnt=True):
    """Auto-regresses using prompt to create the encoder_out tensor"""
    bptt_len = args.bptt_len
    vocab = model.train_dataset.fields['text'].vocab
    di, num_tokens = numericalize(prompt, model=model)
    #di = di.to(device)

    with torch.no_grad():
        eval_model = model.eval()
        out = []
        next_token = None
        next_index = None
        pos = num_tokens - 1
        #print(pos)
        for _ in range(2 * bptt_len):
            #print('di =', di)
            next_index, next_token = get_next_token(di, vocab, pos=pos, model=eval_model, deterministic=deterministic)
            #print('next_index =', next_index)
            #print('next_token =', next_token)
            if next_token in (eos_token, pad_token):
                break
            if next_token is not None:
                out.append(next_token)
                pos += 1
                if pos >= di.shape[1]:
                    di = torch.roll(di, -1, -1)
                    pos -= 1
                di[0, pos] = next_index
        
    out = ' '.join(out)
    if prnt:
        print(prompt + '\n --> \n' + out)
    return out

## Sample the model

In [None]:
prompt = """Born in Omaha , Nebraska , Malcolm X spent his teenage years living in a series of foster homes after his father 's death and his mother 's hospitalization . He engaged in several illicit activities there , eventually being sentenced"""

In [None]:
model = model.float()
for i, _ in enumerate(range(5)):
    print()
    print('Completion #%s:' % i)
    #print('-' * 20)
    out = sample(prompt, model=model, deterministic=False, prnt=False)
    print(prompt, '-->', out)
    #print('-' * 20)

In [None]:
prompt = """Albert Einstein ( 14 March 1879 @–@ 18 April 1955 ) was a German @-@ born theoretical physicist who developed the theory of relativity , one of the two pillars of modern physics ( alongside quantum mechanics ) . His work"""

In [None]:
model = model.float()
for i, _ in enumerate(range(5)):
    print()
    print('Completion #%s:' % i)
    #print('-' * 20)
    out = sample(prompt, model=model, deterministic=False, prnt=False)
    print(prompt, '-->', out)
    #print('-' * 20)

In [None]:
prompts = ['The',
           'Of',
           'To',
           'In',
           'A', 
           'Was',
           'The',
           'On',
           'That',
           'For',
           'As',
           'With']
for prompt in prompts:
    #print('Prompt:')
    #print("=" * 40)
    #print(prompt, '...')
    #print("=" * 40)
    for i, _ in enumerate(range(1)):
        print()
        #print('Completion #%s:' % i)
        #print('-' * 20)
        out = sample(prompt, model=model, deterministic=False, prnt=False)
        print(prompt, '-->', out)
        #print('-' * 20)

In [None]:
prompt = """5 m ("""
sample(prompt, model=model, deterministic=True, prnt=True)

In [None]:
prompt="""Rollerblade is a brand of inline skates owned by Nordica, part of the Tecnica Group of Giavera del Montello, Treviso, Italy.[4][5]The company was started by Scott Olson (b. 1960) and Brennan Olson (b. 1964) in Minneapolis as Ole's Innovative Sports; when they sold the company, it became Rollerblade, Inc.[6] and has changed hands over time between Nordica, Benetton Group and Tecnica.[7]Even though the long established Roces company was the first to"""

for i, _ in enumerate(range(5)):
    print()
    print('Completion #%s:' % i)
    #print('-' * 20)
    out = sample(prompt, model=model, deterministic=False, prnt=False)
    print(prompt, '-->', out)
    #print('-' * 20)