# Transformers 101

This notebook serves as an exploration of the transformer architecture (Vaswani et. al.) Here, we'll implement in native PyTorch the basic building blocks of the transformer and then put them all together so we have a model architecture to put into `../models`

In the process of putting this together (much like my other exploratory projects) I tried to limit viewing existing code online, and primarily used my notes (pdf attached for anyone interested) as a foundation for this work.

In [1]:
import torch
import math
import torch.nn as nn

We want something with output dims: (sequence_length, output_dim)

In [2]:
def positional_embedding(input_tensor: torch.Tensor, output_dim: int, n=10000): 
    """
    Here, we implement the naive approach from the original 
    paper with the sin and cosine functions. 
    """
    P = torch.zeros((input_tensor.shape[-1], output_dim))
    indices = torch.arange(input_tensor.size(-1))
    i_values = torch.arange(int(output_dim/2))
    denominators = torch.float_power(n, 2*i_values/output_dim)
    P[:, 0::2] = torch.sin(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 0, step by 2 sin for even nums
    P[:, 1::2] = torch.cos(indices.unsqueeze(1) / denominators.unsqueeze(0)) # start at 1, step by 2 cos for odd nums
    return P


In [3]:
a = torch.rand((2, 5))
output_dims = 3
positional_embedding(a, output_dims)

tensor([[ 0.0000,  1.0000,  0.0000],
        [ 0.8415,  0.5403,  0.8415],
        [ 0.9093, -0.4161,  0.9093],
        [ 0.1411, -0.9900,  0.1411],
        [-0.7568, -0.6536, -0.7568]])

In [4]:
def attention(x): 
    """
    Simple dot product based attention
    """
    query_layer, key_layer, value_layer = nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1]), nn.Linear(x.shape[-1], x.shape[-1])
    query, key, value = query_layer(x), key_layer(x), value_layer(x)
    attention_weights  = torch.nn.Softmax(-1)(torch.tensordot(query, key, dims=1))
    return torch.sum(value * attention_weights)

In [5]:
x = torch.rand(1, 12)
attention(x)

tensor(-0.0631, grad_fn=<SumBackward0>)

Just to emulate how it would be implemented, we write out the add norm function below. However in practice, this will be encompassed by each transformer sub module since each of them are followed by addition with residual and layer normalization. 

In [6]:
def add_norm(residual: torch.Tensor, hidden: torch.Tensor): 
    if residual.shape != hidden.shape: 
        raise ValueError("Shapes mismatch")
    else: 
        output = residual + hidden # element wise addition
        layer_norm = nn.LayerNorm([residual.shape[-2], residual.shape[-1]])
        return layer_norm(output)

In [7]:
# usage example: 

tensor_a = torch.rand([1, 5, 6]) # batch size, sequence length, embedding dimensions
tensor_b = torch.rand([1, 5, 6])
print(tensor_a)
print(tensor_b)
print(f"Final: {add_norm(tensor_a, tensor_b)}")

tensor([[[0.4798, 0.0733, 0.9246, 0.9232, 0.8420, 0.7289],
         [0.5115, 0.8718, 0.9227, 0.9176, 0.7529, 0.6058],
         [0.4938, 0.9175, 0.9505, 0.0821, 0.8946, 0.8835],
         [0.8759, 0.1793, 0.3222, 0.4535, 0.7326, 0.8833],
         [0.1397, 0.9988, 0.5021, 0.1885, 0.6068, 0.5807]]])
tensor([[[0.5278, 0.7106, 0.0225, 0.1594, 0.6950, 0.1188],
         [0.1737, 0.4454, 0.9145, 0.4098, 0.7631, 0.5530],
         [0.6709, 0.3885, 0.7824, 0.5751, 0.2715, 0.5452],
         [0.0527, 0.0020, 0.6902, 0.0224, 0.4964, 0.2405],
         [0.9476, 0.8544, 0.1650, 0.4228, 0.0887, 0.4015]]])
Final: tensor([[[-0.1810, -0.7537, -0.3358,  0.0108,  1.1739, -0.5905],
         [-1.0066,  0.6114,  1.9423,  0.6375,  1.1202,  0.2058],
         [ 0.2211,  0.5826,  1.6755, -1.0783,  0.2246,  0.8967],
         [-0.3834, -2.2963, -0.1688, -1.5421,  0.3856,  0.1163],
         [ 0.0228,  1.9835, -1.0526, -1.1954, -0.9801, -0.2461]]],
       grad_fn=<NativeLayerNormBackward0>)


In [8]:
def scaled_dot_product_attention(q, k, d_k):
    # in order to align the dimensions for the dot product, we transpose k along the last two dimensions like this
    return torch.nn.Softmax(-1)(torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k))

d_k and d_v are essentially hyperparameters that are fixed before training. This allows for the query and keys to have the same dimensionality, and for all 3 of them to have consistent dimensionality. In many transformer implementations, d_k and d_v are set to be the same for simplicity but this is not always the case.

In [14]:
def multihead_attention(k, q, v, d_k, d_v, d_model, num_heads):
    """
    Scaled Dot product based multi-head attention
    """
    # declare projection layers - assume all inputs have d_model size in the last dimension, and project to number of heads * d_k or d_v 
    query_layer, key_layer, value_layer = nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads*d_v)
    k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
    residual = q

    # in the following line we apply the linear projections and then reshape the outputs for multihead attention. 
    #The reshaping splits the last dimension of the linear layer's output into num_heads and d_k (or d_v for value). 
    # This creates multiple "heads" in the tensor, each with its own d_k (or d_v) dimension
    k, q, v = key_layer(k).view(batch_size, k_len,  num_heads, d_k), query_layer(q).view(batch_size, q_len,  num_heads, d_k), value_layer(v).view(batch_size, v_len,  num_heads, d_v)
    
    # we perform the following transpose so that the num heads dimension preceeds the seq length dimension. This way, each head can capture different information about the same sequence
    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
    attention = scaled_dot_product_attention(q, k, d_k)
    output = torch.matmul(attention, v)

    # following reshaping is done so that we can add our output to the residual 
    output = output.transpose(1, 2).contiguous().view(batch_size, q_len, -1)
    concatenated_projection = nn.Linear(num_heads * d_v, d_model, bias=False)

    output = concatenated_projection(output)
    output += residual

    print(residual.shape)
    layer_norm = nn.LayerNorm([residual.shape[-2], residual.shape[-1]])
    output = layer_norm(output)

    return output, attention

In [15]:
d_model = 512

# from the paper: To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension d model 
k, q, v = torch.rand((1, 2, d_model)), torch.rand((1, 2, d_model)), torch.rand((1, 2, d_model))
d_k, d_v = 5, 5
num_heads = 4

out, attn = multihead_attention(k, q, v, d_k, d_v, d_model, num_heads)
print(out.shape, attn.shape)

torch.Size([1, 2, 512])
torch.Size([1, 2, 512]) torch.Size([1, 4, 2, 2])


In [11]:
class PositionWiseFFN(nn.Module): 
    def __init__(self, d_model, d_ff, dropout) -> None:
        super(PositionWiseFFN, self).__init__()
        self.fc1 = nn.Sequential(nn.Linear(d_model, d_ff, bias=True),nn.ReLU())
        self.fc2 = nn.Linear(d_ff, d_model, bias=True)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x
        x = self.fc2(self.fc1(x))        
        return self.dropout(self.layer_norm(x+residual))


In [12]:
ffn = PositionWiseFFN(d_model, 2048)
x = torch.rand((1, 2, d_model))
ffn(x)

tensor([[[-0.5878, -0.7818,  0.5670,  ...,  0.0621, -1.2793,  0.7879],
         [ 0.5175,  1.8379,  1.3007,  ...,  0.3632, -0.2532,  0.3898]]],
       grad_fn=<NativeLayerNormBackward0>)

Now that we've implemented the lowest level building blocks of the transormer, below we put them together to build transformer blocks, encoder and decoder layers, and the complete transformer architecture. Now we try and condense everything to a more concise-less experimental implementation. 

In [13]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_k, d_model, d_v, dropout, num_heads) -> None:
        super(MultiHeadAttention, self).__init__()
        self.d_k, self.d_v, self.d_model = d_k, d_v, d_model
        self.query_layer, self.key_layer, self.value_layer = nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads* d_k), nn.Linear(d_model, num_heads*d_v)
        self.layer_norm = nn.LayerNorm(d_model)
        self.concat_projection = nn.Linear(num_heads*d_v, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
        residual = q
        k, q, v = self.key_layer(k).view(batch_size, k_len,  num_heads, d_k), self.query_layer(q).view(batch_size, q_len,  num_heads, d_k), self.value_layer(v).view(batch_size, v_len,  num_heads, d_v)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        attention = scaled_dot_product_attention(q, k, d_k)
        output = torch.matmul(attention, v)
        output = self.concat_projection(output.transpose(1, 2).contiguous().view(batch_size, q_len, -1))
        return self.dropout(self.layer_norm(output+residual))

In [None]:
class EncoderLayer(nn.Module): 
    def __init__(self, d_k, d_model, d_v, num_heads, d_ff, dropout) -> None:
        super(EncoderLayer).__init__()
        self.k_layer, self.q_layer, self.v_layer = nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)
        self.multihead_attention = MultiHeadAttention(d_k, d_model, d_v, dropout, num_heads)
        self.pointwise_ffn = PositionWiseFFN(d_model, d_ff)
    
    def forward(self, x): 
        k, q, v = self.k_layer(x), self.q_layer(x), self.v_layer(x)
        output = self.multihead_attention(q, k, v)
        return self.pointwise_ffn(output)

In [None]:
class Encoder(nn.Module):
    def __init__(self, ) -> None:
        super(Encoder, self).__init__()