# My Transformer

I wrote this to help me understand _Attention Is All You Need_: https://arxiv.org/abs/1706.03762

I cut out the Encoder and I'm using it to generate English words

In [None]:
from collections import defaultdict
from math import sqrt, sin, cos
import os
import random
import sys
import time
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2

## Setup
 

In [None]:
conf = {
        'attn_heads': 4,
        'bptt_len': 25,
        #'cuda_device_ids': [3, 2, 1, 0],  # I need better GPU coolng first
        'cuda_device_ids': [3],
        'd_model': 20,
        #'datafile': './city_names.txt', # from: https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068
        'datafile': './corncob_lowercase.txt',  # from: http://www.mieliestronk.com/corncob_lowercase.txt
        'dropout': 0.1,
        'learning_rate': 0.1,
        'epochs_per_loop': 50,
        'total_training_loops': 30,
        'num_blocks_encoder': 1,
        'num_blocks_decoder': 2,
        'minibatch_size': 45000,
        #'optimizer': 'Adam'  # Adam gives me nans. Not sure why yet.
        'optimizer': 'SGD',
        'random_seed': 0,
        }

# Make sure d_model, heads, and d_key are compatible
conf['d_key'] = conf['d_model'] / conf['attn_heads']
assert conf['d_key'] == int(conf['d_key']), \
    f'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], 
                                                         conf['d_model'])

# Set up the RNGs for repeatability
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])

# Logging
wandb.init(project="my-transformer", config=conf)

## Model Architecture

### Section 3.2.1: Scaled Dot-Product Attention

In [None]:
class AttentionHead(nn.Module):
    """Implements section 3.2.1"""
    def __init__(self, mask=True, d_model=conf['d_model'], d_key=conf['d_key'], 
                 bptt_len=conf['bptt_len']):
        super().__init__()
        self.d_model = d_model
        self.d_key = d_key
        self.bptt_len = bptt_len

        if mask:
            ninf = np.NINF * torch.ones([bptt_len, bptt_len])
            self.mask = nn.Parameter(ninf.triu(1), requires_grad=False)
        else:
            self.mask = None
        
        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, queries, keys, values):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) 
        scores = scores / sqrt(self.d_key)

        # Filter out attention to future positions
        if self.mask is not None:
            scores = scores.tril() + self.mask

        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)

        return attn

### 3.2.2 Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    "Section 3.2.2"
    def __init__(self, mask=False, d_model=conf['d_model'], 
                 heads=conf['attn_heads'], bptt_len=conf['bptt_len']):
        super().__init__()
        d_key = int(d_model / heads)

        attn_heads = [AttentionHead(mask, d_model, d_key, bptt_len) 
                      for _ in range(heads)]
        self.attn_heads = nn.ModuleList(attn_heads)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values):
        head_attns = [h(queries=queries, keys=keys, values=values) 
                      for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        return self.Wo(head_attn)

### 3.3 Position-wise Feed-Forward Networks

In [None]:
class FFN(nn.Module):
    "Section 3.3"
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model))

    def forward(self, x):
        return self.ffn(x)

### 3.1 Encoder and Decoder Stacks

In [None]:
class EncoderBlock(nn.Module):
    "Section 3.1, Encoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(False, d_model, heads, bptt_len)
        self.self_attn_drop = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.self_attn_drop(self.self_attn(x, x, x))
        a2 = self.self_attn_norm(x + a1)
        a3 = self.ffn_norm(a2 + self.ffn_drop(self.ffn(a2)))

        return a3

In [None]:
class DecoderBlock(nn.Module):
    "Section 3.1, Decoder"
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
                
        self.self_attn = MultiHeadAttention(True, d_model, heads, bptt_len)
        
        self.self_attn_drop = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.enc_attn = MultiHeadAttention(mask_all, d_model, heads, bptt_len)
        self.enc_attn_drop = nn.Dropout(p=dropout)
        self.enc_attn_norm = nn.LayerNorm(d_model)

        self.ffn = FFN(d_model)
        self.ffn_drop = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out):
        a1_attn = self.self_attn(x, x, x)
        a1_drop = self.self_attn_drop(a1_attn)
        a1 = self.self_attn_norm(x + a1_drop)
        a2_attn = self.enc_attn(a1, encoder_out, encoder_out)
        a2_drop = self.enc_attn_drop(a2_attn)
        a2 = self.enc_attn_norm(a1 + a2_drop)
        a3 = self.ffn_norm(a2 + self.ffn_drop(self.ffn(a2)))
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, bptt_len, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, bptt_len, dropout) 
                                     for _ in range(num_blocks)])
            
    def forward(self, encoder_out, decoder_in):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out)
        return a

### Section 3 Model Architecture

In [None]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout']):
        super().__init__()

        self.vocab = vocab
        self.d_model = d_model
        self.bptt_len = bptt_len
        
        pad_idx = vocab.stoi['<pad>']
        self.embedding = nn.Embedding(len(vocab), d_model, padding_idx=pad_idx)
        self.position_encoding = nn.Parameter(self._position_encoding(), 
                                              requires_grad=False)
        self.embed_dropout = nn.Dropout(p=dropout)
                                            
        #self.encoder = Encoder(d_model, heads, bptt_len, num_blocks_encoder, 
        #                       dropout)
        self.decoder = Decoder(d_model, heads, bptt_len, num_blocks_decoder, 
                               dropout)

        self.linear = nn.Linear(d_model, len(self.vocab))
        self.linear_softmax = nn.Softmax(dim=-1)
           
    def _position_encoding(self):
        d_model = self.d_model
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(self.bptt_len)]
        stack = torch.stack(rows, dim=1)
        
        return stack.T
    
    def embed(self, indices):
        """Implements the embedding from Section 3.4 Embeddings and Softmax"""
        embedded = self.embedding(tensor(indices))
        return embedded + self.position_encoding
        
    #def forward(self, encoder_in=None, encoder_out=None, decoder_in=[]):
    def forward(self, encoder_out, decoder_in):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        encoder_out: (optional rank-1 tensor) passing this skips 
                     the encoder execution, and acts and if this were 
                     the indices the encoder produced.
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""
        
        # Encode
        #encoder_out = self.embed_dropout(self.embed(embedded))
        encoder_out = self.embed(encoder_out)
        
        # Decode
        decoder_in = self.embed(decoder_in)
        decoder_out = self.decoder(encoder_out, decoder_in)

        # Return predictions for next token
        y_pred = self.linear_softmax(self.linear(decoder_out))
        return y_pred


## Build the Vocab and Model

In [None]:
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))
device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))
print()
sys.stdout.flush()

# Make the vocabulary
with open(conf['datafile'], 'r') as f:
    strings = [line.strip().lower() for line in f.readlines()]
    token_iterator = ''.join(strings)
    vocab = torchtext.vocab.build_vocab_from_iterator(token_iterator)
    vocab.stoi['<pad>'] = 0
    vocab.itos[0] = '<pad>'
    vocab.stoi['<eos>'] = 1
    vocab.stoi[1] = '<eos>'
    vocab.freqs['<eos>'] = len(strings)

# Create the model
model = Transformer(vocab)
#model = model.half()
model = model.to(device)
model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])

# Define the Optimizer
optimizer_class = getattr(torch.optim, conf['optimizer']) 
optimizer = optimizer_class(model.parameters(), lr=conf['learning_rate'])

# Define the Loss
#CE_freqs = [float(vocab.freqs[t]) for t in vocab.itos]
#CE_weight = [(0. if f == 0 else 1/f) for f in CE_freqs]
#CE_weight = torch.tensor(CE_weight, dtype=torch.half, device=device)
#pad_idx = vocab.stoi['<pad>']
#criterion = nn.CrossEntropyLoss(weight=CE_weight, ignore_index=pad_idx)
criterion = nn.CrossEntropyLoss()

## Helper Functions

In [None]:
data = defaultdict(list)

In [None]:
def pad_indices(indices, right_shift=False, bptt_len=conf['bptt_len']):
    """Takes a list of token `indices`, appends the index for the
    <eos> token and pads the list on the right to the bptt_len length.
    If you pass `right_shift=True`, then it also inserts the index for 
    the pad token to the beginning of the list and shifts everything 
    else to the right (still maintaining bptt_len."""
    indices = list(map(int, indices))
    eos_index = vocab.stoi['<eos>']
    pad_index = vocab.stoi['<pad>']
    if (not indices) or (indices[-1] != eos_index):
        indices.append(eos_index)
    if right_shift:
        indices.insert(0, pad_index)
    indices = indices[:bptt_len]
    pad_len = bptt_len - len(indices)
    indices += [pad_index] * pad_len
    return indices
    
def get_indices(string, vocab=vocab, bptt_len=conf['bptt_len']):
    """takes a string, tokenizes it, and returns the (unpadded) 
    list of token incides for the tokens in the string. The output
    of this method is suitable input for `pad_indices()`"""
    tokens = list(string.strip().lower())
    tokens = tokens[:bptt_len]
    indices = list(map(lambda x: vocab.stoi[x], tokens))
    return indices

def _get_tensors(string, model=model, vocab=vocab, criterion=criterion, 
                 optimizer=optimizer, bptt_len=conf['bptt_len']):
    """Takes a string, splits it into two parts at each position,
    and returns lists of the appropriately padded and shifted 
    `encoder_out`, `decoder_in`, and `y` tensors for each split."""

    indices = get_indices(string, vocab)
    encoder_out = []
    decoder_in = []
    y = []
    for i in range(len(indices)):
        this_enc_out = tensor(pad_indices(indices[:i])).unsqueeze(0)
        this_dec_in = tensor(pad_indices(indices[i:], 
                                         right_shift=True)).unsqueeze(0)
        this_y = tensor(pad_indices(indices[i:])).unsqueeze(0)
        encoder_out.append(this_enc_out)
        decoder_in.append(this_dec_in)
        y.append(this_y)
    return encoder_out, decoder_in, y

def get_data(minibatch_size=conf['minibatch_size'], vocab=vocab, 
             data_file=conf['datafile']):
    """Reads the \n separated training strings from the data file,
    and returns a generator of `encoder_out`, `encoder_in`, and `y` tensors. 
    Each tensor contains vocab indices for tokens and has a shape:
    (minibatch_size, bptt_len)"""
    
    global data

    if not data:
        # Cache this in memory so we don't have to re-read it. My datasets
        # are small enough for now to fit in ram easily
        with open(data_file,'r') as f:
            strings = [line.strip().lower() for line in f.readlines()]

        for string in strings:
            encoder_out, decoder_in, y = _get_tensors(string)
            data['encoder_out'].extend(encoder_out)
            data['decoder_in'].extend(decoder_in)
            data['y'].extend(y)
        data['encoder_out'] = torch.cat(data['encoder_out']).to(device)
        data['decoder_in'] = torch.cat(data['decoder_in']).to(device)
        data['y'] = torch.cat(data['y']).to(device)

    eo_split = torch.split(data['encoder_out'], minibatch_size, dim=0)
    di_split = torch.split(data['decoder_in'], minibatch_size, dim=0)
    y_split = torch.split(data['y'], minibatch_size, dim=0)
    for encoder_out, decoder_in, y_split in zip(eo_split, di_split, y_split):
        yield encoder_out, decoder_in, y_split 
        
def train_data(encoder_out, decoder_in, y):
    """Runs one minibatch training and returns the loss for that minibatch"""
    optimizer.zero_grad()
    y_pred = model(encoder_out=encoder_out, decoder_in=decoder_in)
    y_pred = torch.transpose(y_pred, -2, -1)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    #sys.stdout.write('.')
    return loss.item()

def do_epoch(epoch, model=model, vocab=vocab, criterion=criterion, 
             optimizer=optimizer):
    """Runs one full training batch and returns the average loss and 
    duration time in seconds"""
    t0 = time.time()
    losses = [train_data(*args) for args in get_data()]
    tf = time.time()
    avg_loss = sum(losses) / len(losses)
    return (avg_loss, tf-t0)

def train(num_epochs=conf['epochs_per_loop'], start_epoch=0, model=model, 
          vocab=vocab, criterion=criterion, optimizer=optimizer):
    """Runs num_epochs training batchs and prints out results"""
    model = model.train()
    for epoch in range(start_epoch, start_epoch+num_epochs):
        loss, seconds = do_epoch(epoch, model, vocab, criterion, optimizer)
        #print()
        wandb.log({'epoch': epoch,
                   'loss': loss,
                   'seconds': seconds})
        print('epoch:', epoch, '(%.1fs)' % seconds, 'loss=%.4f' % loss)    
    return epoch + 1


#induction_weight = CE_weight.unsqueeze(dim=0)

def get_next_token(encoder_out, decoder_in, pos, model=model, vocab=vocab):
    """Runs one step of auto-regression, returning the output token for
    position `pos`. Uses a Multinomial distribution for sampling."""
    eval_model = model.eval()
    
    encoder_out = pad_indices(get_indices(encoder_out))
    encoder_out = tensor(encoder_out).to(device)
    
    decoder_in = pad_indices(get_indices(decoder_in), right_shift=True)
    decoder_in = tensor(decoder_in).to(device)
    
    decoder_out = eval_model(encoder_out=encoder_out, decoder_in=decoder_in)
    #decoder_out = decoder_out * induction_weight
    
    probs = nn.functional.softmax(decoder_out.float(), dim=-1)
    m = torch.distributions.multinomial.Multinomial(probs=probs)
    _, indices = torch.max(m.sample(), dim=-1)
    
    next_index = int(indices[pos])
    return vocab.itos[next_index]

def sample_with_prompt(prompt, model=model, vocab=vocab, bptt_len=conf['bptt_len']):
    """samples a string the network beginning with prompt"""
    encoder_out = prompt
    decoder_out = ''
    next_token = ''
    
    while next_token not in ('<pad>', '<eos>') and len(decoder_out) < (bptt_len - 1):
        decoder_out += next_token
        next_token = get_next_token(encoder_out, decoder_out, 
                                    pos=len(decoder_out))
        
    return prompt + decoder_out

def alphabet_sample():
    """samples strings beginning with each letter of the alphabet"""
    for p in [chr(i) for i in range(ord('a'), ord('z')+1)]: 
        print('%r' % sample_with_prompt(p))

## Train the Model

In [None]:
epoch = 0

In [None]:
alphabet_sample()

for _ in range(conf['total_training_loops']):
    epoch = train(conf['epochs_per_loop'], epoch)
    alphabet_sample()