In [1]:
# Implement LayerNorm layer
# MLP layer
# Self-attention layer (decoder)
# Casual/masked self-attention layer (encoder)

# Transformer block (inlcude layer norm, MLP, and self-attention)

# Decoder (consisting of blocks, mlps and residual connections)
# Encoder (consisting of blocks, mlps and residual connections)

In [191]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
import numpy as np

import math
from dataclasses import dataclass

In [67]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Usage
set_seed(42)  # Or any integer of your choice

## LayerNorm 

This applies layer normalization over the input (minibatch) of inputs, given by the formula:

\begin{equation}
y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}}*\gamma + \beta
\end{equation}

where $\gamma$ and $\beta$ are learnable parameters. The mean and variance are computed over the last $D$ dimensions.

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

In [164]:
# Example of LayerNorm

batch_size = 10 # Number of examples in the batch
sequence_length = 5 # Number of tokens in the sequence
embedding_dim = 3 # Dimensionality of each token

# Generate random inputs for a single batch 
x = torch.randn(batch_size, sequence_length, embedding_dim)

layer_norm = LayerNorm(embedding_dim, bias=True)
output = layer_norm(x)
manual_output = ((x - x.mean(-1).unsqueeze(-1))/(torch.sqrt(x.var(-1, unbiased=False).unsqueeze(-1)) + 1e-5)) #torch.round((x - x.mean(-1).unsqueeze(-1))/(torch.sqrt(x.var(-1, unbiased=False).unsqueeze(-1)) + 1e-5), decimals=2) * torch.tensor(layer_norm.weight).unsqueeze(-1) * torch.scalar_tensor(1.0) + torch.scalar_tensor(0.0)

# Something weird with numerical precision here but the values are close...

print(torch.allclose(torch.round(output, decimals=2), torch.round(manual_output, decimals=2)))

False


## Casual Self-Attention (aka masked SA)



In [180]:
print("The casual self-attention mask looks like this:")
block_size = 5
print(casual_mask := torch.tril(torch.ones((block_size, block_size))))

print("In order to apply this mask to the attention weights, we need to expand the mask to the batch size and number of heads so that it can be broadcasted over the sequence and batch.")
print(casual_mask:= casual_mask.view(1, 1, block_size, block_size).shape)

The casual self-attention mask looks like this:
tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])
In order to apply this mask to the attention weights, we need to expand the mask to the batch size and number of heads so that it can be broadcasted over the sequence and batch.
torch.Size([1, 1, 5, 5])


In [198]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.embedding_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads"

        # Input layer. Calculate Q, K and V.
        self.attn = nn.Linear(config.embedding_dim, 3 * config.embedding_dim, bias=config.bias)
        # Output layer projection
        self.attn_output = nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias)

        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout_rate)
        self.residual_dropout = nn.Dropout(config.dropout_rate)

        self.num_heads = config.num_heads
        self.embedding_dim = config.embedding_dim
        self.dropout_rate = config.dropout_rate


        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: Flash Attention not available, using PyTorch implementation of slow attention.")
        self.layer_norm = LayerNorm(config.embedding_dim, config.bias)

    def forward(self, x):
        B, T, C = x.size() # B - batch size, T - sequence length, C - embedding dimension 

        # Calculate the Queries, Keys and Values 
        q, k, v = self.attn(x).split(self.embedding_dim, dim=2)

        # Project the queries, keys and values to the shape (B, num_heads, T, head_size)
        # C = note num_heads*head_size 
        # This splits the embedding dimension equally across the number of attention heads. 

        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        if self.flash:
            # att_mask here is important for ignoring elements that should be masked or padded.
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout_rate if self.training else 0, is_causal=False)
        else:
            # Implement self-attention manually
            (B, num_heads, T, head_size)* (B, num_heads, T, head_size)
            # Need to transpsoe k from (B, num_heads, T, head_size) to (B, num_heads, head_size, T) which when multipled by q gives (B, num_heads, T, T)
            attn = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(C // self.num_heads)))
            attn = F.softmax(attn, dim=-1) # Compute SM across the seqeunce 
            attn = self.attn_dropout(attn)
            y = attn @ v
        
        # Reassemble y 
        y = y.transpose(1, 2).contiguous().view(B, T, C) # reasseble y into the original shape 
        y = self.residual_dropout(self.attn_dropout(y)) # Apply dropout to attn and dropout. Project head to output.
        return y

class CasualSelfAttention(nn.Module):
    def __init__(self, config):
        '''
        config: a dictionary containing the following parameters:
            - block_size: the number of tokens in the sequence
            - embedding_dim: the dimensionality of the input embeddings
            - num_heads: the number of attention heads
            - dropout_rate: the dropout rate for the attention weights
            - bias: whether to use bias in the linear layers
        '''
        super().__init__()
        assert config.embedding_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads"

        # Calculate the Queries, Keys, and Values
        self.casual_attn = nn.Linear(config.embedding_dim, 3 * config.embedding_dim, bias=config.bias)

        # Output layer projection 
        self.casual_attn_output = nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias)

        # Regularization
        self.casual_attn_dropout = nn.Dropout(config.dropout_rate)
        self.residual_dropout = nn.Dropout(config.dropout_rate)
        
        self.num_heads = config.num_heads
        self.embedding_dim = config.embedding_dim
        self.dropout_rate = config.dropout_rate

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

        if not self.flash:
            print("WARNING: Flash Attention not available, using PyTorch implementation of slow attention.")
            # Add the casual (triangular) mask to ensure attention is only applied to tokens previously seen in the sequence.
            # Use a buffer to ensure the mask is stored as a buffer in the module, not as a parameter. 
            # Buffers are not updated during training, are saved with the model state and moved to the device. 
            # Commonly used for constants, caching computed values and maintaining statistics.
   
            self.register_buffer("bias", torch.tril(torch.ones((config.block_size, config.block_size))).view(1, 1, block_size, block_size))
        # Layer normalization
        self.layer_norm = LayerNorm(config.embedding_dim, config.bias)
    
    def forward(self, x):
        B, T, C = x.size() # B - batch size, T - sequence length, C - embedding dimension

        # calculate the queries, keys, and values
        q, k, v = self.casual_attn(x).split(self.embedding_dim, dim=2) # split the output of the linear network into q,k and v.

        # project each of the queries, keys, and values to:
        # the shape (B, num_heads, T, head_size) i.e. split the embedding dimensions from 
        # q,k and v equally across the number of attention heads in a block. (head_size = embedding_dim / num_heads)
        # Then transpose the dimensions so that the sequence length is first.
    
        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) 
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout_rate if self.training else 0, is_causal=True)
        else:
            # Implement causal self-attention manually
            # q = (B, num_heads, T, head_size), k = (B, num_heads, T, head_size), k.transpose(-2, -1) -> (B, num_heads, head_size, T) - swapping the last two dimensions enables broadcasting

            attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(C // self.num_heads)) # scaled dot product self-attention (d_k is the square root of the head size)

            # Apply the softmax

            attn = F.softmax(attn, dim=-1)

            # Apply dropout to the attention weights    
        
            attn = self.casual_attn_dropout(attn)

            # Compute the output
            y = attn @ v # (B, num_heads, T, T) x (B, num_heads, T, head_size) -> (B, num_heads, T, head_size)
        
        y = y.transpose(1, 2).contiguous().view(B, T, C) # Reassemble all of the head outputs (B, num_heads, T, head_size) -> (B, T, num_heads, head_size) -> (B, T, C)
        y = self.residual_dropout(self.casual_attn_output(y)) # Project the heads to the output and 
        return y 

##  Multi Layer Perceptron (MLP)

In [187]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__() 
        self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
        self.gelu = nn.GELU() 
        self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

## Transformer Block 

In [189]:
class AttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embed, bias=config.bias)
        self.attn = SelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embed, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) # residual connected attention layer
        x = x + self.mlp(self.ln_2(x)) # residual connectioned mlp layer 
        return x    

class CasualAttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embed, bias=config.bias)
        self.attn = CasualSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embed, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) # residual connection
        x = x + self.mlp(self.ln_2(x)) # residual connection
        return x

In [193]:
@dataclass
class Config:
    block_size: int = 5
    n_token: int = 6
    n_embed: int = 3
    n_layer: int = 4
    n_head: int = 4
    dropout: float = 0.1
    bias: bool = True

# Dummy numbers for the moment

## Decoder 

In [201]:
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embed is not None 
        assert config.block_size is not None 
        self.config = config

        self.decoder = nn.ModuleDict(dict(
            tok_to_emb = nn.Embedding(config.n_token, config.n_embed),
            pos_emb = nn.Embedding(config.block_size, config.n_embed),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([CasualAttentionBlock(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embed, bias=config.bias)
        ))

        # project back to the token space
        self.lm_head = nn.Linear(config.n_embed, config.n_token, bias=config.bias) 
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.decoder.tok_to_emb.weight = self.lm_head.weight

        # Initialize weights here...
        self.apply(self._init_weights)

        for pn, p in self.named_parameters():
            # Apply special scaled init to the residual projections, per GPT-2 paper
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding):
        """
        Return the number of parameters in the model. 
        For non-embedding count (default), the position embeddings get subtracted.
        Token embeddings would be subtracted too, except due to the parameter sharing they are actually 
        used as weights in the final layer, so they are included.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.decoder.pos_emb.weight.numel()
        return n_params
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()      
        
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        
        pos = torch.arange(0, t, dtype=torch.long, device=device) # position indices

        # forward the model 
        tok_emb = self.decoder.tok_to_emb(idx) # token embedding of shape - (b, t, embed_dim)
        pos_emb = self.decoder.pos_emb(pos) # position embedding of shape - (t, n_embed)

        x = self.decoder.drop(tok_emb + pos_emb)

        for block in self.decoder.h:
            x = block(x)
        x = self.decoder.ln_f(x)

        if targets is not None:
            # calculate the loss given some targets 
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference mode. Mini-optimization: only forward the lm_head on the last position 
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None 
        return logits, loss 

    def crop_block_size(self, block_size):
        # model surgery to decrease the block size if necessary 
        assert block_size <= self.config.block_size 
        self.config.block_size = block_size 
        self.decoder.tok_emb.weight = nn.Parameter(self.decoder.tok_emb.weight[:block_size])

        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]


        for block in self.decoder.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
    
    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out params that don't require grad
        param_dict = {pn:p for pn, p in param_dict.items() if p.requires_grad}
        # create optimizer groups. Any parameters that are 2D will be weight decayed.
        # This includes all weight tensors in matmuls + embeddings. All biases and layernorms will not be weight decayed.

        decay_params = [p for n,p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n,p in param_dict.items() if p.dim() < 2]

        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]

        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)

        print(f"num decayed parameter tensors: {len(optim_groups[0]['params'])}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(optim_groups[1]['params'])}, with {num_nodecay_params:,} parameters")

        # Create AdamW optimizer and use the fused version if it's available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"Using fused AdamW: {use_fused}")
        return optimizer

## Encoder

In [202]:
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embed is not None 
        assert config.block_size is not None 

        self.config = config 

        self.encoder = nn.ModuleDict(dict(
            tok_to_emb = nn.Embedding(config.n_token, config.n_embed),
            pos_emb = nn.Embedding(config.block_size, config.n_embed),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([SelfAttentionBlock(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embed, bias=config.bias)
        ))

        # Project output back to the token space - replace with flux + error prediction ...
        self.lm_head = nn.Linear(config.n_embed, config.n_token, bias=config.bias)
        
        # Weight tying 
        self.encoder.tok_to_emb.weight = self.lm_head.weight

        # Init weights
        # ...

    def get_num_params(self, non_embedding):
        """
        Return the number of parameters in the model. 
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            # Subtract positional embeddings 
            n_params -= self.encoder.pos_emb.weight.numel()
        return n_params 

    def _init_weights(self, module):
        if isinstance(module, nn.Linear): 
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size() # check the size of the device

        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 

        pos = torch.arange(0, t, dtype=torch.long, device=device) # position indices 

        # forward the model 
        tok_emb = self.encoder.tok_to_emb(idx) # token embeddings 
        pos_emb = self.encoder.pos_emb(pos) # position embeddings 

        x = self.encoder.drop(tok_emb + pos_emb) # drop tok_emb and pos_emb 

        for block in self.encoder.h:
            x = block(x)
        x = self.encoder.ln_f(x)

        if targets is not None: # Training mode
            # Change to match the desired output...
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        else: # Inference mode
            logits = self.lm_heead(x[:, [-1], :]) # Preserve the time dimension with [-1]
            loss = None 
        return logits, loss 


    def crop_block_size(self, block_size):
        # model surgey to decrease block size if necessary 
        assert block_size <= self.config.block_size 
        self.config.block_size = block_size 
        
        self.encoder.tok_emb.weight = nn.Parameter(self.encoder.tok_emb.weight[:block_size])

        for block in self.encoder.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out params that don't require grad 
        param_dict = {pn:p for pn, p in param_dict.items() if p.requires_grad}
        # Create optimizer groups for decay and non-decay params 

        decay_params = [p for n,p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]

        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]

        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)

        print(f"num decayed parameter tensors: {len(optim_groups[0]['params'])}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(optim_groups[1]['params'])}, with {num_nodecay_params:,} parameters")

        # Create AdamW optimizer and use the fused version if available for efficiency
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 
        use_fused = fused_available and device_type == 'cuda' 
        extra_args = dict(fused=True) if use_fused else dict() 
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"Using fused AdamW: {use_fused}")
        return optimizer 