## This notebook will attempt to build a GPT-2 style transformer model from scratch

**Disclaimer** : I have relied on inspiration from Neel Nanda's walkthrough and Anthropic's Transformer Circuits thread. However, this is completely my own implementation.

**Author**: Aniruddh Galgali, 2024

The 'key' individual components are as follows:

1. Embedding (Word and Position Embedding) and Unembedding Layers

2. Layer Norm - This occurs before every new layer (i.e either before attention, MLP or unembed)

3. A transformer block consisiting of:

    i. Self-attention (usually multiple independent heads)

    ii. Multi-layer Perceptron 

The transformer uses a 'residual' type architecture i.e the main information highway is the transformer "residual stream" which is of dimensionality 'd_model'. The outputs of each sub-layer (i.e self-attention or MLP) just add back to the original residual stream.

The inputs to the model are a series of tokens these are typically an integer representation that are obtained through tokenizing the text

In [1]:
import torch
import torch.nn as nn
import numpy as np
from einops import einsum, rearrange, reduce, repeat
import math
from dataclasses import dataclass
from transformer_lens.utils import tokenize_and_concatenate, keep_single_column
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import pprint

### 0.0 Dataset
We will use the Tiny Stories dataset that's available on the HuggingFace hub to train our transformer

In [2]:
ds = load_dataset("roneneldan/TinyStories",split="train")

### 0.1 Exploring the dataset

In [3]:
print(f'Dataset:{ds}')
print(f'Dataset features: {ds["train"].features}')

Dataset:DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})
Dataset features: {'text': Value(dtype='string', id=None)}


As you can see above, the data is automatically split into a train and validation set. Makes life easy for us! Let's now look at some examples from the training set

In [4]:
# Printing some examples
print(f' Tiny story : random sample 1 \n')
pprint.pprint(ds['train'][0])
print(f' Tiny story : random sample 2 \n')
pprint.pprint(ds['train'][100])

 Tiny story : random sample 1 

{'text': 'One day, a little girl named Lily found a needle in her room. She '
         'knew it was difficult to play with it because it was sharp. Lily '
         'wanted to share the needle with her mom, so she could sew a button '
         'on her shirt.\n'
         '\n'
         'Lily went to her mom and said, "Mom, I found this needle. Can you '
         'share it with me and sew my shirt?" Her mom smiled and said, "Yes, '
         'Lily, we can share the needle and fix your shirt."\n'
         '\n'
         "Together, they shared the needle and sewed the button on Lily's "
         'shirt. It was not difficult for them because they were sharing and '
         'helping each other. After they finished, Lily thanked her mom for '
         'sharing the needle and fixing her shirt. They both felt happy '
         'because they had shared and worked together.'}
 Tiny story : random sample 2 

{'text': 'There was a little girl with dark hair. Her name was

### 0.3 Tokenizer

### Let's use the GPT-2 tokenizer to tokenize the above examples

In [5]:
# Using the same tokenizer as GPT2
tokenizer_model_name = "roneneldan/TinyStories-1M"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
print(tokenizer)

GPT2TokenizerFast(name_or_path='roneneldan/TinyStories-1M', vocab_size=50257, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)


We will use the 'tokenize_and_concatenate' function in the TransformerLens API to actually obtain the tokenized representation of the text.

In [12]:
MAX_CTX_LENGTH = 512 # This is the maximum context length
tokenized_dataset = tokenize_and_concatenate(ds["train"],tokenizer=tokenizer,max_length= MAX_CTX_LENGTH, num_proc = 4)

Map (num_proc=4):   0%|          | 0/2119719 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (10666 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (12536 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (12297 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (13147 > 2048). Running this sequence through the model will result in indexing errors


In [15]:
# Looking at some tokens
pprint.pprint(tokenized_dataset['tokens'][0,:40])

# Looking at whether tokens corrrespond to the correct part of text
pprint.pprint(tokenizer.decode(tokenized_dataset['tokens'][0,:40]))

tensor([50256,  3198,  1110,    11,   257,  1310,  2576,  3706, 20037,  1043,
          257, 17598,   287,   607,  2119,    13,  1375,  2993,   340,   373,
         2408,   284,   711,   351,   340,   780,   340,   373,  7786,    13,
        20037,  2227,   284,  2648,   262, 17598,   351,   607,  1995,    11])
('<|endoftext|>One day, a little girl named Lily found a needle in her room. '
 'She knew it was difficult to play with it because it was sharp. Lily wanted '
 'to share the needle with her mom,')


In [None]:
BATCH_SIZE = 4
data_loader = DataLoader(tokenized_dataset,batch_size=BATCH_SIZE, shuffle=True,pin_memory=True,num_workers=4)

In [None]:
# Creating a config dataclass that contains all the hyper-parameters
@dataclass
class Config:
    # These numbers are not from the standard implementation of GTP-2. Instead 
    # most numbers are much smaller due to training resource constraints. Only
    # the d_vocab is consistent, as without that the tokenizer will not work.
    d_model = 256
    d_head = 64
    n_heads = 5
    d_mlp = 1024
    d_vocab = 50257
    layer_norm_eps = 1e-5
    init_std = 0.02
    max_ctx = MAX_CTX_LENGTH
    batch_size = BATCH_SIZE

### 1. Embedding and Unembedding layers

This include the embedding layer, the positional enmbedding and the final unembedding layer

In [2]:
class EmbeddingLayer(nn.Module):

    """ 
        Embedding layer that takes as inputs a batch of 
        token and embeds them into vectors of size d_model.
        This is the first layer after the input.
    """

    def __init__(self, d_vocab, d_model, init_std = 0.02):
        super().__init__()
        self.W_E = nn.Parameter(torch.empty(d_vocab, d_model)) # embedding matrix
        nn.init.normal_(self.W_E, std=init_std)
        print(self.W_E.shape)

    def forward(self, input_tokens):
        # input_tokens are of size [batch_size, position] and are 
        # essentially integers that index the rows of W_E

        embedded = self.W_E[input_tokens,:]
        return embedded
    
class PositionEmbedding(nn.Module):
    """ 
    
        This layer produces the relevant positional information.
        This positional information is then added to the output of
        the Embedding layer resulting in the input to the first
        transformer block 
        
    """
    def __init__(self, max_ctx, d_model, init_std = 0.02):
        super().__init__()
        self.W_P = nn.Parameter(torch.empty(max_ctx, d_model))
        nn.init.normal_(self.W_P, std = init_std)
        print(self.W_P.shape)


    def forward(self, input_tokens):
        #input_tokens are of size [batch, position]
        pos_embed = repeat(self.W_P[:input_tokens.shape[1],:], "pos d_model -> batch pos d_model", batch = input_tokens.shape[0])
        print(pos_embed.shape)
        return pos_embed
    
class UnembeddingLayer(nn.Module):
    """ Unembedding layer that takes as inputs the ouput
        from the last transformer block and expands it back
        to a vector of size d_vocab, which are the logits that
        get passed on to the Softmax. 
    
    """
    def __init__(self, d_vocab, d_model, init_std):
        super().__init__()
        self.W_U = nn.Parameter(torch.empty(d_model,d_vocab))
        nn.init.normal_(self.W_U, std = init_std)
        self.b_U = nn.Parameter(torch.zeros(d_vocab,))

    def forward(self, resid_embed_last):
        # resid_embed_last corresponds to the residual stream
        # after the last transformer block. It is of size 
        # [batch_size, position, d_model] 

        logits = einsum(resid_embed_last, self.W_U,  "batch pos d_model, d_model d_vocab -> batch pos d_vocab") + self.b_U
        return logits



### 2. Layer Norm

Note that using the in-built nn.LayerNorm() is maybe more conveneint. But here, I have written my own, for learning purposes.

In [5]:
class LayerNorm(nn.Module):

    """ Layer Normalization. Effectively z-scores it's input
        along the embedding dimension and then multiplies 
        each embedding dimension independently by learnable
        gains"""

    def __init__(self, d_model, init_std = 0.02, layer_norm_eps = 1e-5):
        super().__init__()
        self.gains = nn.Parameter(torch.empty(d_model,))
        nn.init.normal_(self.gains,std = init_std)
        self.bias = nn.Parameter(torch.zeros(d_model,))
        self.layer_norm_eps = layer_norm_eps

    def forward(self, input):
        # Input is the residual stram and is a tensor of 
        # size [batch, position, d_model]. The layer norm
        # subtracts the mean and variance computed across the
        # d_model dimension
        layer_mean = input.mean(dim=-1,keepdim=True)
        input_centered = (input - layer_mean)
        layer_var = input_centered.var(dim=-1,keepdim=True)
        layer_scale = torch.sqrt(layer_var + self.layer_norm_eps) 
        input_normalized = input_centered/layer_scale
        output = input_normalized *self.gains + self.bias
        return output

### 3. Self-Attention

In [7]:
class SelfAttention(nn.Module):
    
    """ This module implements the self-attention mechanism with
     multiple heads assuming that there are a total of "n_heads" heads.
    The forward() method of this class provides the output of the
    attention module which is a weighted combination of value vectors,
    where the weights are the "attention weights" """

    def __init__(self, n_heads, d_model, d_head, mask_val = -1e5, init_std=0.02):
        super().__init__()
        self.W_Q = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        nn.init.normal_(self.W_Q,std = init_std)
        self.b_Q = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        self.W_K = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        nn.init.normal_(self.W_K,std = init_std)
        self.b_K = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        self.W_V = torch.nn.Parameter(torch.empty(n_heads, d_model, d_head))
        nn.init.normal_(self.W_V,std = init_std)
        self.b_V = torch.nn.Parameter(torch.zeros(n_heads, d_head))
        self.W_O = torch.nn.Parameter(torch.empty(n_heads, d_head, d_model))
        nn.init.normal_(self.W_O,std = init_std)
        self.b_O = torch.nn.Parameter(torch.zeros(d_model))
        self.register_buffer("mask_val",torch.tensor(mask_val,dtype = torch.float32))

    def forward(self, resid_pre):
        # resid_pre is of size [batch, position, d_model]
        
        # Computing the keys, queries and values 
        keys = einsum(resid_pre, self.W_K, "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head") + self.b_K
        queries = einsum(resid_pre, self.W_Q, "batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head") + self.b_Q
        values = einsum(resid_pre, self.W_V, "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head") + self.b_V
        
        # Computing the attention pattern
        attn_pattern = einsum(queries, keys, "batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos") 
        attn_pattern /= math.sqrt(self.W_Q.shape[-1])
        mask = torch.triu(torch.ones(attn_pattern.shape[-2], attn_pattern.shape[-1]),diagonal=1).bool()
        attn_pattern.masked_fill_(mask, self.mask_val)
        attn_pattern = nn.Softmax(dim=-1)(attn_pattern)
        context_vec = einsum(attn_pattern, values, "batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch n_heads query_pos d_head")

        # Note that in the original formulation of the transformer, one concatenates the outputs of the heads and then multiplies by a W_O that has 
        # block rows equivalent to the individual W_O for each head. This is done for efficiency purposes. Therefore, my implementation below is 
        # not efficient, but is only for didactic clarity.

        # Output : Sum of all attention heads (see comment above)
        output = einsum(context_vec, self.W_O, "batch n_heads query_pos d_head, n_heads d_head d_model -> batch query_pos d_model") + self.b_O
        return output


### 3. MLP

In [8]:
def gelu_new(input):
    """
    This is the activation function used by GPT-2. Apparently,
    it's slightly different from PyTorch's nn.gelu() implementation
    """
    return 0.5*input* (1.0 + torch.tanh(math.sqrt(2.0/math.pi) * 
                        (input + 0.04715 *torch.pow(input, 3.0))))


class MLP(nn.Module):
    def __init__(self, d_mlp, d_model):
        super().__init__()
        self.input_layer = nn.Linear(d_model, d_mlp)
        self.output_layer = nn.Linear(d_mlp, d_model)
        self.act_fn = gelu_new

    def forward(self, resid_attended):
        pre_act = self.input_layer(resid_attended)
        act = self.act_fn(pre_act)
        output = self.output_layer(act)
        return output


### 4. Transformer Block

In [9]:
class TransformerBlock(nn.Module):

    def __init__(self, d_model, d_mlp, d_head, n_heads):
        super().__init__()
        self.ln1 = LayerNorm(d_model =d_model)
        self.attn = SelfAttention(n_heads, d_model, d_head)
        self.ln2 = LayerNorm(d_model =d_model)
        self.mlp = MLP(d_mlp, d_model) 

    def forward(self, resid_pre):
        # resid_pre is of size [batch, position, d_model]
        ln1_out = self.ln1(resid_pre)
        attn_out = self.attn(ln1_out)
        resid_mid = resid_pre +  attn_out
        ln2_out = self.ln2(resid_mid)
        resid_post = resid_mid + self.mlp(ln2_out)
        return resid_post

### 5. Full Transformer

In [None]:
class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.embed = EmbeddingLayer(config.d_vocab, config.d_model, config.init_std)
        self.pos_embed = PositionEmbedding(config.max_ctx, config.d_model, config.init_std)
        self.transformer = nn.ModuleList([TransformerBlock(config.d_model, config.d_mlp,
                                config.d_head, config.n_heads) for _ in config.num_blocks])
        self.ln = LayerNorm(config.d_model, config.init_std, config.layer_norm_eps)
        self.unembed = UnembeddingLayer(config.d_vocab, config.d_model, config.init_std)


    def forward(self, tokens):
        embedded = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        resid = embedded + pos_embed
        for i in range(len(self.transformer)):
            resid = self.transformer[i](resid)
        resid_normalized_final = self.ln(resid)
        logits = self.unembed(resid_normalized_final)

        return logits

### 6. Training

In [None]:
def cross_entropy_lm(logits, tokens):
    # tokens of size [batch, pos]
    # logits of size [batch, pos, d_vocab]

    log_probs = nn.LogSoftmax(dim=-1)(logits)
    pred_log_probs = torch.gather(log_probs[:,:-1,:],dim=-1, index = tokens[:,1:].unsqueeze(dim=-1)).squeeze(dim=-1)
    return pred_log_probs.mean(dim=-1)



