# My Transformer

In [None]:
from collections import defaultdict
from math import sqrt, sin, cos
import os
import random
import sys
import time
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2

## Setup
 
Pulling list of cities from: 
https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068

In [None]:
conf = {
        'attn_heads': 3,
        'bptt_len': 20,
        #'cuda_device_ids': [3, 2, 1, 0],
        'cuda_device_ids': [3],
        'd_model': 9,
        'device': 'cuda',
        'datafile': './city_names.txt',
        'dropout': 0.1,
        'learning_rate': 0.1,
        'max_epochs': 100,
        #'num_blocks_encoder': 6,
        'num_blocks_decoder': 3,
        #'optimizer': 'Adam',
        'optimizer': 'SGD',
        'random_seed': 0,
           
        #'batch_size': 400,
        #'dataset': 'imagenette2-320',
        #'init_gain': 5,
        #'initializer': None,
        #'load_workers': os.cpu_count(), 
        #'training_loops': 4,
        #'cuda_device_ids': [0, 1, 2],
        #'num_hidden_nodes': 300,
        }

# Make sure d_model, heads, and d_key are compatible
conf['d_key'] = conf['d_model'] / conf['attn_heads']
assert conf['d_key'] == int(conf['d_key']), 'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], conf['d_model'])

# Set up the RNGs
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])

# Logging
wandb.init(project="my-transformer", config=conf)

## Model Architecture

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, mask=True, d_model=conf['d_model'], d_key=conf['d_key'], bptt_len=conf['bptt_len']):
        super().__init__()
        self.d_model = d_model
        self.d_key = d_key
        self.bptt_len = bptt_len

        if mask:
            self.mask = nn.Parameter((np.NINF * torch.ones([bptt_len, bptt_len])).triu(1), requires_grad=False)
        else:
            self.mask = None
        
        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, queries, keys, values):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) / sqrt(self.d_key) #shape = (heads, bptt_len, bptt_len)
        #assert scores.shape == torch.Size([self.bptt_len, self.bptt_len])

        # Filter out attention to future positions
        if self.mask is not None:
            scores = scores.tril() + self.mask

        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)
        #assert attn.shape == torch.Size([self.bptt_len, self.d_key])

        return attn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, mask=False, d_model=conf['d_model'], heads=conf['attn_heads'], bptt_len=conf['bptt_len']):
        super().__init__()
        d_key = int(d_model / heads)

        self.attn_heads = nn.ModuleList([AttentionHead(mask, d_model, d_key, bptt_len) for _ in range(heads)])
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values):
        head_attns = [h(queries=queries, keys=keys, values=values) for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        return self.Wo(head_attn)

In [None]:
class FFN(nn.Module):
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model))

    def forward(self, x):
        return self.ffn(x)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(False, d_model, heads, bptt_len)
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.ffn_norm(a1 + self.ffn_dropout(self.ffn(a1)))

        return a2

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
                
        self.self_attn = MultiHeadAttention(True, d_model, heads, bptt_len)
        
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.enc_attn = MultiHeadAttention(mask_all, d_model, heads, bptt_len)
        self.enc_attn_dropout = nn.Dropout(p=dropout)
        self.enc_attn_norm = nn.LayerNorm(d_model)

        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.enc_attn_norm(a1 + self.enc_attn_dropout(self.enc_attn(a1, encoder_out, encoder_out)))
        a3 = self.ffn_norm(a2 + self.ffn_dropout(self.ffn(a2)))
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, encoder_out, decoder_in):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out)
        return a

In [None]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout']):
        super().__init__()

        self.vocab = vocab
        self.d_model = d_model
        self.bptt_len = bptt_len
        
        self.embedding = nn.Embedding(len(vocab), d_model, padding_idx=vocab.stoi['<pad>'])
        self.position_encoding = nn.Parameter(self._position_encoding(), requires_grad=False)
        self.embed_dropout = nn.Dropout(p=dropout)
                                            
        #self.encoder = Encoder(d_model, heads, bptt_len, num_blocks_encoder, dropout)
        self.decoder = Decoder(d_model, heads, bptt_len, num_blocks_decoder, dropout)

        self.linear = nn.Linear(d_model, len(self.vocab))
        #self.linear.weight = self.embedding.weight  # Section 3.4
        self.linear_dropout = nn.Dropout(p=dropout)
        self.linear_norm = nn.LayerNorm(len(vocab))
           
    def _position_encoding(self):
        d_model = self.d_model
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(self.bptt_len)]
        stack = torch.stack(rows, dim=1)
        assert stack.shape == torch.Size([self.d_model, self.bptt_len])
        
        return stack.T
    
    def embed(self, indices):
        embedded = self.embedding(tensor(indices))
        #assert embedded.shape == torch.Size([self.bptt_len, self.d_model])
        return embedded + self.position_encoding
        
    #def forward(self, encoder_in=None, encoder_out=None, decoder_in=[]):
    def forward(self, encoder_out, decoder_in):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        encoder_out: (optional rank-1 tensor) passing this skips 
                     the encoder execution, and acts and if this were 
                     the indices the encoder produced.
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""
        
        # Embed
        #embedded = self.embed_dropout(self.embed(encoder_in))
            
        # Encode
        #encoder_out = self.embed_dropout(self.embed(embedded))
        
        # Decode
        encoder_out = self.embed(encoder_out)
        #print('encoder_out = ', encoder_out)
        decoder_in = self.embed(decoder_in)
        #print('decoder_in = ', decoder_in)
        decoder_out = self.decoder(encoder_out, decoder_in)
        #print('decoder_out = ', decoder_out)
        #print('decoded:', decoded)

        # Return predictions for next token
        y_pred = self.linear(decoder_out)
        return self.linear_norm(self.linear_dropout(y_pred))
        #y_pred = torch.matmul(decoded, self.embedding.weight.T)

## Build the Vocab and Model

In [None]:
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))

device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))

# Make the vocabulary
with open(conf['datafile'], 'r') as f:
    vocab = torchtext.vocab.build_vocab_from_iterator(f.read().replace('\n','').lower())
    vocab.stoi['<pad>'] = 0
    vocab.stoi['<eos>'] = 1
    vocab.itos[0] = '<pad>'
    vocab.stoi[1] = '<eos>'
with open(conf['datafile'], 'r') as f:
    vocab.freqs['<eos>'] = len(f.readlines())

# define the model
model = Transformer(vocab)
model = model.half().to(device)
model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])
optimizer = getattr(torch.optim, conf['optimizer'])(model.parameters(), lr=conf['learning_rate'])

CE_freqs = [float(vocab.freqs[t]) for t in vocab.itos]
CE_weight = [(0. if f == 0 else 1/f) for f in CE_freqs]
CE_weight = torch.tensor(CE_weight, dtype=torch.half, device=device)
criterion = nn.CrossEntropyLoss(weight=CE_weight, ignore_index=vocab.stoi['<pad>'])

## Helper Functions

In [None]:
data = defaultdict(list)

def pad_indices(indices, right_shift=False, bptt_len=conf['bptt_len']):
    indices = list(map(int, indices))
    eos_index = vocab.stoi['<eos>']
    pad_index = vocab.stoi['<pad>']
    if (not indices) or (indices[-1] != eos_index):
        indices.append(eos_index)
    if right_shift:
        indices.insert(0, pad_index)
    indices = indices[:bptt_len]
    pad_len = bptt_len - len(indices)
    indices += [pad_index] * pad_len
    return indices
    
#def get_indices(string, vocab, max_tokens=conf['bptt_len'], include_shifted=False):
def get_indices(string, vocab=vocab, bptt_len=conf['bptt_len']):
    tokens = list(string.strip().lower())
    tokens = tokens[:bptt_len]
    indices = list(map(lambda x: vocab.stoi[x], tokens))
    return indices

def _make_filter(shape):
    return 

def _get_tensors(string, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer, bptt_len=conf['bptt_len']):
    indices = get_indices(string, vocab)
    encoder_out = []
    decoder_in = []
    y = []
    for i in range(len(indices)):
        encoder_out.append(tensor(pad_indices(indices[:i])).unsqueeze(0))
        decoder_in.append(tensor(pad_indices(indices[i:], right_shift=True)).unsqueeze(0))
        y.append(tensor(pad_indices(indices[i:])).unsqueeze(0))
    return encoder_out, decoder_in, y

def get_data(minibatch_size=500, vocab=vocab, data_file=conf['datafile']):
    global data
    if not data:
        with open(data_file,'r') as f:
            strings = [line.strip().lower() for line in f.readlines()]
        for string in strings:
            encoder_out, decoder_in, y = _get_tensors(string)
            data['encoder_out'].extend(encoder_out)
            data['decoder_in'].extend(decoder_in)
            data['y'].extend(y)
        data['encoder_out'] = torch.cat(data['encoder_out']).to(device)
        data['decoder_in'] = torch.cat(data['decoder_in']).to(device)
        data['y'] = torch.cat(data['y']).to(device)

    batch_size = data['y'].shape[0]
    i = 0
    while i < batch_size:
        j = i + minibatch_size
        encoder_out = data['encoder_out'][i:j,:]
        decoder_in = data['decoder_in'][i:j,:]
        y = data['y'][i:j,:]
        yield encoder_out, decoder_in, y
        i = j
        
def train_data(encoder_out, decoder_in, y):
    optimizer.zero_grad()
    y_pred = model(encoder_out=encoder_out, decoder_in=decoder_in)
    y_pred = torch.transpose(y_pred, -2, -1)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    #sys.stdout.write('.')
    return loss.item()

def do_epoch(epoch, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer):
    t0 = time.time()
    losses = [train_data(*args) for args in get_data()]
    tf = time.time()
    avg_loss = sum(losses) / len(losses)
    return (avg_loss, tf-t0)

def train(num_epochs=conf['max_epochs'], start_epoch=0, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer):
    for epoch in range(start_epoch, start_epoch+num_epochs):
        loss, seconds = do_epoch(epoch, model, vocab, criterion, optimizer)
        #print()
        wandb.log({'epoch': epoch,
                   'loss': loss,
                   'seconds': seconds})
        print(epoch, '%.2f secs:' % seconds, 'loss: %.4f' % loss)    
    return epoch + 1


induction_weight = CE_weight.unsqueeze(dim=0)

def get_next_token(encoder_out, decoder_in, pos, model=model, vocab=vocab):
    eval_model = model.eval()
    encoder_out = tensor(pad_indices(get_indices(encoder_out))).to(device)
    decoder_in = tensor(pad_indices(get_indices(decoder_in), right_shift=True)).to(device)
    decoder_out = eval_model(encoder_out=encoder_out, decoder_in=decoder_in)
    decoder_out = decoder_out * induction_weight
    _, indices = torch.max(nn.functional.softmax(decoder_out, dim=1), dim=1)
    index = int(indices[pos])
    #print('index:', index, 'pos:', pos)
    return vocab.itos[index]

def sample_with_prompt(prompt, model=model, vocab=vocab, bptt_len=conf['bptt_len']):
    encoder_out = prompt
    decoder_out = ''
    next_token = ''
    
    while next_token not in ('<pad>', '<eos>') and len(decoder_out) < (bptt_len - 1):
        #print('prompt:', prompt, 'decoder_out:', decoder_out, 'next_index:', next_index, 'next_token:', next_token)
        decoder_out += next_token
        next_token = get_next_token(encoder_out, decoder_out, pos=len(decoder_out))
        
    return prompt + decoder_out

## Train the Model

In [None]:
epoch = 0

In [None]:
epochs_per_run = 100
for _ in range(15):
    epoch = train(epochs_per_run, epoch)
    for p in [chr(i) for i in range(ord('a'), ord('z')+1)]: 
        print('%r' % sample_with_prompt(p))

In [None]:
eval_model = model.eval()
encoder_out = tensor(pad_indices(get_indices('N'))).to(device)
decoder_in = tensor(pad_indices(get_indices(''), right_shift=True)).to(device)
decoder_out = eval_model(encoder_out=encoder_out, decoder_in=decoder_in)
_, indices = torch.max(nn.functional.softmax(decoder_out, dim=1), dim=1)
indices