<a href="https://colab.research.google.com/github/DiiGii/gpt2-scratch/blob/main/gpt2_attention_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Attention transformer

In the following notebook, we will be implementing an attention transformer in PyTorch. We'll accomplish this by implementing each of the following:

1. LayerNorm

  Layer Normalization (**LayerNorm**), used in GPT-2, normalizes activations within each training example across all its features, unlike Batch Normalization which normalizes across a batch. Applied after attention and feed-forward layers, LayerNorm stabilizes training by reducing internal covariate shift, leading to faster training and better performance.

2. Attention mechanism
  
  The **attention mechanism** is the heart of GPT2, because it solves the problem of long-range dependencies -- when words that are far apart are related. In particular, the attention mechanism lets the model directly consider every word when processing each one, like a spotlight highlighting relevant words regardless of their position.

3. Multi-head attention

  **Multi-head attention** enhances the basic attention mechanism by running multiple attention calculations in parallel. Each "head" learns different aspects of the input, like in a sentence, one head might focus on the action while another focuses on the subject's description. This parallel processing allows the model to capture more diverse information and semantic relationships, especially useful in longer sequences for improved accuracy and efficiency.

4. Bigram Attention Model

  The **BigramAttentionModel** uses token embeddings (representing word meaning) and positional embeddings (representing word order) combined as input to transformer blocks with multi-head attention. A final linear layer maps these to the vocabulary for text generation, with LayerNorm for regularization. Training uses cross-entropy loss, and generation is done by iteratively predicting characters. In this step, we put together all of our work in the previous steps and test our model.

# Setup

We first have to import the PyTorch libraries to build the Transformer model.

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
# Can import from attention module in week4

Here we define the hyperparameters for the model and training.

In [None]:
### Hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000 # how many iterations do we want to train our model for
eval_interval = 100 # at which iterations do we perform our evaluation
learning_rate = 1e-3 # how much do we want to optimize our weights at each step
device = 'cuda' if torch.cuda.is_available() else 'cpu' # determines device we run the tensor on
eval_iters = 200 # how many evaluation intervals do we use to get the loss average
n_embd = 64 # dimension of embeddings for our input
n_head = 4 # number of attention heads working in parallel
n_layer = 4 # number of layers in our attention head that our input goes through
dropout = 0.0 # dropout probability aka probability that a weight turns to 0

Here we load in our dataset to be used by the Transformer Model we are creating.

In [None]:
### Preparing Data
torch.manual_seed(1337)

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

--2025-01-02 03:55:19--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-01-02 03:55:19 (18.8 MB/s) - ‘input.txt’ saved [1115394/1115394]



These are miscellaneous functions that will be used in our model training and evaluation.

The first function gets the batches for out dataset. For our data, we often process it in batches. Each batch is a set size of data taken randomly from the dataset. Each training epoch operates on one batch. The input is always what is ahead in the sequence and the target is the next character.

The second function estimates the loss for our model using a simple training loop with gradient calculation.

In [None]:
### Miscellaneous Functions
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    # estimate the average loss for each data split for evaluation
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# LayerNorm

Now we build LayerNorm from scratch.

![The equation we use for LayerNorm](https://miro.medium.com/v2/resize:fit:1040/0*qN-QGSHiY85obQfj)

Basically, LayerNorm takes the mean and variance of the input x and uses it to normalize the inputs. Following the Central Limit Theorem, we normalize the inputs so that they have a mean of 0 and a variance of 1 by subtracting the mean and dividing by the standard deviation (square root of variance). This allows the inputs to follow a standard normal distribution preventing the data from getting extraneous or outlier values that can cause exploding or vanishing gradients, allowing for more stable training. In order to make sure that we don't divide by 0, we add epsilon, a small value, to the denominator.

Next, we have the shift and scale parameters. Gamma is the scale parameter which is basically the variance of the normalized distribution. You can think of it as how much the normal distribution is stretched from the standard normal (like how large the range of common values is). Beta is the shift parameter which is basically the mean of the normalized distribution. This is where the normal distribution is centered. These two parameters are learned so you can adjust them to the data that is used.

In [None]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        # initialize the parameters of the LayerNorm equation
        self.a_2 = 1
        self.b_2 = 0
        self.eps = eps

    def forward(self, x):
        # implement LayerNorm based on the equation above
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [None]:
#@title Solutions (no peeking!)
def hidden_layer_norm_forward(x, a_2, b_2, eps):
    # implement LayerNorm based on the equation above
    mean = x.mean(-1, keepdim=True)
    std = x.std(-1, keepdim=True)
    return a_2 * (x - mean) / (std + eps) + b_2

In [None]:
eps = 1e-6
batch_size = 32
sequence_length = 10
hidden_size = 64
test_input = torch.randn(batch_size, sequence_length, hidden_size)

test = LayerNorm(hidden_size, eps)

truth = hidden_layer_norm_forward(test_input, test.a_2, test.b_2, eps)
user = test.forward(test_input)

assert torch.allclose(truth, user), "oh no there was a mismatch between your implmentation and ours!"
print("Passed all tests!")

Passed all tests!


# Attention

Next, lets build the attention mechanism for our Transformer. This is the core component that makes Transformers so powerful.

First, we will define the attention head where our attention mechanism will take place. In it we initialize the key, query, and value projection layers where we project the inputs into another subspace (think of it as another dimension or embedding of storing information on the input semantics). These projection layers are learned and can be adjusted. We also defined something called "tril" which we will go into more in just a second. Finally, there is dropout a regularization technique to prevent overfitting and allow the model to learn more effectively.

In the attention mechanism, we first need to get our key, query, and value embeddings. We do this by throwing the input through linear layers to map them to another information filled subspace. These embeddings are very important. For example, lets say your input is a sentence. Each sentence is broken up into tokens (typically one word) and these tokens are represented by embeddings. The projection layers map these embeddings into another subspace as vectors (key, value, and query). The magnitude and direction or what defines these vectors encode information about the token they are correlated to. We will use this information later on to determine the meaning of the input sentence in our attention mechanism.

We can use Week 4's attention notebook and implementation here.

![Here is the attention mechanism equation](https://miro.medium.com/v2/resize:fit:1400/0*4L90D4iDB_R1Uljs)

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

# Multi Head Attention

After developing our attention mechanism, the engine of the Transformer model, we need to add a few things to complete the architecture. One of them is MultiHeadAttention. The idea of MultiHeadAttention is to allow for multiple attention mechanisms to happen in parallel allowing for more information and semantic parsing. In longer sequences with longer contexts, this extra information is very useful in order to maximize accuracy and efficiency. Each attention head in a MultiHeadAttention block can learn a separate function and thus retain different things about the input. For example, with a sentence, one head can learn about what the subject is doing and another head can learn about how the subject looks.

![Diagram of MultiHeadAttention](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcR2H71D22diHDQKf6STcbHbRgvdynJ_c0RZZA&s)

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        # initialize all the heads for the MultiHeadAttention module (hint: use ModuleList)
        self.heads = torch.nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        # feed x through all the attention heads
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return out


Next, we need to put it all together into a Transformer Block. First, we initialize the MultiHeadAttention block to be used in the Transformer. For each head, the subspace that they work in is the size of the embedding divided by the number of heads. The more heads can lead to more parallel attention mechanisms at once but also force the embedding space they work in to be smaller, leading to more limited information.

We also initialize a layer norm as regularization for the output of the MultiHeadAttention block. In a normal Transformer block there would also be a feed forward but we will cover that in more detail later.

The forward function just uses the initialized MultiHeadAttention on the input and throws the output of that through the LayerNorm.

In [None]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        # initialize the components of a attention only Block using MultiHeadAttention
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ln1 = LayerNorm(n_embd)

    def forward(self, x):
        # put the input throught the intialized components
        x = self.sa(x)
        x = self.ln1(x)
        return x

# BigramAttentionModel

Finally, we put the blocks together to create the BigramAttentionModel. First we have a token_embedding_table which basically embeds the input tokens into embeddings based on a table. These embeddings are numerical representations of the input, vectors of size n_embd that contain information of the input they are representing.

![Here is a diagram of embeddings. As you can see words are mapped to vectors where more related words have more similar vectors](https://arize.com/wp-content/uploads/2022/06/blog-king-queen-embeddings.jpg)

Next, we have the position_embedding_table which is used to encode the position of each token in the input. This is another lookup table that takes the position of a token in the sequence and encodes it in a numerical vector of the same size as the input embeddings. This is done so that we can add the positional embeddings to the input embeddings. This allows the model to have awareness of where in a sentence a specific token is and the relative distances between different tokens.

![In a typical Transformer, we use sin and cos functions to create positional encodings. However, we are just using a table for this application. Here is how positional encodings are used within a LLM.](https://miro.medium.com/v2/resize:fit:1400/0*oiP-eu8BmJx5SVp7.png)

Lastly, we have the actual transformer blocks which is the workhorse of the model. These blocks contain the multi headed attention mechanism.

We finally have a linear layer to map the embeddings and the information they carry to the vocab_size so they can be decoded character by character into natural language. LayerNorm is used once again to regularize the output of the blocks.

The forward function uses all the initialized modules on the input in order and then if a target is provided, helps compute the loss. We use cross entropy between the predicted next character logit (which is basically the probability distribution of the next character) and the target next character to calculate the loss. In order to make it compatible with the cross entropy function we need to reshape the tensor so that the dimensions match. The dimensions of the two tensors are provided in the comments.

We also have a generate function for inference. The idea of this function is to generate max_new_tokens of characters within a natural language passage using the Bigram Language Transformer Model.

## Depth of Model

A major hyperparameter to consider is depth of the model. To create a deeper model, we can add more Transformer Blocks to the model (increase what n_layer is). But what is the advantage of that?

In lots of scenarios, stacking layers and having more parameters can lead to greater results as evidenced by the increasing size of LLMs and their growing capabilities. However, just stacking more blocks isn't necessarily easy.

More blocks means more compute and that means more necessary GPUs and physical hardware to support the machine. There's a reason why only big companies with lots of money can create the best models in the world. There's also the problem of unstable training. With bigger models, they become harder to train due to unstable gradients. Gradient calculation starts from the lowest layer and becomes smaller and smaller as you go through the model. This can lead to vanishing gradients or increasingly slow training.

In [None]:
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # initialize the token_embedding table
        self.token_embedding_table = torch.nn.Embedding(vocab_size, n_embd)
        # initialize the position embedding table
        self.position_embedding_table = torch.nn.Embedding(block_size, n_embd)

        # intialize the blocks which are like attention layers for our model
        self.blocks = torch.nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        # intialize the layer norm and projection layer to predict the next character
        self.ln_f = LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)



    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        # feed the idx through the model and initialized parameters
        '''
        1. First create embeddings for the input using the token_embedding_table
        2. Then get the position embeddings (you can use torch.arange)
        3. Add the position embeddings and token_embeddings
        4. Get the logits (probabilities for the next character) using the blocks
        5. Layer Norm the logits
        6. Feed the logits through the last linear layer

        MAKE SURE TO PAY ATTENTION TO DIMENSIONS
        '''
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        logits = self.blocks(x) # (B,T,C)
        x = self.ln_f(logits) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)


        ### ------------------- ###
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # (B*T, C)
            targets = targets.view(B*T) # (B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [None]:
#@title Bigram Solution (don't look, pretty please?)
def hidden_bigram_forward(idx, token_embedding_table, position_embedding_table, blocks, ln_f, lm_head, targets=None):
  B, T = idx.shape

  # idx and targets are both (B,T) tensor of integers
  tok_emb = token_embedding_table(idx) # (B,T,C)
  pos_emb = position_embedding_table(torch.arange(T, device=device)) # (T,C)
  x = tok_emb + pos_emb # (B,T,C)
  logits = blocks(x) # (B,T,C)
  x = ln_f(logits) # (B,T,C)
  logits = lm_head(x) # (B,T,vocab_size)

  if targets is None:
      loss = None
  else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C) # (B*T, C)
      targets = targets.view(B*T) # (B*T)
      loss = F.cross_entropy(logits, targets)

  return logits, loss

In [None]:
test_input = torch.randint(0, 20, (32, 10))

user = BigramLanguageModel()
user_output = user.forward(test_input)
truth = hidden_bigram_forward(test_input, user.token_embedding_table, user.position_embedding_table, user.blocks, user.ln_f, user.lm_head)

assert torch.allclose(truth[0], user_output[0]), "oh no there was a mismatch between your implmentation and ours!"
print("Passed all tests!")

Passed all tests!


# Training Module

This is the training module. Here we intialize the model, print the parameters, then have a typical training loop for the Transformer Model with evaluation intervals.

In [None]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    # estimate the average loss for each data split for evaluation
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

### Model Training and Initialization
# initialize model and set it to device
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    # feed input and target into model
    logits, loss = model.forward(xb, yb)
    # implement rest of training loop
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

0.059585 M parameters
step 0: train loss 4.3461, val loss 4.3383
step 100: train loss 3.3137, val loss 3.3482
step 200: train loss 3.3002, val loss 3.3313
step 300: train loss 3.1640, val loss 3.1847
step 400: train loss 2.8201, val loss 2.8168
step 500: train loss 2.6066, val loss 2.6042
step 600: train loss 2.5183, val loss 2.5152
step 700: train loss 2.4581, val loss 2.4591
step 800: train loss 2.4103, val loss 2.4033
step 900: train loss 2.3697, val loss 2.3660
step 1000: train loss 2.3482, val loss 2.3470
step 1100: train loss 2.3075, val loss 2.3206
step 1200: train loss 2.2677, val loss 2.2827
step 1300: train loss 2.2532, val loss 2.2652


KeyboardInterrupt: 

This code uses the model to generate a shakespeare passage since the dataset we gave the model to train on is shakespeare passages.

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))


UCOFLERS:
Ny silliowc o themn. the soow ith to setple os sertprert on wath te?

Thet it pretr eane metol came.

DUET:
Serdeter
Doft, nif manct healoter;
Seres icon! sever wol nott, to tafe malle wiilot jel, Thel weuares reiglbleirns hopecenceld-elf hiters-shantht nuks, nast in sole heeae latexce: neutd;
Hn le ftrackes thevee ive lyoo theat as rikwe dinen: youl fuver rtize isimshomelsts; t,
Thont coteark. he ist bract exegt ilpriche mole on gilll
De tith as mant thee ney! tHe tracon me kin,
'ct lat Catinthtul for neat.

LALD:
Whde illen appraks wius and iI t.o gonf, propend the wepn mathakin yeoryat, o a,
Serawh thene, watuboldave for ox; pollee;
I'cl martatere thay wikes mad of,e wnow of dere atth cille of bentect.

GBERKEEAR:
And to hex.

The giubpel fvaen thoul thint,
'De me tar'ss hy; perothep,
Bender selvucpe; fue trsad'd natprind; leee, ad dee of nave;
ReNat of feriing goy, thre rves we, to olikord.
-Thanble stors,
Y'tr bace of frercand wwid tthay, wnod dloldef forke seari, thee 

What do you notice about the passage?

The model does a good job of getting the overall Shakespeare play structure correct. However, the specific phrases and names lack accuracy.

Particularly, we see that some sentence start correctly, with "Were", "By", "The", and similar leading words in Shakespeare plays. The issue comes when the full sentence is constructed.

Please note that the hyperparameters as they are currently defined can take a while to run. You can adjust them, but please note that decreasing the number of iterations, or stopping the training loop early, means the text generated by the model will be of lower quality.