# Transformers 101

This notebook serves as an exploration of the transformer architecture (Vaswani et. al.) Here, we'll implement in native PyTorch the basic building blocks of the transformer and then put them all together so we have a model architecture to put into `../models`

In the process of putting this together (much like my other exploratory projects) I tried to limit viewing existing code online, and primarily used my notes (pdf attached for anyone interested) as a foundation for this work.

In [2]:
import torch
import torch.nn as nn

We want something with output dims: (sequence_length, output_dim)

In [3]:
def positional_embedding(input_tensor: torch.Tensor, output_dim: int, n=10000): 
    """
    Here, we implement the naive approach from the original 
    paper with the sin and cosine functions. 
    """
    P = torch.zeros((input_tensor.shape[-1], output_dim))
    indices = torch.arange(input_tensor.size(-1))
    i_values = torch.arange(int(output_dim/2))
    denominators = torch.float_power(n, 2*i_values/output_dim)
    P[:, 0::2] = torch.sin(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 0, step by 2 sin for even nums
    P[:, 1::2] = torch.cos(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 1, step by 2 cos for odd nums
    return P


In [4]:
a = torch.rand((2, 5))
output_dims = 3
positional_embedding(a, output_dims)

tensor([[ 0.0000,  1.0000,  0.0000],
        [ 0.8415,  0.5403,  0.8415],
        [ 0.9093, -0.4161,  0.9093],
        [ 0.1411, -0.9900,  0.1411],
        [-0.7568, -0.6536, -0.7568]])

In [5]:
def attention(x): 
    """
    Simple dot product based attention
    """
    query_layer, key_layer, value_layer = nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1])
    query, key, value = query_layer(x), key_layer(x), value_layer(x)
    attention_weights  = torch.nn.Softmax(-1)(torch.tensordot(query, key, dims=1))
    return torch.sum(value * attention_weights)

In [6]:
x = torch.rand(1, 12)
attention(x)

tensor(0.2987, grad_fn=<SumBackward0>)

In [14]:
def add_norm(residual: torch.Tensor, hidden: torch.Tensor): 
    if residual.shape != hidden.shape: 
        raise ValueError("Shapes mismatch")
    else: 
        output = residual + hidden # element wise addition
        layer_norm = nn.LayerNorm([residual.shape[-2], residual.shape[-1]])
        return layer_norm(output)

In [15]:
# usage example: 

tensor_a = torch.rand([1, 5, 6]) # batch size, sequence length, embedding dimensions
tensor_b = torch.rand([1, 5, 6])
print(tensor_a)
print(tensor_b)
print(f"Final: {add_norm(tensor_a, tensor_b)}")

tensor([[[0.8880, 0.2732, 0.9008, 0.7261, 0.3817, 0.5167],
         [0.9505, 0.1640, 0.9440, 0.8701, 0.3452, 0.9666],
         [0.8365, 0.6466, 0.7730, 0.3295, 0.8195, 0.6639],
         [0.5948, 0.3429, 0.1542, 0.4754, 0.5124, 0.6733],
         [0.8461, 0.8573, 0.4217, 0.9035, 0.1438, 0.6029]]])
tensor([[[0.3105, 0.7641, 0.8900, 0.7077, 0.2874, 0.8929],
         [0.8561, 0.3155, 0.8284, 0.6441, 0.8224, 0.9303],
         [0.8458, 0.9986, 0.6091, 0.9989, 0.5183, 0.3800],
         [0.2779, 0.3161, 0.4261, 0.1374, 0.2840, 0.5407],
         [0.1030, 0.4531, 0.0115, 0.3710, 0.4745, 0.6053]]])
Added: tensor([[[1.1985, 1.0373, 1.7908, 1.4338, 0.6692, 1.4096],
         [1.8067, 0.4795, 1.7724, 1.5142, 1.1676, 1.8969],
         [1.6823, 1.6453, 1.3820, 1.3284, 1.3378, 1.0439],
         [0.8727, 0.6590, 0.5803, 0.6128, 0.7964, 1.2140],
         [0.9491, 1.3104, 0.4331, 1.2745, 0.6183, 1.2083]]])
Final: tensor([[[ 0.0655, -0.3161,  1.4676,  0.6224, -1.1876,  0.5652],
         [ 1.5051, -1.6365,  1

In [None]:
def multihead_attention():
    """
    Scaled Dot product based multi-head attention
    """
    query_layer, key_layer, value_layer = nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1])
    query, key, value = query_layer(x), key_layer(x), value_layer(x)
    attention_weights  = torch.nn.Softmax(-1)(torch.tensordot(query, key, dims=1))
    return torch.sum(value * attention_weights)