# Transformers 101

This notebook serves as an exploration of the transformer architecture (Vaswani et. al.) Here, we'll implement in native PyTorch the basic building blocks of the transformer and then put them all together so we have a model architecture to put into `../models`

In the process of putting this together (much like my other exploratory projects) I tried to limit viewing existing code online, and primarily used my notes (pdf attached for anyone interested) as a foundation for this work.

In [1]:
import torch
import math
import torch.nn as nn
from torch.utils.data import random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS
from torch.utils.data import DataLoader

We want something with output dims: (sequence_length, output_dim)

In [60]:
def positional_embedding(input_tensor: torch.Tensor, output_dim: int, n=10000): 
    """
    Here, we implement the naive approach from the original 
    paper with the sin and cosine functions. 
    """
    P = torch.zeros((input_tensor.shape[-1], output_dim))
    indices = torch.arange(input_tensor.size(-1))
    i_values = torch.arange(int(output_dim/2))
    denominators = torch.float_power(n, 2*i_values/output_dim)
    P[:, 0::2] = torch.sin(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 0, step by 2 sin for even nums
    P[:, 1::2] = torch.cos(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 1, step by 2 cos for odd nums
    return P


In [61]:
a = torch.rand((2, 5))
output_dims = 3
positional_embedding(a, output_dims)

tensor([[ 0.0000,  1.0000,  0.0000],
        [ 0.8415,  0.5403,  0.8415],
        [ 0.9093, -0.4161,  0.9093],
        [ 0.1411, -0.9900,  0.1411],
        [-0.7568, -0.6536, -0.7568]])

In [62]:
def attention(x): 
    """
    Simple dot product based attention
    """
    query_layer, key_layer, value_layer = nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1])
    query, key, value = query_layer(x), key_layer(x), value_layer(x)
    attention_weights  = torch.nn.Softmax(-1)(torch.tensordot(query, key, dims=1))
    return torch.sum(value * attention_weights)

In [63]:
x = torch.rand(1, 12)
attention(x)

tensor(0.0601, grad_fn=<SumBackward0>)

Just to emulate how it would be implemented, we write out the add norm function below. However in practice, this will be encompassed by each transformer sub module since each of them are followed by addition with residual and layer normalization. 

In [64]:
def add_norm(residual: torch.Tensor, hidden: torch.Tensor): 
    if residual.shape != hidden.shape: 
        raise ValueError("Shapes mismatch")
    else: 
        output = residual + hidden # element wise addition
        layer_norm = nn.LayerNorm([residual.shape[-2], residual.shape[-1]])
        return layer_norm(output)

In [65]:
# usage example: 

tensor_a = torch.rand([1, 5, 6]) # batch size, sequence length, embedding dimensions
tensor_b = torch.rand([1, 5, 6])
print(tensor_a)
print(tensor_b)
print(f"Final: {add_norm(tensor_a, tensor_b)}")

tensor([[[0.0205, 0.9839, 0.4722, 0.5144, 0.4750, 0.3709],
         [0.6898, 0.0637, 0.7878, 0.8657, 0.6293, 0.9553],
         [0.0415, 0.0994, 0.4251, 0.9324, 0.8413, 0.8969],
         [0.2195, 0.2179, 0.5099, 0.2900, 0.5529, 0.1354],
         [0.6915, 0.2766, 0.6307, 0.5326, 0.1450, 0.8390]]])
tensor([[[0.6223, 0.9084, 0.0278, 0.3628, 0.4586, 0.9681],
         [0.1042, 0.7222, 0.9661, 0.1567, 0.8702, 0.8263],
         [0.0307, 0.7071, 0.1334, 0.7505, 0.0034, 0.6686],
         [0.8367, 0.1280, 0.1493, 0.6331, 0.3516, 0.3006],
         [0.0411, 0.5540, 0.1565, 0.5201, 0.1892, 0.5447]]])
Final: tensor([[[-0.6920,  2.0346, -1.0038, -0.1804, -0.0576,  0.8272],
         [-0.3622, -0.3797,  1.7324,  0.1364,  1.1773,  1.7930],
         [-1.9372, -0.3348, -0.8759,  1.5774, -0.2516,  1.3213],
         [ 0.2100, -1.3400, -0.6562, -0.0804, -0.1209, -1.1433],
         [-0.4961, -0.2822, -0.3770,  0.2024, -1.3655,  0.9246]]],
       grad_fn=<NativeLayerNormBackward0>)


In [66]:
def scaled_dot_product_attention(q, k, d_k):
    # in order to align the dimensions for the dot product, we transpose k along the last two dimensions like this
    return torch.nn.Softmax(-1)(torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k))

d_k and d_v are essentially hyperparameters that are fixed before training. This allows for the query and keys to have the same dimensionality, and for all 3 of them to have consistent dimensionality. In many transformer implementations, d_k and d_v are set to be the same for simplicity but this is not always the case.

In [67]:
def multihead_attention(k, q, v, d_k, d_v, d_model, num_heads):
    """
    Scaled Dot product based multi-head attention
    """
    # declare projection layers - assume all inputs have d_model size in the last dimension, and project to number of heads * d_k or d_v 
    query_layer, key_layer, value_layer = nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads*d_v)
    k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
    residual = q

    # in the following line we apply the linear projections and then reshape the outputs for multihead attention. 
    #The reshaping splits the last dimension of the linear layer's output into num_heads and d_k (or d_v for value). 
    # This creates multiple "heads" in the tensor, each with its own d_k (or d_v) dimension
    k, q, v = key_layer(k).view(batch_size, k_len,  num_heads, d_k), query_layer(q).view(batch_size, q_len,  num_heads, d_k), value_layer(v).view(batch_size, v_len,  num_heads, d_v)
    
    # we perform the following transpose so that the num heads dimension preceeds the seq length dimension. This way, each head can capture different information about the same sequence
    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
    attention = scaled_dot_product_attention(q, k, d_k)
    output = torch.matmul(attention, v)

    # following reshaping is done so that we can add our output to the residual 
    output = output.transpose(1, 2).contiguous().view(batch_size, q_len, -1)
    concatenated_projection = nn.Linear(num_heads * d_v, d_model, bias=False)

    output = concatenated_projection(output)
    output += residual

    print(residual.shape)
    layer_norm = nn.LayerNorm([residual.shape[-2], residual.shape[-1]])
    output = layer_norm(output)

    return output, attention

In [68]:
d_model = 512

# from the paper: To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension d model 
k, q, v = torch.rand((1, 2, d_model)), torch.rand((1, 2, d_model)), torch.rand((1, 2, d_model))
d_k, d_v = 5, 5
num_heads = 4

out, attn = multihead_attention(k, q, v, d_k, d_v, d_model, num_heads)
print(out.shape, attn.shape)

torch.Size([1, 2, 512])
torch.Size([1, 2, 512]) torch.Size([1, 4, 2, 2])


In [69]:
class PositionWiseFFN(nn.Module): 
    def __init__(self, d_model, d_ff, dropout) -> None:
        super(PositionWiseFFN, self).__init__()
        self.fc1 = nn.Sequential(nn.Linear(d_model, d_ff, bias=True),nn.ReLU())
        self.fc2 = nn.Linear(d_ff, d_model, bias=True)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x
        x = self.fc2(self.fc1(x))        
        return self.dropout(self.layer_norm(x+residual))


In [70]:
ffn = PositionWiseFFN(d_model, 2048, 0.1)
x = torch.rand((1, 2, d_model))
ffn(x)

tensor([[[-0.2430,  0.0000,  0.1437,  ..., -1.7312, -2.0171,  0.4877],
         [-0.2407, -1.0071,  0.4952,  ..., -0.6492,  0.5152, -0.5740]]],
       grad_fn=<MulBackward0>)

Now that we've implemented the lowest level building blocks of the transormer, below we put them together to build transformer blocks, encoder and decoder layers, and the complete transformer architecture. Now we try and condense everything to a more concise-less experimental implementation. 

In [146]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_k, d_model, d_v, dropout, num_heads) -> None:
        super(MultiHeadAttention, self).__init__()
        self.d_k, self.d_v, self.d_model, self.num_heads = d_k, d_v, d_model, num_heads
        self.query_layer, self.key_layer, self.value_layer = nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads*d_v)
        self.layer_norm = nn.LayerNorm(d_model)
        self.concat_projection = nn.Linear(num_heads*d_v, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
        residual = q
        k, q, v = self.key_layer(k).view(batch_size, k_len,  self.num_heads, self.d_k), self.query_layer(q).view(batch_size, q_len,  self.num_heads, self.d_k), self.value_layer(v).view(batch_size, v_len,  self.num_heads, self.d_v)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        attention = scaled_dot_product_attention(q, k, self.d_k)
        output = torch.matmul(attention, v)
        output = self.concat_projection(output.transpose(1, 2).contiguous().view(batch_size, q_len, -1))
        return self.dropout(self.layer_norm(output+residual))

In [147]:
class EncoderLayer(nn.Module): 
    def __init__(self, d_k, d_model, d_v, num_heads, d_ff, dropout) -> None:
        super(EncoderLayer, self).__init__()
        self.k_layer, self.q_layer, self.v_layer = nn.Linear(d_model, d_model), nn.Linear(d_model, d_model), nn.Linear(d_model, d_model)
        self.multihead_attention = MultiHeadAttention(d_k, d_model, d_v, dropout, num_heads)
        self.pointwise_ffn = PositionWiseFFN(d_model, d_ff, dropout)
    
    def forward(self, x): 
        k, q, v = self.k_layer(x), self.q_layer(x), self.v_layer(x)
        output = self.multihead_attention(q, k, v)
        return self.pointwise_ffn(output)

The following encoder implementation is based off of the block diagram from Attention Is All You Need

In [148]:
class Encoder(nn.Module):
    def __init__(self, d_k, d_model, d_v, d_ff, num_heads, num_layers, vocab_size, dropout=0.1) -> None:
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.positional_embedding = positional_embedding
        self.dropout = nn.Dropout(dropout)
        self.layers = [EncoderLayer(d_k, d_model, d_v, num_heads, d_ff, dropout) for _ in range(num_layers)]
    
    def forward(self, x):
        embedded = self.embedding(x)
        x = self.dropout(embedded + self.positional_embedding(x, self.d_model))
        for layer in self.layers:
            x = layer(x)
        return x

## Preliminary testing on text classification task (encoder only)

In [149]:
train_iter = AG_NEWS(split='train')

# Convert to list to enable random splitting
train_dataset = list(train_iter)

#80-20 train-val split 
train_size = int(len(train_dataset) * 0.8)  
val_size = len(train_dataset) - train_size  
train_data, val_data = random_split(train_dataset, [train_size, val_size])

tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for text in data_iter:
        yield tokenizer(text)

VOCAB_SIZE = 5000

# Build vocab based on the train_data
train_data_iter = (text for _, text in train_data)
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter), specials=["<unk>"], max_tokens=VOCAB_SIZE)
vocab.set_default_index(vocab["<unk>"])

In [150]:
from torch.nn.utils.rnn import pad_sequence

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_list, text_list, lengths = [], [], []
    
    # Sort the batch in the descending order
    batch.sort(key=lambda x: len(x[1]), reverse=True)
    
    for _label, _text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        lengths.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    lengths = torch.tensor(lengths, dtype=torch.int64)
    
    # Pad sequences
    text_list = pad_sequence(text_list, batch_first=True)
    
    return label_list.to(device), text_list.to(device), lengths

In [151]:
train_loader = DataLoader(train_data, batch_size = 8, shuffle = True, collate_fn = collate_batch)
val_loader = DataLoader(val_data, batch_size = 8, shuffle = False, collate_fn = collate_batch)

In [152]:
LEARNING_RATE = 1e-3
NUM_EPOCHS = 50
DROPOUT = 0.1
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

D_K = 128
D_V = 128
D_FF = 512
D_MODEL = 256
NUM_LAYERS = 2
OUTPUT_DIM = 4

In [153]:
model = Encoder(D_K, D_MODEL, D_V, D_FF, num_heads=4, num_layers=4, vocab_size=VOCAB_SIZE)
model = model.to(DEVICE)

In [154]:
for i, batch_data in enumerate(train_loader):
            
    model.train()
    (y, x, x_size) = batch_data
    #print("Labels: {}, data: {}, x_size.cpu(): {}".format(batch_data[0], x.shape,x_size.cpu()))

    logits = model(x)