In [1]:
import torch.nn as nn

In [6]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [5]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    """
    Fused multi-head causal self-attention (Transformer-style).

    This module projects the input once into Q/K/V of size `d_out`, splits them
    into `num_heads` heads of size `head_dim = d_out // num_heads`, applies a
    **causal mask** (no future attention), computes attention per head, and then
    concatenates and linearly projects back to `d_out`.

    Args:
        d_in (int): Input embedding size.
        d_out (int): Total output size across all heads (also the model size).
        context_length (int): Maximum supported sequence length (mask size).
        num_heads (int): Number of attention heads.
        dropout (float, optional): Dropout probability on attention weights.
        qkv_bias (bool, optional): Whether to include bias in Q/K/V projections.

    Shapes:
        Input:  x -> (B, T, d_in)
        Output: y -> (B, T, d_out)

        where:
            B = batch size, T = sequence length.
    """

    def __init__(self, d_in, d_out, 
                 context_length, num_heads,
                 dropout=0.0, qkv_bias=False):
        super().__init__()

        # Ensure the total output dimension splits evenly across heads.
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        # Save basic hyperparameters.
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # per-head feature dimension

        # Single Q/K/V projections mapping d_in -> d_out (total across all heads).
        # (Corrected: nn.Linear, not nn.linear)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Register a (context_length x context_length) upper-triangular mask.
        # mask[i, j] = True when j > i (i.e., "future" positions to be masked).
        # (Fixed parentheses and made it boolean at creation.)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
        )

        # Dropout applied to attention weights after softmax (not to scores).
        self.dropout = nn.Dropout(dropout)

        # Final linear projection after concatenating heads: (B, T, d_out) -> (B, T, d_out).
        self.out_proj = nn.Linear(d_out, d_out)
        
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (B, T, d_in)
        Returns:
            torch.Tensor: (B, T, d_out)
        """
        # Unpack batch (B), time/sequence length (T), and channels (C=d_in).
        B, T, C = x.shape
        
        # Project inputs to fused Q/K/V of shape (B, T, d_out).
        queries = self.W_query(x)  # (B, T, d_out)
        keys    = self.W_key(x)    # (B, T, d_out)
        values  = self.W_value(x)  # (B, T, d_out)
        
        # Reshape to split heads: (B, T, d_out) -> (B, T, num_heads, head_dim)
        queries = queries.view(B, T, self.num_heads, self.head_dim)
        keys    = keys.view(B, T, self.num_heads, self.head_dim)
        values  = values.view(B, T, self.num_heads, self.head_dim)
        
        # Move heads before time for batched attention: (B, num_heads, T, head_dim)
        queries = queries.transpose(1, 2)  # (B, H, T, Hd)
        keys    = keys.transpose(1, 2)     # (B, H, T, Hd)
        values  = values.transpose(1, 2)   # (B, H, T, Hd)
        
        # Compute attention scores per head: (B, H, T, Hd) @ (B, H, Hd, T) -> (B, H, T, T)
        att_scores = queries @ keys.transpose(2, 3)

        # Scale by sqrt(head_dim) (the size of each key/query vector).
        att_scores = att_scores / (self.head_dim ** 0.5)
        
        # Apply causal mask: broadcast (T, T) -> (1, 1, T, T) across (B, H, T, T).
        # Positions where mask==True (future) get -inf so softmax -> 0.
        att_scores.masked_fill_(self.mask[:T, :T].unsqueeze(0).unsqueeze(0), -torch.inf)
        
        # Convert to probabilities along the last dimension (over keys/time).
        att_weights = torch.softmax(att_scores, dim=-1)  # (B, H, T, T)

        # Regularize attention by dropping some probability mass.
        att_weights = self.dropout(att_weights)
        
        # Weighted sum of values: (B, H, T, T) @ (B, H, T, Hd) -> (B, H, T, Hd)
        context_vec = att_weights @ values

        # Move time back in front of heads: (B, H, T, Hd) -> (B, T, H, Hd)
        context_vec = context_vec.transpose(1, 2)

        # Merge heads: (B, T, H, Hd) -> (B, T, H*Hd=d_out)
        context_vec = context_vec.contiguous().view(B, T, self.d_out)
        
        # Final linear projection mixes head outputs.
        out = self.out_proj(context_vec)
        
        return out

In [2]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.parameter(torch.zeros(emb_dim))
        
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        var = x.var(dim = -1, keepdim = True)
        
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        return self.scale * x_norm + self.shift

In [3]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(2 / torch.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [4]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emd_dim'], cfg['emd_dim'] * 4),
            GELU(),
            nn.Linear(cfg['emd_dim'] * 4, cfg['emd_dim']),
        )
    
    def forward(self, x):
        return self.layers(x)

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layernorm1 = LayerNorm(cfg['emb_dim'])
        self.layernorm2 = LayerNorm(cfg['emb_dim'])
        self.attn = MultiHeadAttention(
            d_in = cfg['emb_dim'],
            d_out = cfg['emb_dim'],
            context_length = cfg['context_length'],
            num_heads = cfg['n_heads'],
            dropout = cfg['drop_rate'],
            qkv_bias = cfg['qkv_bias']
        )
        self.ffn = FeedForward(cfg)
        self.dropout = nn.Dropout(cfg['drop_rate'])
        
    def forward(self, x):
        shortcut = x
        x = self.layernorm1(x)
        x = self.attn(x)
        x = self.dropout(x)
        x = x + shortcut
        
        shortcut = x
        x = self.layernorm2(x)
        x = self.ffn(x)
        x = self.dropout(x)
        x = x + shortcut    
        return x

In [8]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.dropout = nn.Dropout(cfg['drop_rate'])
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg['n_layers'])]
            )
        self.layernorm = LayerNorm(cfg['emb_dim'])
        self.lm_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=cfg['qkv_bias'])
        
    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.tok_emb(idx)  # (B, T, emb_dim)
        pos_emd = self.pos_emb(torch.arange(T, device=idx.device))  # (T, emb_dim)
        x = tok_emb + pos_emd
        x = self.dropout(x)  # (B, T, emb_dim)
        x = self.transformer_blocks(x)  # (B, T, emb_dim)
        x = self.layernorm(x)  # (B, T, emb_dim)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits
        