# My Transformer

I wrote this to help me understand _Attention Is All You Need_: https://arxiv.org/abs/1706.03762

I cut out the Encoder and I'm using it to generate English words

In [None]:
from collections import defaultdict, Counter
import multiprocessing.pool
from math import sqrt, sin, cos
import os
import random
import sys
import string
import time
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2
from torchtext.datasets.language_modeling import LanguageModelingDataset


## Setup
 

In [None]:
# Basic Config
conf = {
        'attn_heads': 8,
        'bptt_len': 40,
        #'cuda_device_ids': [3, 2, 1, 0],  # I need better GPU coolng first
        'cuda_device_ids': [1],
        'd_model': 512,
        #'datafile': './city_names.txt', # from: https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068
        #'datafile': './corncob_lowercase.txt',  # from: http://www.mieliestronk.com/corncob_lowercase.txt
        #'datafile': './alphabet_short.txt',  
        #'datafile': './dummy_data.txt', 
        'dataset': 'WikiText2',
        #'dataset': 'WikiText103',
        'dropout': 0.1,
        'final_softmax': False,
        #'learning_rate': 0.01,
        'epochs_per_loop': 1,
        'total_training_loops': 6,
        'num_blocks_encoder': 0,
        'num_blocks_decoder': 24,
        #'minibatch_size': 32 * 16,
        'minibatch_size': 16,
        'optimizer': 'Adam',  
        #'optimizer': 'SGD',
        'random_seed': 0,
        'warmup_steps': 100,
        }


# debugging
#conf['attn_heads'] = 1
#conf['d_model'] = 1
#conf['bptt_len'] = 2
#conf['datafile'] = './dummy_data.txt'  
#conf['num_blocks_decoder'] = 1
#conf['minibatch_size'] = 1
#conf['epochs_per_loop'] = 1


# Make sure d_model, heads, and d_key are compatible
assert conf['d_model'] % conf['attn_heads'] == 0, \
    f'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], 
                                                         conf['d_model'])
conf['d_key'] = conf['d_model'] / conf['attn_heads']

# Set up the RNGs for repeatability
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])
    
    
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))

device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))

print()

# I use this bare FIXME:
bptt_len = conf['bptt_len']

# Logging
#wandb = None
wandb.init(project="my-transformer", config=conf)

## Model Architecture

### Section 3.2.1: Scaled Dot-Product Attention

In [None]:
class AttentionHead(nn.Module):
    """Implements section 3.2.1"""
    def __init__(self, mask=False, d_model=conf['d_model'], d_key=conf['d_key'], 
                 bptt_len=conf['bptt_len']):
        super().__init__()

        self.d_model = d_model
        self.d_key = d_key
        self.bptt_len = bptt_len

        if mask:
            t_mask = torch.ones([bptt_len, bptt_len]).tril()
            self.register_buffer('mask', t_mask)
        else:
            self.mask = None
        
        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, queries, keys, values):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) 
        scores = scores / sqrt(self.d_key)

        # Filter out attention to future positions
        if self.mask is not None:
            this_mask = self.mask[:scores.shape[1], :scores.shape[2]]
            scores.masked_fill_(this_mask == 0, float('-inf'))
            
        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)

        return attn

### 3.2.2 Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    "Section 3.2.2"
    def __init__(self, mask=False, d_model=conf['d_model'], 
                 heads=conf['attn_heads'], bptt_len=conf['bptt_len']):
        super().__init__()
        d_key = int(d_model / heads)

        attn_heads = [AttentionHead(mask=mask, d_model=d_model, d_key=d_key, bptt_len=bptt_len) 
                      for _ in range(heads)]
        self.attn_heads = nn.ModuleList(attn_heads)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values):
        head_attns = [h(queries=queries, keys=keys, values=values) 
                      for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        ret = self.Wo(head_attn)
        return ret

### 3.3 Position-wise Feed-Forward Networks

In [None]:
class FFN(nn.Module):
    "Section 3.3"
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff, bias=False), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model, bias=False))

    def forward(self, x):
        return self.ffn(x)

### 3.1 Encoder and Decoder Stacks

In [None]:
class EncoderBlock(nn.Module):
    "Section 3.1, Encoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout']):
        super().__init__()
        
        self.attn = MultiHeadAttention(mask=False, d_model=d_model, heads=heads, bptt_len=bptt_len)
        self.attn_drop = nn.Dropout(p=dropout)
        self.attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.attn(x, x, x)
        a1 = self.attn_drop(a1)
        a1 = self.attn_norm(x + a1) 

        a2 = self.ffn(a1)
        a2 = self.ffn_drop()
        a2 = self.ffn_norm(a1 + a2)

        return a2

In [None]:
class DecoderBlock(nn.Module):
    "Section 3.1, Decoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout']):
        super().__init__()
                
        self.self_attn = MultiHeadAttention(mask=True, d_model=d_model, heads=heads, bptt_len=bptt_len)
        self.self_attn_drop = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.enc_attn = MultiHeadAttention(mask=False, d_model=d_model, heads=heads, bptt_len=bptt_len)
        self.enc_attn_drop = nn.Dropout(p=dropout)
        self.enc_attn_norm = nn.LayerNorm(d_model)

        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out):
        a1 = self.self_attn(x, x, x)
        a1 = self.self_attn_drop(a1)
        a1 = x + a1  # residual
        a1 = self.self_attn_norm(a1) 
        
        a2 = self.enc_attn(a1, encoder_out, encoder_out)
        a2 = self.enc_attn_drop(a2)
        a2 = a1 + a2  # residual
        a2 = self.enc_attn_norm(a2)

        a3 = self.ffn(a2)
        a3 = self.ffn_drop(a3)
        a3 = a2 + a3  # residual
        a3 = self.ffn_norm(a3)
        
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, bptt_len, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, bptt_len, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, encoder_out, decoder_in):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out)
        return a

### Section 3 Model Architecture

In [None]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab_len, 
                 pad_index,
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout']):
        super().__init__()
        
        #vocab_len = 1  # FIXME
        #pad_index = 0  # FIXME

        self.d_model = d_model
        self.bptt_len = bptt_len
        self.vocab_len = vocab_len
        self.pad_index = pad_index
  
        self.embedding = nn.Embedding(vocab_len, d_model, padding_idx=pad_index)
        self.register_buffer('position_encoding', self._position_encoding())
        #self.embed_drop = nn.Dropout(p=dropout)
                                            
        #self.encoder = Encoder(d_model, heads, bptt_len, num_blocks_encoder, dropout)
        self.decoder = Decoder(d_model, heads, bptt_len, num_blocks_decoder, dropout)

        self.linear = nn.Linear(d_model, vocab_len, bias=False)
        self.linear_softmax = nn.Softmax(dim=-1)
           
    def _position_encoding(self):
        d_model = self.d_model
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(self.bptt_len)]
        stack = torch.stack(rows, dim=1)
        
        return stack.T
    
    def embed(self, indices):
        """Implements the embedding from Section 3.4 Embeddings and Softmax"""
        this_bptt_len = indices.shape[-1]
        pe = self.position_encoding[:this_bptt_len, :]

        embedded = self.embedding(indices)
        
        #print('pe.shape:', pe.shape)
        #print('embedded.shape:', embedded.shape)
        return pe + embedded
        
    #def forward(self, encoder_in, encoder_out=None, decoder_in=[]):
    def forward(self, encoder_out, decoder_in, pos=None, pre_embedded=False):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        encoder_out: (optional rank-1 tensor) passing this skips 
                     the encoder execution, and acts and if this were 
                     the indices the encoder produced.
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""

        # Embed
        if pre_embedded:
            eo_embedded = encoder_out
            di_embedded = decoder_in
        else:
            eo_embedded = self.embed(encoder_out) 
            di_embedded = self.embed(decoder_in)
        
        # Encode
        #encoded = self.encoder(self.embed(encoder_in))
        encoded = eo_embedded
        
        # Decode
        decoded = self.decoder(encoded, di_embedded)

        # Return predictions for next token
        if pos is not None:
            decoded = decoded[:, pos, :]
        
        y_pred = self.linear(decoded)
        
        if conf['final_softmax']:
            y_pred = self.linear_softmax(y_pred)        
        
        return y_pred


## Load Data and Build the Model

In [None]:
# dataloader and vocab
#train_ds = load_dataset()
dataloader = getattr(torchtext.datasets, conf['dataset'])
train_ds, val_ds, test_ds = dataloader.iters(batch_size=conf['minibatch_size'], 
                                             bptt_len=2 * conf['bptt_len'], 
                                             device=device)
vocab = train_ds.dataset.fields['text'].vocab

pad_token = '_'
pad_index = vocab.stoi[pad_token]

In [None]:
# Create the model
model = Transformer(len(vocab), pad_index)
model = model.to(device)
model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])

# Define the Loss
#criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
criterion = nn.CrossEntropyLoss()

### Validate the Model

## Training Helper Functions

In [None]:
def get_minibatches(dataset=train_ds):
    bptt_len = conf['bptt_len']
    for batch in dataset:
        #print('batch:', batch)
        if batch.text.shape[0] < bptt_len + 1:
            continue
        eo = batch.target.T[:,:bptt_len].contiguous()
        y = batch.target.T[:,bptt_len:].contiguous()
        di = batch.text.T[:,bptt_len:].contiguous()
        di[:,0] = vocab.stoi['?']                                 
        yield eo, di, y

        
def accuracy(output, expected_indices):
    indices = torch.max(output, dim=-1)[1]
    indices = indices.squeeze()
    acc = (indices == expected_indices) / float(indices.numel())
    acc = float(acc.sum())
    return acc

def run_minibatch(encoder_out, decoder_in, y, optimizer):
    """Runs one minibatch training and returns the loss and accuracy for that minibatch"""
    optimizer.zero_grad()
    y_pred = model(encoder_out=encoder_out, decoder_in=decoder_in)
    acc = accuracy(y_pred, y)
    loss = criterion(y_pred.transpose(-2, -1), y)
    loss.backward()  # Not sure why, but this step logs a UserWarning
    optimizer.step()
    return loss.item(), acc

def test_set_accuracy(model):
    with torch.no_grad():
        eval_model = model.eval()
        accuracies = []
        minibatches = 0
        for encoder_out, decoder_in, y in get_minibatches(test_ds):
            y_pred = model(encoder_out=encoder_out, decoder_in=decoder_in)
            accuracies.append(accuracy(y_pred, y))
            minibatches += 1
    acc = 100 * tensor(accuracies, device=device).float().mean().item()
    return acc
            
def do_epoch(epoch, optimizer, model, bptt_len=conf['bptt_len']):
    """Runs one full training batch and returns the average loss,
    accuracy, and duration time in seconds"""
    model = model.train()
    t0 = time.time()
    losses = []
    train_accuracies = []
    for encoder_out, decoder_in, y in get_minibatches():
        #print('eo.shape:', encoder_out.shape, 'di.shape', decoder_in.shape, 'y.shape:', y.shape)
        loss, train_acc = run_minibatch(encoder_out, decoder_in, y, optimizer) 
        losses.append(loss)
        train_accuracies.append(train_acc)
    #losses = [run_minibatch(*args) for args in get_minibatches(train_ds)]
    tf = time.time()
    if losses:
        avg_loss = tensor(losses, device=device).float().mean().item()
        avg_train_accuracy = 100 * tensor(train_accuracies, device=device).float().mean().item()
    else:
        avg_loss = 0
        avg_train_accuracy = 0
    avg_test_accuracy = test_set_accuracy(model)
    return (avg_loss, avg_train_accuracy, avg_test_accuracy, tf-t0)

def train(optimizer, num_epochs=conf['epochs_per_loop'], start_epoch=0, model=model, 
          vocab=vocab, criterion=criterion):
    """Runs num_epochs training batches and prints out results"""
    for epoch in range(start_epoch, start_epoch+num_epochs):
        loss, train_accuracy, test_accuracy, seconds = do_epoch(epoch, optimizer, model)
        if wandb:
            wandb.log({'epoch': epoch,
                       'loss': loss,
                       'train_accuracy': train_accuracy,
                       'test_accuracy': test_accuracy,
                       'seconds': seconds})
        print('epoch:', epoch, '(%.1fs)' % seconds, 'loss=%f' % loss, 'train_accuracy=%.1f%%' % (train_accuracy), 'test_accuracy=%.1f%%' % (test_accuracy))   
    return epoch + 1

## Train the Model

In [None]:
# Define the Optimizer
#optimizer_class = getattr(torch.optim, conf['optimizer']) 
#lr = conf['learning_rate']
#optimizer = optimizer_class(model.parameters(), lr=lr)

In [None]:
epoch = 1

In [None]:
#lr = conf['learning_rate'] * 10 
warmup_steps = conf['warmup_steps']

for _ in range(conf['total_training_loops']):
    #lr = lr/10
    lr = (conf['d_model']**-.5) * min(epoch**-.5, epoch * (warmup_steps**-1.5))
    optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9, lr=lr)
    #print('lr = %.4f' % lr)
    epoch = train(optimizer=optimizer, num_epochs=conf['epochs_per_loop'], start_epoch=epoch)
    save_file_name = './my-transformer_%s_%s-layer_%s-epochs.pt' % (conf['dataset'], conf['num_blocks_decoder'], epoch)
    torch.save(model.module.state_dict(), save_file_name)


In [None]:
#torch.save(model.module.state_dict(), './my-transformer-wikitext-2-test-21.4_pct.pt')

## Sampling Helper Functions

In [None]:
unk_token = '<unk>'
pad_token = '<pad>'
eos_token = '<eos>'

In [None]:
def numericalize(tokens):
    """Takse a string and returns a tensor of vocab indices for the tokens"""
    indices = list([vocab.stoi[t] for t in tokens])
    return torch.tensor([indices]).to(device)

def tokenize(indices):
    "Takes a tensor of token indices and returns a string"
    tokens = [vocab.itos[i] for i in indices.squeeze()]
    return ' '.join(tokens)

def get_next_token(encoder_out, decoder_in, pos, model=model, deterministic=False):
    """Runs one step of auto-regression, returning the output token for
    position `pos`."""
    
    decoder_out = model(encoder_out=encoder_out, decoder_in=decoder_in)
    
    if deterministic:
        _, indices = torch.max(decoder_out, dim=-1)
    else:
        probs = nn.functional.softmax(decoder_out.float(), dim=-1)
        m = torch.distributions.multinomial.Multinomial(probs=probs)
        _, indices = torch.max(m.sample(), dim=-1)

    next_index = int(indices[0,pos])
    return next_index, vocab.itos[next_index]

def sample(prompt, deterministic=False, vocab=vocab, prnt=True):
    """Auto-regresses using prompt to create the encoder_out tensor"""
    bptt_len = conf['bptt_len']
    prompt_tokens = prompt.split()
    assert len(prompt_tokens) == bptt_len, 'Prompt strings must be %s tokens long' % bptt_len    
    with torch.no_grad():
        eval_model = model.eval()

        encoder_out = numericalize(prompt_tokens)
        decoder_in = numericalize([unk_token] + ([pad_token] * (bptt_len - 1)))
        out = []
        #print('eo = ', encoder_out)
        #print('eo.shape = ', encoder_out.shape)
        #print('di = ', decoder_in)
        #print('di.shape = ', decoder_in.shape)

        next_token = None
        next_index = None
        for pos in range(bptt_len):
            next_index, next_token = get_next_token(encoder_out, decoder_in, pos=pos, 
                                                    model=eval_model, deterministic=deterministic)
            if next_token in (eos_token, pad_token):
                break
            if next_token is not None:
                out.append(next_token)
                if pos+1 < bptt_len:
                    decoder_in[0, pos+1] = next_index
        
    out = ' '.join(out)
    if prnt:
        print(prompt + '\n --> \n' + out)
    return out

## Sample the model

In [None]:
prompt_tokens = """Robert <unk> is an English film , television and theatre actor . He had a guest @-@ starring role on the television series The Bill in 2000 . This was followed by a starring role in the play Herons written by Simon Stephens , which was performed in 2001 at the Royal Court Theatre . He had a guest role in the television series Judge John <unk> in 2002 . In 2004 <unk> landed a role as " Craig " in the episode " Teddy 's Story " of the television series The Long Firm ; he starred alongside actors Mark Strong and Derek Jacobi . He was cast in the 2005 theatre productions of the Philip Ridley play Mercury Fur , which was performed at the Drum Theatre in Plymouth and the <unk> <unk> Factory in London . He was directed by John <unk> and starred alongside Ben <unk> , Shane <unk> , Harry Kent , Fraser <unk> , Sophie Stanton and Dominic Hall .""".split()
bptt_prompt = prompt_tokens[:bptt_len]
prompt = ' '.join(bptt_prompt)
#print(prompt)
#print(len(prompt.split()))

for _ in range(10):
    print('--------------')
    out = sample(' '.join(prompt_tokens[:bptt_len]))    

In [None]:
test_set_accuracy(train_ds)

In [None]:
prompt = """Born in Omaha , Nebraska , Malcolm X spent his teenage years living in a series of foster homes after his father 's death and his mother 's hospitalization . He engaged in several illicit activities there , eventually being"""

for _ in range(5):
    sample(prompt)

In [None]:
prompt = """There was the problematization of madness and illness arising out of social and medical practices , and defining a certain pattern of “ normalization “ ; a problematization of life , language , and labor in discursive practices that conformed"""

for _ in range(5):
    print("========================")
    sample(prompt)