In [1]:
import torch
import torch.nn as nn

In [2]:
class SelfAttention(nn.Module):
    """
    A simple implementation of the Self-Attention mechanism.

    Self-Attention allows each element in a sequence to attend (i.e., 
    look at and weigh) all other elements in the sequence when forming 
    a new representation. This is the core mechanism behind 
    Transformer models.

    Args:
        d_in (int): Dimensionality of the input features.
        d_out (int): Dimensionality of the output (query/key/value vectors).

    Forward Input:
        x (torch.Tensor): Input tensor of shape (N, d_in), where
            - N is the number of tokens (sequence length).
            - d_in is the feature size of each token.

    Forward Output:
        context_vec (torch.Tensor): Tensor of shape (N, d_out),
            where each output token representation is a weighted sum of
            all input token value vectors.
    """

    def __init__(self, d_in, d_out):
        # Initialize the parent nn.Module class
        super().__init__()
        
        # Linear projection to generate query vectors (Q)
        # Each input vector (d_in) is mapped to query space (d_out)
        self.W_query = nn.Linear(d_in, d_out, bias=False)

        # Linear projection to generate key vectors (K)
        # Each input vector (d_in) is mapped to key space (d_out)
        self.W_key = nn.Linear(d_in, d_out, bias=False)

        # Linear projection to generate value vectors (V)
        # Each input vector (d_in) is mapped to value space (d_out)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        
    def forward(self, x):
        """
        Forward pass of self-attention.

        Args:
            x (torch.Tensor): Input tensor of shape (N, d_in).

        Returns:
            torch.Tensor: Output context vectors of shape (N, d_out).
        """

        # Project the input into query space
        queries = self.W_query(x)   # Shape: (N, d_out)

        # Project the input into key space
        keys = self.W_key(x)        # Shape: (N, d_out)

        # Project the input into value space
        values = self.W_value(x)    # Shape: (N, d_out)
        
        # Compute attention scores by taking dot product between queries and keys
        # queries: (N, d_out), keys.T: (d_out, N) => att_scores: (N, N)
        att_scores = queries @ keys.T 

        # Scale the scores by sqrt(d_out) to stabilize gradients
        att_scores = att_scores / keys.shape[-1]**0.5

        # Apply softmax to normalize the scores into attention weights
        # Each row sums to 1, representing the distribution over all tokens
        att_weights = torch.softmax(att_scores, dim=-1)  # Shape: (N, N)
        
        # Compute the weighted sum of values using attention weights
        # att_weights: (N, N), values: (N, d_out) => context_vec: (N, d_out)
        context_vec = att_weights @ values

        # Return the new representation for each token
        return context_vec

In [None]:
class CausalAttention(nn.Module):
    """
    Causal (autoregressive) self-attention.

    This layer computes standard self-attention but applies a *causal mask* so
    that token t can only attend to positions <= t (no peeking into the future).
    This is the attention used in decoder-only Transformers (e.g., GPT).

    Args:
        d_in (int): Input embedding dimensionality.
        d_out (int): Projection dimensionality for queries/keys/values and outputs.
        context_length (int): Maximum sequence length supported by the mask.
        dropout (float, optional): Dropout probability applied to attention weights.
        qkv_bias (bool, optional): Whether to include bias in Q/K/V linear layers.

    Shapes:
        Input:  x of shape (B, T, d_in)
                B = batch size, T = sequence length
        Output: context_vec of shape (B, T, d_out)

    Notes:
        - The causal mask is registered as a non-trainable buffer so it moves with
          the module across devices and is saved in state_dict.
        - Attention scores are scaled by sqrt(d_out) for stable gradients.
    """

    def __init__(self, d_in, d_out, context_length, 
                 dropout=0.0, qkv_bias=False):
        super().__init__()  # Initialize base nn.Module state

        # Linear projections to form Q, K, V from inputs.
        # Each maps d_in -> d_out; optional bias controlled by qkv_bias.
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Dropout used *after* softmax on attention weights.
        self.dropout = nn.Dropout(dropout)

        # Upper-triangular matrix with ones above the main diagonal:
        # mask[i, j] = 1 if j > i (future positions), else 0.
        # Registered as a buffer so it's not a parameter, but moves with the module.
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length), diagonal=1)  # (context_length, context_length)
        )
        
    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (B, T, d_in)

        Returns:
            torch.Tensor: (B, T, d_out) context vectors.
        """
        # Unpack shapes: B=batch, T=sequence length, C=input channels (d_in).
        B, T, C = x.shape

        # Project inputs to queries/keys/values.
        # Shapes: (B, T, d_out) for each.
        queries = self.W_query(x)
        keys    = self.W_key(x)
        values  = self.W_value(x)
        
        # Raw attention scores via batched matrix multiply:
        # queries: (B, T, d_out), keys.transpose(1, 2): (B, d_out, T)
        # => att_scores: (B, T, T) (each token attends over all tokens)
        att_scores = queries @ keys.transpose(1, 2)

        # Apply the causal mask: disallow attention to future positions.
        # self.mask[:T, :T] is (T, T); broadcast across batch to (B, T, T).
        # Positions where mask==1 (future) are set to -inf so softmax -> 0.
        att_scores.masked_fill_(self.mask.bool()[:T, :T], -torch.inf)

        # Scale by sqrt(d_out) to stabilize gradients/magnitudes.
        # (see note below about using the correct dimension)
        att_scores = att_scores / (queries.shape[-1] ** 0.5)

        # Normalize scores along the key dimension to obtain attention weights.
        # (B, T, T), each row sums to 1.
        att_weights = torch.softmax(att_scores, dim=-1)

        # Regularize by dropping some attention probability mass.
        att_weights = self.dropout(att_weights)
        
        # Weighted average of values: (B, T, T) @ (B, T, d_out) -> (B, T, d_out)
        context_vec = att_weights @ values

        # Return per-token context-enhanced representations.
        return context_vec

In [4]:
class MultiHeadAttentionWrapper(nn.Module):
    """
    A simple wrapper that composes multiple independent causal-attention heads
    and concatenates their outputs along the feature dimension.

    Each head is an instance of `CausalAttention`, so the masking behavior is
    autoregressive (token t cannot attend to positions > t).

    Args:
        d_in (int): Input embedding size for each token.
        d_out (int): Output size *per head* for each CausalAttention module.
        num_heads (int): Number of attention heads to run in parallel.
        context_length (int): Max sequence length supported by each head’s mask.
        dropout (float, optional): Dropout probability on attention weights.
        qkv_bias (bool, optional): Whether Q/K/V linear layers include bias.

    Shapes:
        Input:
            x: (B, T, d_in)
        Output:
            y: (B, T, num_heads * d_out)

        where:
            B = batch size, T = sequence length.
    """

    def __init__(self, d_in, d_out, 
                 num_heads, context_length, 
                 dropout=0.0, qkv_bias=False):
        # Initialize base nn.Module
        super().__init__()

        # Create 'num_heads' independent causal attention heads.
        # Each head maps (B, T, d_in) -> (B, T, d_out).
        # Using ModuleList keeps heads as registered submodules.
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in=d_in,
                    d_out=d_out,
                    context_length=context_length,
                    dropout=dropout,
                    qkv_bias=qkv_bias
                )
                for _ in range(num_heads)
            ]
        )
    
    def forward(self, x):
        """
        Runs all heads on the same input and concatenates their outputs.

        Args:
            x (torch.Tensor): (B, T, d_in)

        Returns:
            torch.Tensor: (B, T, num_heads * d_out)
        """
        # Compute each head's output: list of (B, T, d_out)
        head_outputs = [h(x) for h in self.heads]

        # Concatenate along the channel (feature) dimension: (B, T, num_heads*d_out)
        return torch.cat(head_outputs, dim=-1)