# My Transformer

In [1]:
from collections import defaultdict
from math import sqrt, sin, cos
import os
import random
import sys
import time
import numpy as np

import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import wandb
from torchtext.data import RawField, ReversibleField, LabelField
from torchtext.datasets import WikiText2

## Setup
 

In [2]:
conf = {
        'attn_heads': 4,
        'bptt_len': 25,
        #'cuda_device_ids': [3, 2, 1, 0],
        'cuda_device_ids': [3],
        'd_model': 20,
        #'datafile': './city_names.txt', # from: https://www.britannica.com/topic/list-of-cities-and-towns-in-the-United-States-2023068
        'datafile': './corncob_lowercase.txt',  # from: http://www.mieliestronk.com/corncob_lowercase.txt
        'dropout': 0.1,
        'learning_rate': 0.1,
        'epochs_per_loop': 50,
        'total_training_loops': 30,
        #'max_epochs': 100,
        'num_blocks_encoder': 1,
        'num_blocks_decoder': 2,
        'minibatch_size': 50000,
        #'optimizer': 'Adam'
        'optimizer': 'SGD',
        'random_seed': 0,
           
        #'batch_size': 400,
        #'dataset': 'imagenette2-320',
        #'init_gain': 5,
        #'initializer': None,
        #'load_workers': os.cpu_count(), 
        #'training_loops': 4,
        #'cuda_device_ids': [0, 1, 2],
        #'num_hidden_nodes': 300,
        }

# Make sure d_model, heads, and d_key are compatible
conf['d_key'] = conf['d_model'] / conf['attn_heads']
assert conf['d_key'] == int(conf['d_key']), 'attn_heads=%s does not evenly divide d_model=%s' % (conf['attn_heads'], conf['d_model'])

# Set up the RNGs
if conf['random_seed']:
    torch.manual_seed(conf['random_seed'])
    torch.cuda.manual_seed(conf['random_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(conf['random_seed'])

# Logging
wandb.init(project="my-transformer", config=conf)

wandb: Wandb version 0.8.32 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


W&B Run: https://app.wandb.ai/aletheap/my-transformer/runs/cvyncisv

## Model Architecture

In [3]:
class AttentionHead(nn.Module):
    def __init__(self, mask=True, d_model=conf['d_model'], d_key=conf['d_key'], bptt_len=conf['bptt_len']):
        super().__init__()
        self.d_model = d_model
        self.d_key = d_key
        self.bptt_len = bptt_len

        if mask:
            self.mask = nn.Parameter((np.NINF * torch.ones([bptt_len, bptt_len])).triu(1), requires_grad=False)
        else:
            self.mask = None
        
        # head projections
        self.Wq = nn.Linear(d_model, d_key, bias=False)
        self.Wk = nn.Linear(d_model, d_key, bias=False)
        self.Wv = nn.Linear(d_model, d_key, bias=False)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, queries, keys, values):
        # project queries, keys, values
        queries = self.Wq(queries)
        keys = self.Wk(keys)
        values = self.Wv(values)

        # calculate compatibility function
        scores = torch.matmul(queries, torch.transpose(keys, -2, -1)) / sqrt(self.d_key) #shape = (heads, bptt_len, bptt_len)
        #assert scores.shape == torch.Size([self.bptt_len, self.bptt_len])

        # Filter out attention to future positions
        if self.mask is not None:
            scores = scores.tril() + self.mask

        # softmax
        scores = self.softmax(scores)
        
        # sum the weighted value vectors
        attn = torch.matmul(scores, values)  # shape = (bptt_len, d_key)
        #assert attn.shape == torch.Size([self.bptt_len, self.d_key])

        return attn

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, mask=False, d_model=conf['d_model'], heads=conf['attn_heads'], bptt_len=conf['bptt_len']):
        super().__init__()
        d_key = int(d_model / heads)

        self.attn_heads = nn.ModuleList([AttentionHead(mask, d_model, d_key, bptt_len) for _ in range(heads)])
        self.Wo = nn.Linear(d_model, d_model, bias=False)
                    
    def forward(self, queries, keys, values):
        head_attns = [h(queries=queries, keys=keys, values=values) for h in self.attn_heads]
        head_attn = torch.cat(head_attns, dim=-1)
        return self.Wo(head_attn)

In [5]:
class FFN(nn.Module):
    def __init__(self, d_model=conf['d_model'], multiplier=4):
        super().__init__()
        
        d_ff = int(multiplier * d_model)

        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), 
                                 nn.ReLU(), 
                                 nn.Linear(d_ff, d_model))

    def forward(self, x):
        return self.ffn(x)

In [6]:
class EncoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(False, d_model, heads, bptt_len)
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.ffn_norm(a1 + self.ffn_dropout(self.ffn(a1)))

        return a2

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 dropout=conf['dropout'],
                 mask_all=False):
        super().__init__()
                
        self.self_attn = MultiHeadAttention(True, d_model, heads, bptt_len)
        
        self.self_attn_dropout = nn.Dropout(p=dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        self.enc_attn = MultiHeadAttention(mask_all, d_model, heads, bptt_len)
        self.enc_attn_dropout = nn.Dropout(p=dropout)
        self.enc_attn_norm = nn.LayerNorm(d_model)

        self.ffn = FFN(d_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.ffn_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_out):
        a1 = self.self_attn_norm(x + self.self_attn_dropout(self.self_attn(x, x, x)))
        a2 = self.enc_attn_norm(a1 + self.enc_attn_dropout(self.enc_attn(a1, encoder_out, encoder_out)))
        a3 = self.ffn_norm(a2 + self.ffn_dropout(self.ffn(a2)))
        return a3

In [8]:
class Encoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_encoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([EncoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, x):
        a = x
        for block in self.blocks:
            a = block(a)
        return a

In [9]:
class Decoder(nn.Module):
    def __init__(self, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'], 
                 num_blocks=conf['num_blocks_decoder'],
                 dropout=conf['dropout']):
        super().__init__()

        self.blocks = nn.ModuleList([DecoderBlock(d_model, heads, bptt_len, dropout) for _ in range(num_blocks)])
            
    def forward(self, encoder_out, decoder_in):
        a = decoder_in
        for block in self.blocks:
            a = block(a, encoder_out)
        return a

In [10]:
class Transformer(nn.Module):
    def __init__(self, 
                 vocab, 
                 d_model=conf['d_model'], 
                 heads=conf['attn_heads'], 
                 bptt_len=conf['bptt_len'],
                 num_blocks_encoder=conf['num_blocks_encoder'],
                 num_blocks_decoder=conf['num_blocks_decoder'], 
                 dropout=conf['dropout']):
        super().__init__()

        self.vocab = vocab
        self.d_model = d_model
        self.bptt_len = bptt_len
        
        self.embedding = nn.Embedding(len(vocab), d_model, padding_idx=vocab.stoi['<pad>'])
        self.position_encoding = nn.Parameter(self._position_encoding(), requires_grad=False)
        self.embed_dropout = nn.Dropout(p=dropout)
                                            
        #self.encoder = Encoder(d_model, heads, bptt_len, num_blocks_encoder, dropout)
        self.decoder = Decoder(d_model, heads, bptt_len, num_blocks_decoder, dropout)

        self.linear = nn.Linear(d_model, len(self.vocab))
        #self.linear.weight = self.embedding.weight  # Section 3.4
        self.linear_softmax = nn.Softmax(dim=-1)
        #self.linear_dropout = nn.Dropout(p=dropout)
        #self.linear_norm = nn.LayerNorm(len(vocab))
           
    def _position_encoding(self):
        d_model = self.d_model
        rows = [tensor([sin(pos/(10000**(i/d_model))) 
                        if i % 2 == 0 
                        else 
                        cos(pos/(10000**((i-1)/d_model))) 
                        for i in range(d_model)])
                for pos in range(self.bptt_len)]
        stack = torch.stack(rows, dim=1)
        assert stack.shape == torch.Size([self.d_model, self.bptt_len])
        
        return stack.T
    
    def embed(self, indices):
        embedded = self.embedding(tensor(indices))
        #assert embedded.shape == torch.Size([self.bptt_len, self.d_model])
        return embedded + self.position_encoding
        
    #def forward(self, encoder_in=None, encoder_out=None, decoder_in=[]):
    def forward(self, encoder_out, decoder_in):
        """parameters:
        encoder_in:  (rank-1 tensor) vocab indices of encoder input token 
                     sequence
        encoder_out: (optional rank-1 tensor) passing this skips 
                     the encoder execution, and acts and if this were 
                     the indices the encoder produced.
        decoder_in:  (optional rank-1 tensor) vocab indices of prior 
                     decoder output for auto-regression. Right 
                     shifted by one position."""
        
        # Embed
        #embedded = self.embed_dropout(self.embed(encoder_in))
            
        # Encode
        #encoder_out = self.embed_dropout(self.embed(embedded))
        
        # Decode
        encoder_out = self.embed(encoder_out)
        #print('encoder_out = ', encoder_out)
        decoder_in = self.embed(decoder_in)
        #print('decoder_in = ', decoder_in)
        decoder_out = self.decoder(encoder_out, decoder_in)
        #print('decoder_out = ', decoder_out)
        #print('decoded:', decoded)

        # Return predictions for next token
        y_pred = self.linear_softmax(self.linear(decoder_out))
        return y_pred
        #return self.linear_norm(self.linear_dropout(y_pred))
        #y_pred = torch.matmul(decoded, self.embedding.weight.T)

## Build the Vocab and Model

In [11]:
# Set up Cuda
print("Using", len(conf['cuda_device_ids']), "GPU(s):")
for i in conf['cuda_device_ids']:
    print("    cuda:%s:" % i, torch.cuda.get_device_name(i))

device = torch.device('cuda:' + str(conf['cuda_device_ids'][0]))

# Make the vocabulary
with open(conf['datafile'], 'r') as f:
    vocab = torchtext.vocab.build_vocab_from_iterator(f.read().replace('\n','').lower())
    vocab.stoi['<pad>'] = 0
    vocab.stoi['<eos>'] = 1
    vocab.itos[0] = '<pad>'
    vocab.stoi[1] = '<eos>'
with open(conf['datafile'], 'r') as f:
    vocab.freqs['<eos>'] = len(f.readlines())

# define the model
model = Transformer(vocab)
model = model.half().to(device)
model = nn.DataParallel(model, device_ids=conf['cuda_device_ids'])
optimizer = getattr(torch.optim, conf['optimizer'])(model.parameters(), lr=conf['learning_rate'])

CE_freqs = [float(vocab.freqs[t]) for t in vocab.itos]
CE_weight = [(0. if f == 0 else 1/f) for f in CE_freqs]
CE_weight = torch.tensor(CE_weight, dtype=torch.half, device=device)
##criterion = nn.CrossEntropyLoss(weight=CE_weight, ignore_index=vocab.stoi['<pad>'])
criterion = nn.CrossEntropyLoss(weight=CE_weight)
#criterion = nn.CrossEntropyLoss()

53488lines [00:00, 534872.86lines/s]

Using 1 GPU(s):
    cuda:3: GeForce RTX 2080 Ti


484586lines [00:00, 586833.57lines/s]


## Helper Functions

In [None]:
data = defaultdict(list)

In [13]:
def pad_indices(indices, right_shift=False, bptt_len=conf['bptt_len']):
    """Takes a list of token `indices`, appends the index for the
    <eos> token and pads the list on the right to the bptt_len length.
    If you pass `right_shift=True`, then it also inserts the index for 
    the pad token to the beginning of the list and shifts everything 
    else to the right (still maintaining bptt_len."""
    indices = list(map(int, indices))
    eos_index = vocab.stoi['<eos>']
    pad_index = vocab.stoi['<pad>']
    if (not indices) or (indices[-1] != eos_index):
        indices.append(eos_index)
    if right_shift:
        indices.insert(0, pad_index)
    indices = indices[:bptt_len]
    pad_len = bptt_len - len(indices)
    indices += [pad_index] * pad_len
    return indices
    
def get_indices(string, vocab=vocab, bptt_len=conf['bptt_len']):
    """takes a string, tokenizes it, and returns the (unpadded) 
    list of token incides for the tokens in the string. The output
    of this method is suitable input for `pad_indices()`"""
    tokens = list(string.strip().lower())
    tokens = tokens[:bptt_len]
    indices = list(map(lambda x: vocab.stoi[x], tokens))
    return indices

def _get_tensors(string, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer, bptt_len=conf['bptt_len']):
    """Takes a string returns the appropriately padded and shifted 
    `encoder_out`, `decoder_in`, and `y` tensors for it"""
    indices = get_indices(string, vocab)
    encoder_out = []
    decoder_in = []
    y = []
    for i in range(len(indices)):
        encoder_out.append(tensor(pad_indices(indices[:i])).unsqueeze(0))
        decoder_in.append(tensor(pad_indices(indices[i:], right_shift=True)).unsqueeze(0))
        y.append(tensor(pad_indices(indices[i:])).unsqueeze(0))
    return encoder_out, decoder_in, y

def get_data(minibatch_size=conf['minibatch_size'], vocab=vocab, data_file=conf['datafile']):
    """Reads the \n separated training strings from the data file,
    and returns file and returns a generator of `encoder_out`, 
    `encoder_in`, and `y` tensors. Each has a shape:
    (minibatch_size, bptt_len, d_model)"""
    global data
    if not data:
        with open(data_file,'r') as f:
            strings = [line.strip().lower() for line in f.readlines()]
        for string in strings:
            encoder_out, decoder_in, y = _get_tensors(string)
            data['encoder_out'].extend(encoder_out)
            data['decoder_in'].extend(decoder_in)
            data['y'].extend(y)
        data['encoder_out'] = torch.cat(data['encoder_out']).to(device)
        data['decoder_in'] = torch.cat(data['decoder_in']).to(device)
        data['y'] = torch.cat(data['y']).to(device)

    eo_split = torch.split(data['encoder_out'], minibatch_size, dim=0)
    di_split = torch.split(data['decoder_in'], minibatch_size, dim=0)
    y_split = torch.split(data['y'], minibatch_size, dim=0)
    for encoder_out, decoder_in, y_split in zip(eo_split, di_split, y_split):
        yield encoder_out, decoder_in, y_split 
    #batch_size = data['y'].shape[0]
    #i = 0
    #while i < batch_size:
    #    j = i + minibatch_size
    #    encoder_out = data['encoder_out'][i:j,:]
    #    decoder_in = data['decoder_in'][i:j,:]
    #    y = data['y'][i:j,:]
    #    yield encoder_out, decoder_in, y
    #    i = j
        
def train_data(encoder_out, decoder_in, y):
    """Runs one minibatch training and returns the loss for that minibatch"""
    optimizer.zero_grad()
    y_pred = model(encoder_out=encoder_out, decoder_in=decoder_in)
    y_pred = torch.transpose(y_pred, -2, -1)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    #sys.stdout.write('.')
    return loss.item()

def do_epoch(epoch, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer):
    """Runs one full training batch and returns the average loss and 
    duration time in seconds"""
    t0 = time.time()
    losses = [train_data(*args) for args in get_data()]
    tf = time.time()
    avg_loss = sum(losses) / len(losses)
    return (avg_loss, tf-t0)

def train(num_epochs=conf['epochs_per_loop'], start_epoch=0, model=model, vocab=vocab, criterion=criterion, optimizer=optimizer):
    """Runs num_epochs training batchs and prints out results"""
    model = model.train()
    for epoch in range(start_epoch, start_epoch+num_epochs):
        loss, seconds = do_epoch(epoch, model, vocab, criterion, optimizer)
        #print()
        wandb.log({'epoch': epoch,
                   'loss': loss,
                   'seconds': seconds})
        print(epoch, '(%.1fs)' % seconds, 'loss=%.4f' % loss)    
    return epoch + 1


induction_weight = CE_weight.unsqueeze(dim=0)

def get_next_token(encoder_out, decoder_in, pos, model=model, vocab=vocab):
    """Runs one step of auto-regression, returning the output token for
    position `pos`. Uses a Multinomial distribution for sampling."""
    eval_model = model.eval()
    encoder_out = tensor(pad_indices(get_indices(encoder_out))).to(device)
    decoder_in = tensor(pad_indices(get_indices(decoder_in), right_shift=True)).to(device)
    decoder_out = eval_model(encoder_out=encoder_out, decoder_in=decoder_in)
    #decoder_out = decoder_out * induction_weight
    probs = nn.functional.softmax(decoder_out.float(), dim=-1)
    m = torch.distributions.multinomial.Multinomial(probs=probs)
    _, indices = torch.max(m.sample(), dim=-1)
    #_, indices = torch.max(nn.functional.softmax(decoder_out, dim=1), dim=1)
    index = int(indices[pos])
    #print('index:', index, 'pos:', pos)
    return vocab.itos[index]

def sample_with_prompt(prompt, model=model, vocab=vocab, bptt_len=conf['bptt_len']):
    """samples a string the network beginning with prompt"""
    encoder_out = prompt
    decoder_out = ''
    next_token = ''
    
    while next_token not in ('<pad>', '<eos>') and len(decoder_out) < (bptt_len - 1):
        #print('prompt:', prompt, 'decoder_out:', decoder_out, 'next_index:', next_index, 'next_token:', next_token)
        decoder_out += next_token
        next_token = get_next_token(encoder_out, decoder_out, pos=len(decoder_out))
        
    return prompt + decoder_out

def alphabet_sample():
    """samples strings beginning with each letter of the alphabet"""
    for p in [chr(i) for i in range(ord('a'), ord('z')+1)]: 
        print('%r' % sample_with_prompt(p))

## Train the Model

In [14]:
epoch = 0

In [15]:
for _ in range(conf['total_training_loops']):
    alphabet_sample()
    epoch = train(conf['epochs_per_loop'], epoch)
alphabet_sample()



'afmsiizbknlsauvcgqhkpdgui'
'bvqchvazejeokxvspjxsneyau'
'cbrxljuryazqhmkdsb'
'didbm'
'egxozn-aqomgrrzthtctjeiby'
'fifwgennnyax'
'gzg'
'hbixgrztiythaaoj'
'iyvg'
'jyh'
'kufzaftkwbmhwawzrliukohap'
'lckcmwvjrhwkdbmzjscyf-'
'mgjtnhedju'
'nztytgabhxltx'
'oa-hwixh-riwm-vxqccj'
'ppfqxkiujdghaubxaoluvsapd'
'qgbqcpbmcbyo'
'rjdq-o-mt'
'sr'
'txpqawgy'
'ufr'
'veshdhmarbgvxexfpvwtahhna'
'w'
'xwnyxb'
'yglafautthrjonk'
'zrdusvocampcuu-zwohzejuww'


wandb: Wandb version 0.8.32 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


0 (25.6s) loss=3.4605




1 (7.2s) loss=3.4617
2 (7.3s) loss=3.4605
3 (7.3s) loss=3.4598
4 (7.3s) loss=3.4604
5 (7.4s) loss=3.4598
6 (7.4s) loss=3.4588
7 (7.4s) loss=3.4586
8 (7.4s) loss=3.4594
9 (7.4s) loss=3.4572
10 (7.4s) loss=3.4557
11 (7.4s) loss=3.4557
12 (7.4s) loss=3.4537
13 (7.4s) loss=3.4529
14 (7.5s) loss=3.4523
15 (7.5s) loss=3.4488
16 (7.5s) loss=3.4457
17 (7.5s) loss=3.4412
18 (7.5s) loss=3.4381
19 (7.5s) loss=3.4318
20 (7.5s) loss=3.4256
21 (7.5s) loss=3.4199
22 (7.5s) loss=3.4160
23 (7.5s) loss=3.4062
24 (7.5s) loss=3.3969
25 (7.5s) loss=3.3896
26 (7.5s) loss=3.3855
27 (7.5s) loss=3.3814
28 (7.5s) loss=3.3787
29 (7.5s) loss=3.3770
30 (7.5s) loss=3.3742
31 (7.5s) loss=3.3750
32 (7.5s) loss=3.3713
33 (7.5s) loss=3.3725
34 (7.5s) loss=3.3709
35 (7.5s) loss=3.3701
36 (7.5s) loss=3.3693
37 (7.5s) loss=3.3684
38 (7.5s) loss=3.3684
39 (7.6s) loss=3.3691
40 (7.5s) loss=3.3680
41 (7.5s) loss=3.3678
42 (7.5s) loss=3.3678
43 (7.5s) loss=3.3666
44 (7.5s) loss=3.3658
45 (7.5s) loss=3.3658
46 (7.5s) loss=3.36

In [None]:
sample_with_prompt('New')

In [None]:
eval_model = model.eval()
encoder_out = tensor(pad_indices(get_indices('N'))).to(device)
decoder_in = tensor(pad_indices(get_indices(''), right_shift=True)).to(device)
decoder_out = eval_model(encoder_out=encoder_out, decoder_in=decoder_in)
_, indices = torch.max(nn.functional.softmax(decoder_out, dim=1), dim=1)
indices