## This notebook will attempt to build a GPT-2 style transformer model from scratch

The 'key' individual components are as follows:

1. Embedding (Word and Position Embedding) and Unembedding Layers

2. Layer Norm - This occurs before every new layer (i.e either before attention, MLP or unembed)

3. A transformer block consisiting of:

    i. Self-attention (usually multiple independent heads)

    ii. Multi-layer Perceptron 

The architecture used by the transformer is a 'residual' type architecture, so the main information highway of the transformer is the "residual stream". The putputs of each layer just add back to the residual stream.

The inputs to the model are a series of tokens these are typically an integer representation that are obtained through tokenizing the text

In [73]:
import torch
import torch.nn as nn
from einops import einsum, rearrange, reduce, repeat
import math
import transformer_lens

In [4]:
model_reference = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


### 1. Embedding and Unembedding layers

This include the embedding layer, the positional enmbedding and the unembedding

In [5]:
class EmbeddingLayer(nn.Module):

    """ 
        Embedding layer that takes as inputs a batch of 
        token and embeds them into vectors of size d_model.
        This is the first layer after the input.
    """

    def __init__(self, d_vocab, d_model, init_std = 0.02):
        super().__init__()
        self.W_E = nn.Parameter(torch.empty(d_vocab, d_model)) # embedding matrix
        nn.init.normal_(self.W_E, std=init_std)
        print(self.W_E.shape)

    def forward(self, input_tokens):
        # input_tokens are of size [batch_size, position] and are 
        # essentially integers that index the rows of W_E

        embedded = self.W_E[input_tokens,:]
        return embedded


In [None]:
class PositionEmbedding(nn.Module):
    """ 
    
        This layer produces the relevant positional information.
        This positional information is then added to the output of
        the Embedding layer resulting in the input to the first
        transformer block 
        
    """
    def __init__(self, max_ctx, d_model, init_std = 0.02):
        super().__init__()
        self.W_P = nn.Parameter(torch.empty(max_ctx, d_model))
        nn.init.normal_(self.W_P, std = init_std)
        print(self.W_P.shape)


    def forward(self, input_tokens):
        #input_tokens are of size [batch, position]
        pos_embed = repeat(self.W_P[:input_tokens.shape[1],:], "position d_model -> batch position d_model", batch = input_tokens.shape[0])
        print(pos_embed.shape)
        return pos_embed


In [None]:
class UnembeddingsLayer(nn.Module):
    """ Unembedding layer that takes as inputs the ouput
        from the last transformer block and expands it back
        to a vector of size d_vocab, which are the logits that
        get passed on to the Softmax. 
    
    """
    def __init__(self, d_vocab, d_model, init_std):
        super().__init__()
        self.W_U = nn.Parameter(torch.empty(d_model,d_vocab))
        nn.init.normal_(self.W_U, std = init_std)
        self.b_U = nn.Parameter(torch.zeros(d_vocab,))

    def forward(self, resid_embed_last):
        # resid_embed_last corresponds to the residual stream
        # after the last transformer block. It is of size 
        # [batch_size, position, d_model] 

        logits = einsum(resid_embed_last, self.W_U,  "batch position d_model, d_model d_vocab -> batch position d_vocab") + self.b_U
        return logits


### 2. Layer Norm

In [58]:
class LayerNorm(nn.Module):

    """ Layer Normalization. Effectively z-scores it's input
        along the embedding dimension and then multiplies 
        each embedding dimension independently by learnable
        gains"""

    def __init__(self, d_model, init_std = 0.02, layer_norm_eps = 1e-5):
        super().__init__()
        self.gains = nn.Parameter(torch.empty(d_model,))
        nn.init.normal_(self.gains,std = init_std)
        self.bias = nn.Parameter(torch.zeros(d_model,))
        self.layer_norm_eps = layer_norm_eps

    def forward(self, input):
        # Input is the residual stram and is a tensor of 
        # size [batch, position, d_model]. The layer norm
        # subtracts the mean and variance computed across the
        # d_model dimension
        layer_mean = input.mean(dim=-1,keepdim=True)
        print(layer_mean[0])
        input_centered = (input - layer_mean)
        layer_var = input_centered.var(dim=-1,keepdim=True)
        layer_scale = torch.sqrt(layer_var + self.layer_norm_eps) 
        print(layer_scale[0])
        input_normalized = input_centered/layer_scale
        output = input_normalized *self.gains + self.bias
        return output

In [59]:
d_model = 1024
LN1 = LayerNorm(d_model=d_model)
batch_size, position = 64, 10
input = torch.randn(batch_size,position,d_model)
print(input.shape)
output = LN1(input)
print(output.shape)

torch.Size([64, 10, 1024])
tensor([[ 0.0088],
        [ 0.0064],
        [-0.0147],
        [-0.0163],
        [ 0.0269],
        [-0.0111],
        [-0.0473],
        [-0.0179],
        [ 0.0051],
        [ 0.0409]])
tensor([[1.0105],
        [0.9872],
        [0.9871],
        [1.0257],
        [0.9980],
        [0.9808],
        [1.0369],
        [1.0059],
        [1.0105],
        [1.0125]])
torch.Size([64, 10, 1024])


### 3. Self-Attention

In [95]:
class SelfAttention(nn.Module):
    """ This module implements the self-attention mechanism with
     multiple heads assuming that there are a total of "n_heads" heads.
    The forward() method of this class provides the output of the
    attention module which is a weighted combination of value vectors,
    where the weights are the "attention weights" """
    def __init__(self, n_heads, d_model, d_head, mask_val = -1e5, init_std=0.02):
        super().__init__()
        self.W_Q = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        nn.init.normal_(self.W_Q,std = init_std)
        self.b_Q = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        self.W_K = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        nn.init.normal_(self.W_K,std = init_std)
        self.b_K = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        self.W_V = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        self.b_V = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        nn.init.normal_(self.W_V,std = init_std)
        self.W_O = torch.nn.Parameter(torch.empty(n_heads, d_head, d_model))
        nn.init.normal_(self.W_O,std = init_std)
        self.b_O = torch.nn.Parameter(torch.zeros(d_model))
        self.register_buffer("mask_val",torch.tensor(mask_val,dtype = torch.float32))

    def forward(self, resid_pre):
        # resid_pre is of size [batch, position, d_model]
        keys = einsum(resid_pre, self.W_K, "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head") + self.b_K
        print(f'Keys of size:{keys.shape}')
        queries = einsum(resid_pre, self.W_Q, "batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head") + self.b_Q
        print(f'Queries of size:{queries.shape}')
        values = einsum(resid_pre, self.W_V, "batch val_pos d_model, n_heads d_model d_head -> batch val_pos n_heads d_head") + self.b_V
        print(f'Values of size:{values.shape}')
        attn_pattern = einsum(queries, keys, "batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos") 
        attn_pattern /= math.sqrt(self.W_Q.shape[-1])

        print(f'Shape of attention pattern:{attn_pattern.shape}')
        attn_pattern = self.apply_causal_mask(attn_pattern)
        print(attn_pattern[0,0])
        context_vec = einsum(attn_pattern, values, "batch n_heads query_pos key_pos", "batch val_pos n_heads d_head -> batch n_heads query_pos d_head")
        print(context_vec.shape)

    def apply_causal_mask(self, attn_pattern):
        mask = torch.triu(torch.ones(attn_pattern.shape[-2], attn_pattern.shape[-1]),diagonal=1).bool()
        print(mask)
        attn_pattern.masked_fill_(mask, self.mask_val)
        return attn_pattern

In [96]:
attn_heads = SelfAttention(n_heads = 2, d_model = 1024, d_head = 64)

In [97]:
batch,position = 64,10
resid_pre = torch.randn(batch,position,1024)
attn_heads(resid_pre)

Keys of size:torch.Size([64, 10, 2, 64])
Queries of size:torch.Size([64, 10, 2, 64])
Values of size:torch.Size([64, 10, 2, 64])
Shape of attention pattern:torch.Size([64, 2, 10, 10])
tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[ 2.0982

EinopsError: Unknown axis query_pos on right side of einsum batch val_pos n_heads d_head -> batch n_heads query_pos d_head.

In [94]:
print(attn_heads.mask_val)

None
