# My Transformer

In [None]:
from collections import defaultdict
from math import sqrt, sin, cos
import os
import random
import sys
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2

 ## Config, Logging, RNG

In [None]:
conf = {
        'attn_heads': 2,
        'bptt_len': 50,
        #'cuda_device_ids': [3, 2, 1, 0],
        'cuda_device_ids': [3],
        'd_model': 4,
        'device': 'cuda',
        'datafile': './city_names.txt',
        'dropout': 0.1,
        'learning_rate': 0.1,
        'max_epochs': 1000,
        'num_blocks_encoder': 6,
        'num_blocks_decoder': 2,
        'optimizer': 'Adam',
        'random_seed': 0,
           
        #'batch_size': 400,
        #'dataset': 'imagenette2-320',
        #'init_gain': 5,
        #'initializer': None,
        #'load_workers': os.cpu_count(), 
        #'training_loops': 4,
        #'cuda_device_ids': [0, 1, 2],
        #'num_hidden_nodes': 300,
        }

# Pulling list of cities from: https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068

conf['d_key'] = conf['d_model'] / conf['attn_heads']
assert conf['d_key'] == int(conf['d_key']), 'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], conf['d_model'])

In [None]:
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])

In [None]:
#wandb.init(project="my-transformer", config=conf)


## Define the Transformer Architecture

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, mask=True, d_model=conf['d_model'], d_key=conf['d_key'], bptt_len=conf['bptt_len']):
        super().__init__()
        self.d_model = d_model
        self.d_key = d_key
        self.bptt_len = bptt_len

        if mask:
            self.mask = nn.Parameter((np.NINF * torch.ones([bptt_len, bptt_len])).triu(1), requires_grad=False)
        else:
            self.mask = None
        
        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, queries, keys, values):
        # project queries, keys, values
        queries = self.Wq(queries)
        #print('(AttenHead) queries=', queries)
        keys = self.Wk(keys)
        #print('(AttenHead) keys=', keys)
        values = self.Wv(values)
        #print('(AttenHead) values=', values)

        # calculate compatibility function
        scores = torch.matmul(queries, keys.T) / sqrt(self.d_key) #shape = (heads, bptt_len, bptt_len)
        #print('(AttenHead) scores=', scores)
        assert scores.shape == torch.Size([self.bptt_len, self.bptt_len])

        # Filter out attention to future positions
        if self.mask is not None:
            t = scores.tril()
            #print('(AttenHead) zeroed scores=', t)
            #print('(AttenHead) mask=', self.mask)

            scores = scores.tril() + self.mask
            #print('(AttenHead) masked scores=', scores)

        # softmax
        scores = self.softmax(scores)
        #print('(AttenHead) softmax scores=', scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)
        #print('(AttenHead) attn=', attn)
        assert attn.shape == torch.Size([self.bptt_len, self.d_key])

        return attn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, mask=False, d_model=conf['d_model'], heads=conf['attn_heads'], bptt_len=conf['bptt_len']):
        super().__init__()
        d_key = int(d_model / heads)

        self.Wo = nn.Linear(d_model, d_model)
        self.attn_heads = nn.ModuleList([AttentionHead(mask, d_model, d_key, bptt_len) for _ in range(heads)])
                    
    def forward(self, queries, keys, values):
        head_attns = [h(queries, keys, values) for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=1)
        return self.Wo(head_attn)

In [None]:
class FFN(nn.Module):
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model))

    def forward(self, x):
        return self.ffn(x)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(False, d_model, heads, bptt_len)
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.ffn_norm(a1 + self.ffn_dropout(self.ffn(a1)))

        return a2

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
                
        self.self_attn = MultiHeadAttention(True, d_model, heads, bptt_len)
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.enc_attn = MultiHeadAttention(mask_all, d_model, heads, bptt_len)
        self.enc_attn_dropout = nn.Dropout(p=dropout)
        self.enc_attn_norm = nn.LayerNorm(d_model)

        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoded):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.enc_attn_norm(a1 + self.enc_attn_dropout(self.enc_attn(a1, encoded, encoded)))
        a3 = self.ffn_norm(a2 + self.ffn_dropout(self.ffn(a2)))
        return a3

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a, x)
        return a

In [None]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout']):
        super().__init__()

        self.vocab = vocab
        self.d_model = d_model
        self.bptt_len = bptt_len
        
        self.embedding = torch.nn.Embedding(len(vocab), d_model, padding_idx=1)
        self.position_encoding = nn.Parameter(self._position_encoding(), requires_grad=False)
        self.initial_dropout = nn.Dropout(p=dropout)
                                            
        #self.encoder = Encoder(d_model, heads, bptt_len, num_blocks_encoder, dropout)
        self.decoder = Decoder(d_model, heads, bptt_len, num_blocks_decoder, dropout)

        self.linear = nn.Linear(d_model, len(self.vocab), bias=False)
        self.linear.weight = self.embedding.weight  # Section 3.4
        self.final_dropout = nn.Dropout(p=dropout)
           
    def _position_encoding(self):
        d_model = self.d_model
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(self.bptt_len)]
        stack = torch.stack(rows, dim=1)
        assert stack.shape == torch.Size([self.d_model, self.bptt_len])
        
        return stack.T
    
    def embed(self, indices):
        embedded = self.embedding(tensor(indices))
        assert embedded.shape == torch.Size([self.bptt_len, self.d_model])
        return embedded + self.position_encoding
        
    def forward(self, indices):
        # Embed
        embedded = self.initial_dropout(self.embed(indices))
        #print('embedded:', embedded)

        # Encode
        #encoded= self.encoder(embedded)
        
        # Decoder
        #decoded = self.decoder(encoded)
        decoded = self.decoder(embedded)
        #print('decoded:', decoded)

        # Dis-embed
        predictions = self.final_dropout(self.linear(decoded))
        #print('predictions:', predictions)
        
        return predictions

## Helper Functions

In [None]:
def get_data(vocab, data_file=conf['datafile']):
    with open(data_file,'r') as f:
        for line in f.readlines():
            yield line.strip().lower()
            #yield get_indices(string, vocab, include_shifted=True)

def make_vocab(data_file=conf['datafile']):
    with open(data_file,'r') as f:
        vocab = torchtext.vocab.build_vocab_from_iterator(f.read().replace('\n','').lower())
    return vocab

def get_indices(string, vocab, max_tokens=conf['bptt_len'], include_shifted=False):
    tokens = list(string.strip().lower())
    tokens = tokens[:max_tokens]
    pad_len = max_tokens - len(tokens)
    tokens.extend(['<pad>'] * pad_len)
    indices = list(map(lambda x: vocab.stoi[x], tokens))
    t = tensor(indices)
    if include_shifted:
        t_shifted = tensor([0] + indices[:-1])
        return t, t_shifted
    else:
        return t

In [None]:
def train_string(model, vocab, string, criterion, optimizer):
    x, y = get_indices(string, vocab, include_shifted=True)
    x = x.to(device)
    y = y.to(device)
    optimizer.zero_grad()
    y_pred = model(x)
    print('y_pred:', y_pred)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()            
    return loss.item()

In [None]:
def do_epoch(model, vocab, criterion, optimizer, epoch='    '):
    losses = []
    for string in get_data(vocab=vocab):
        loss = train_string(model, vocab, string, criterion, optimizer)
        losses.append(loss.item())
        #sys.stdout.write('.')
        print(epoch, 'string:', string, 'loss:', loss) 
    avg_loss = sum(losses) / len(losses)
    return avg_loss

In [None]:
def train(model, vocab, criterion, optimizer, max_epochs=conf['max_epochs']):
    for epoch in range(max_epochs):
        loss = do_epoch(model, vocab, criterion, optimizer)
        print('epoch:', epoch, 'loss:', loss)    

## Build and Train the Model

In [None]:
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))

device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))

# Make the vocabulary
vocab = make_vocab()

# define the model
model = Transformer(vocab)
model = model.half().to(device)
model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])

criterion = nn.CrossEntropyLoss()
optimizer = getattr(torch.optim, conf['optimizer'])(model.parameters(), lr=conf['learning_rate'])

In [None]:
list(model.parameters())

In [None]:
train_string(model, vocab, 'Atlanta', criterion, optimizer)

In [None]:
list(model.parameters())

In [None]:
do_epoch(model, vocab, criterion, optimizer)

In [None]:
train(model, vocab, criterion, optimizer)

In [None]:
_, out_indices = torch.max(self.softmax(decoded), dim=1)




def next_predicted_token(model, vocab, string, beam_width=1):
    model = model.eval()
    y_pred=model(string)
    y_pred_list = list(y_pred.squeeze())
    max_val = max(y_pred_list)
    index = y_pred_list.index(max_val)
    predicted_token = vocab.itos[index]
    return predicted_token

In [None]:
def sample_with_prompt(model, vocab, prompt, beam_width=1):
    output = vocab.tokenize(prompt)
    
    next_token = None
    while next_token != '<EOS>' and len(output) <= conf['bptt_len']:
        next_token = next_predicted_token(model, vocab, vocab.detokenize(output))
        output.append(next_token)

    return vocab.detokenize(output)

In [None]:
class Vocab(nn.Module):
    def __init__(self, 
                 data_file=conf['datafile'], 
                 d_model=conf['d_model'], 
                 split_field='', 
                 bptt_len=conf['bptt_len'],
                 device=device):
        super().__init__()
        
        self.data_file = data_file
        self.split_field = split_field
        self.d_model = d_model
        self.bptt_len = bptt_len

        self.itos = []
        self.stoi = {}
        self.stoe = {}
        self.freq = defaultdict(int)
        self._register_token('<EOS>')

        for line in self.load_strings():
            for token in self.tokenize(line):
                self._register_token(token)

    def __len__(self):
        return len(self.itos)

    def _register_token(self, token):
        if not token in self.stoi:
            self.itos.append(token)
            self.stoi[token] = len(self) - 1
            self.stoe[token] = torch.randn(self.d_model, device=device, dtype=torch.half, requires_grad=True)
        self.freq[token] += 1

    def tokenize(self, string):
        if self.split_field == '':
            ret = list(string)
        else:
            ret = string.split(self.split_field)
        tokens = list(map(str.lower, ret))
        tokens = tokens[:self.bptt_len - 1] + ['<EOS>']
        return tokens
    
    def detokenize(self, tokens):
        if tokens[-1] == '<EOS>':
            tokens.pop()
        return ''.join(tokens)            

    def load_strings(self, shuffle=True):
        with open(self.data_file, 'r') as f:
            lines = [line.strip() for line in f.readlines()]
        if shuffle:
            random.shuffle(lines)
        return lines

    def embed(self, string):
        tokens = self.tokenize(string)
        if len(tokens) < self.bptt_len:
            tokens.extend(['<EOS>'] * (self.bptt_len - len(tokens)))
        vectors = [self.stoe[token] for token in tokens]
        tensors = list(map(lambda t: t.unsqueeze(0), vectors))
        return torch.cat(tensors, 0)
                    
    def forward(self, string):
        return self.embed(string)
        

In [None]:
sample_with_prompt(model, vocab, 'N')