#Import libraries

In [None]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

# Print all activation shapes of reference model (for debugging)

In [None]:
for activation_name, activation in cache.cache_dict.items():
  if ".0." in activation_name or "blocks" not in activation_name:
    print(activation_name, activation.shape)

In [None]:

@dataclass
class Config:
    d_model: int = 768               # Dimensionality of the model's hidden layer (each token's representation size)
    debug: bool = True               # Flag to enable or disable debug mode
    layer_norm_eps: float = 1e-5     # Small epsilon value added in layer normalization to avoid division by zero
    d_vocab: int = 50257             # Size of the model's vocabulary (number of unique tokens the model can handle)
    init_range: float = 0.02         # Standard deviation for initializing model weights
    n_ctx: int = 1024                # Maximum sequence length (number of tokens in an input sequence)
    d_head: int = 64                 # Dimensionality of each attention head
    d_mlp: int = 3072                # Dimensionality of the MLP (feed-forward network) hidden layer
    n_heads: int = 12                # Number of attention heads in each transformer block
    n_layers: int = 12               # Number of transformer blocks (layers) in the model


cfg = Config()
print(cfg)


#LayerNorm

1. Make mean 0

2. Normalize to have variance 1

3. Scale with learned weights

4. Translate with learned bias





In [None]:
class LayerNorm(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    # Initialize learnable parameters for scaling (w) and bias (b)
    self.w = nn.Parameter(torch.ones(cfg.d_model)) # Scaling weight initialized to ones
    self.b = nn.Parameter(torch.zeros(cfg.d_model)) # Bias initialized to zeros

  def forward(self, residual):
    if cfg.debu:print("Residual:", residual.shape)
    # Subtract the mean from the residual across the last dimension (d_model)
        # This centers the inputs around zero
    residual = residual - einops.reduce("batch position d_model -> batch position", residual, "mean")
    # Calculate variance, square root it. Add in an epsilon to avoid dividing by 0
    scale = einops.reduce("batch position d_model -> batch position", residual.pow(2), "var") + cfg.layer_norm_eps.sqrt() # to avoid division by zero
    # Normalize the residual by dividing by the calculated scale
    normalized = residual/scale
    # Apply learned scaling (w) and bias (b) to the normalized values
    normalized = normalized * self.w + self.b
    if cfg.debug: print("Normalized:", residual.shape)
    return normalized

#Embedding


In [None]:
class Embed(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    # Initialize an embedding weight matrix (W_E) with shape (d_vocab, d_model)
    self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
    # Fill the weight matrix with normally distributed values
    nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
      if cfg.debug: print("Tokens:", tokens.shape)
      # Retrieve the embeddings corresponding to the input token IDs
      embed = self.W_E[tokens,:]
      if cfg.debug: print("Embeddings:", embed.shape)
      return embed

#Positional Embedding

In [None]:
class PosEmbed(nn.Module):
  '''A lookup table for positional embeddings, providing an embedding vector for each position in the sequence.'''
  def __init__(self, cfg):
    super().__init__()
    # Store the configuration, which includes parameters like sequence length and embedding dimension.
    self.cfg = cfg
    # Initialize a learnable positional embedding matrix, W_P, with shape (n_ctx, d_model)
        # - n_ctx: maximum sequence length (context size), meaning it provides an embedding for each position up to this length.
        # - d_model: embedding size for each position.
    self.W_P = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
    # Initialize W_P with values drawn from a normal distribution.
    # The std (standard deviation) for this distribution is set by cfg.init_range, controlling the spread of initial values.
    nn.init.normal_(self.W_P, std=self.cfg.init_range)

  def forward(self, tokens):
    if cfg.debug: print("Tokens:", tokens.shape)
    # Retrieve only the positional embeddings up to the current sequence length.
    # This slice has shape (sequence_length, d_model), as it selects positional embeddings
    # only for positions up to the sequence length (tokens.size(1)) for each batch.
    pos_embed = self.W_P[:tokens.size(1),:]
    # Repeat pos_embed across the batch dimension, so it can match the shape of tokens.
    # This uses einops to replicate the position embeddings for each sequence in the batch.
    # The resulting shape is (batch, sequence_length, d_model).
    pos_embed = einops.repeat(pos_embed, "d_model -> batch position d_model", batch=tokens.size(0))
    if cfg.debug: print("pos_embeddings:", pos_embed.shape)
    return pos_embed


#Attention

**Step 1**: Produce an attention pattern – for each destination token, probability distribution over previous tokens (incl current token)
2 * Linear map from input -> query, key shape[batch, position, head_index, d_head]
* Dot product every *pair* of queries and keys to get a[batch, head_indey, query_pos, key_pos] (query = dest, key = source)
* Scale and mask attn_scores to make it lower triagular, i.e. causal
*softmax row-wise, to get prob distribution along each of the key_pos dimension - attention pattern!

**Step 2**: Move information from source tokens to destination token using attention pattern (move=apply linear map)
* Linear map from input -> value [batch, key_pos, head_indey, d_head]
* Mix along the key_pos with attn pattern to get z, a mixed value [batch, query_pos, head_index, d_head]
* Map to output, [batch, position, d_model] (position = query_pos, we've summed over all heads)

In [None]:
class Attention(nn.Module):
  '''Attention mechanism for Transformer models.'''
  def __init__(self, cfg):
    super().__init__()
    # Store config, includes model hyperparameters
    self.cfg = cfg
    # Initialize weight matrices for query, key and value
    self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_Q, std= self.cfg.init_range)
    self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
    self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_K, std=self.cfg.init_range)
    self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
    self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_V, std=self.cfg.init_range)
    self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

    # Initialize the output weight matrix (W_O) for combining the results from each attention head.
    # Shape: (n_heads, d_head, d_model), mapping each head back to the model dimension.
    self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
    nn.init.normal_(self.W_O, std=self.cfg.init_range)
    self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

    # Set a large negative value used to mask out future tokens in the attention scores (causal masking).
    self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))

  def forward(self, normalized_resid_pre):
    if cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)

    # Calculate query vectors by applying W_Q to the input and adding bias.
    q = einsum("batch query_pos n_heads d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
    # Calculate key vectors by applying W_K to the input and adding bias.
    k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self. b_K

    # Apply causal masking to ensure the model only attends to previous and current tokens.
    attn_scores = attn_scores/ nath.sqrt(self.cfg.d_head)
    attn_scores = self.apply_causal_mask(attn_scores)

    # Calculate value vectors by applying W_V to the input and adding bias.
    v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

    # Use the attention scores to compute the weighted average of the value vectors.
    # The attention probabilities (attn) weight the value vectors (v).
    z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", attn, v)

    # Combine attention head outputs with W_O to project back to model dimension.
    attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
    return attn_out

  def apply_causal_mask(self, attn_scores):
    # Causal mask to prevent attending to future tokens.
    # We use an upper triangular matrix filled with 1s to mask future positions.
    # Positions above the diagonal (future tokens) are masked by setting scores to a very low value (self.IGNORE).
    mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1),device = attn_scores.device), diagonal=1).bool()
    attn_scores.masked_fill_(mask, self.IGNORE)
    return attn_scores



#MLP

In [None]:
class MLP(nn.Module):
  '''The MLP class represents a feed-forward network in each transformer layer.
  It provides additional non-linear transformations after the attention mechanism.'''
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    # Initialize the weight and bias parameters for the first linear layer, W_in and b_in.
    self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
    nn.init.normal_(self.W_in, std=self.cfg.init_range) # Randomly initialize with a small standard deviation.
    self.b_in = nn.Parameter(torch.zeros(cfg.d_mlp)) # Initialize bias with zeros.

    # Initialize the weight and bias parameters for the second linear layer, W_out and b_out.
    self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
    nn.init.normal_(self.W_out, std=self.cfg.init_range)
    self.b_out = nn.Parameter(torch.zeros(cfg.d_model))

  def forward(self, normalized_resid_mid):
    '''Perform the forward pass, which applies the MLP transformation to the input.'''
    if cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
    # First, apply the first linear transformation (W_in) to expand the input to d_mlp dimensions.
    pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in

    # Apply the activation function (GELU), which introduces non-linearity.
    # GELU (Gaussian Error Linear Unit) is used because it tends to yield better results in transformer architectures.
    post = gelu_new(pre)

    # Apply the second linear transformation (W_out) to project back to the model's original dimension (d_model).
    # This compresses the intermediate representation back to the original model dimension.
    mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
    # Return the final output of the MLP layer, which will be added back to the residual stream.
    return mlp_out

#Transformer Block

In [None]:
class TransformerBlock(nn.Module):
  ''' '''
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg

    # Initialize the first LayerNorm, Attention, second LayerNorm, and MLP layers.
    # These correspond to the main components in a transformer block.
    self.ln1 = LayerNorm(cfg)
    self.attn = Attention(cfg)
    self.ln2 = LayerNorm(cfg)
    self.mlp = MLP(cfg)

  def forward(self, resid_pre):
    '''resid_pre: The input to this block, containing residual stream information'''
    # First, apply layer normalization to the residual stream before feeding it to attention.
    normalized_resid_pre = self.ln1(resid_pre)
    # Pass the normalized residual through the self-attention layer to get attn_out.
    attn_out = self.attn(normalized_resid_pre)
    # Add the attention output (attn_out) to the original input (resid_pre) to form resid_mid.
    resid_mid = resid_pre + attn_out # This is the first residual connection

    # Normalize resid_mid and pass it through the MLP layer.
    normalized_resid_mid = self.ln2(resid_mid)
    # Pass normalized output to the MLP, which consists of a linear layer, activation, and another linear layer.
    mlp_out = self.mlp(normalized_resid_mid)
    # Add the MLP output to resid_mid to form the final output, resid_post, for this block.
    resid_post = resid_mid + mlp_out # second residual connection
    # Return the final residual post-layer (resid_post), which will be the input to the next block.
    return resid_post

#Unembedding

In [None]:
class Unembed(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    # Define the unembedding weight matrix W_U, which maps from the model's hidden dimension to the vocabulary dimension
    self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
    # Initialize W_U with a normal distribution, scaled by the initial range defined in cfg
    nn.init.normal_(self.W_U, std=self.cfg.init_range)
    # Define a bias term b_U for each vocabulary token. By setting `requires_grad=False`,
    # this bias won't be trained (optional for specific designs).
    self.b_U = nn.Parameter(torch.zeros(cfg.d_vocab), requires_grad=False)

    def forward(self, normalized_resid_final):
      if cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
      # Matrix multiplication between the hidden state (`normalized_resid_final`) and `W_U`.
      # This transforms the hidden state into logits over the vocabulary size.
      logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
      return logits

#Full Transformer

In [None]:
class DemoTransformer(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    # Embedding layer that converts tokens to vectors of size d_model
    self.embed = Embed(cfg)
    # Positional embedding layer to give the model a notion of word order
    self.pos_embed = PosEmbed(cfg)
    # Stack of transformer blocks, each adding complexity to the model with self-attention and MLP
    self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
    # Final layer normalization to stabilize output
    self.ln_final = LayerNorm(cfg)
    # Unmebedding layer to map final hidden states back to vocabulary logits
    self.unembed = Unembed(cfg)

    def forward(self, tokens):
      # Convert tokens to embeddings
      embed = self.embed(tokens)
      # get positional embeddings for each position in sequence
      pos_embed = self.pos_embed(tokens)
      # combine token embeddings with positional embeddings to form initial residual stream
      residual = embed + pos_embed

      # pass combined embedding through each transformer block
      for block in self.blocks:
        residual = block(residual)

      # apply final layer normalization to output of last transformer block
      normalized_resid_final = self.ln_final(residual)
      # map final normalized residuals to logits over vocabulary
      logits = self.unembed(normalized_resid_final)
      return logits


# Cross Entropy Loss

In [None]:
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()
loss = lm_cross_entropy_loss(demo_logits, test_tokens)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

# Train Model

## Read data

## Initialize Model

In [None]:
model = DemoTransformer(cfg)

## Optimizer

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Run Training Loop

In [None]:
losses = []
print("Number of batches:", len(data_loader))
for epoch in range(num_epochs):
    for c, batch in tqdm.tqdm(enumerate(data_loader)):
        tokens = batch['tokens'].cuda()
        logits = model(tokens)
        loss = lm_cross_entropy_loss(logits, tokens)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        if c % log_every == 0:
            print(f"Step: {c}, Loss: {loss.item():.4f}")
        if c > max_steps:
            break