# Notes: You will have to:
  >1- Make a Copy of this notebook to edit it for your solutions;
  >
  >2- Upload on Gradescope: The `Colab Notebook edited with your solutions`, and a `pdf Report` with:
  >
   >>- All plots from experiments
   >>- Written interpretations and analyses
   >>- Summary tables.

# Question 1: xLSTM - Extended Long Short-Term Memory and Long-Sequence Generalization [50 points]


This question is divided into five parts. Please read carefully and follow all implementation and reporting instructions.

**Note:** Detailed instructions is provided in the HW PDF document.

## Section: Imports

In [None]:
import torch
from torch.nn.functional import logsigmoid
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Dict
from tqdm import tqdm

## Part 1.1: Exponential Gates [8 points]

In [None]:
class ExponentialGates(nn.Module):
    """
    ExponentialGates module.

    Core building block of xLSTM introducing exponential gating for
    better gradient flow over long sequences. A stabilizer state (m_t)
    ensures numerical stability by tracking maximum gate activations.

    The module supports two modes:
        - Exponential gating with stabilizer (xLSTM)
        - Standard sigmoid gating (for baseline/ablation)

    Returns:
        new_states: torch.Tensor [4, batch_size, hidden_size]
            (h_t, c_t, n_t, m_t)
        gates: torch.Tensor [4, batch_size, hidden_size]
            (input_gate, forget_gate, cell_input, output_gate)
    """

    def __init__(self, hidden_size: int, use_exponential: bool = True):
        """
        Args:
            hidden_size: int
                Dimensionality of the hidden representation.
            use_exponential: bool
                If True, use exponential gating with stabilizer.
                If False, use standard sigmoid gating.
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.use_exponential = use_exponential

        # TODO: define linear projections for the four gates
        # Each projection maps hidden_size ‚Üí hidden_size and corresponds to:
        #   - Input gate: input_gate_proj
        #   - Forget gate: forget_gate_proj
        #   - Cell input (update): cell_input_proj
        #   - Output gate: output_gate_proj
        # pass  # YOUR CODE HERE
        self.input_gate_proj = nn.Linear(hidden_size, hidden_size)
        self.forget_gate_proj = nn.Linear(hidden_size, hidden_size)
        self.cell_input_proj = nn.Linear(hidden_size, hidden_size)
        self.output_gate_proj = nn.Linear(hidden_size, hidden_size)
        # --------------------------------------------------------------------------
        # Note:
        # The linear layers and the .forward() method below are for Part 1.1 only,
        # allowing ExponentialGates to be tested as a standalone module.
        # In the full sLSTM/xLSTM architecture, the gate projections are handled
        # by the sLSTMLayer; only the pointwise_* methods are used during training.
        # --------------------------------------------------------------------------


    def forward_pointwise_exp(self, Wx, Ry, b, states, constants):
        """
        Exponential gating path (stabilized).

        Uses stabilized exponential gates (i'_t, f'_t) computed with the
        stabilizer state m_t. Updates the memory (c_t) and normalizer (n_t)
        following the equations provided in the handout.

        Args:
            Wx: torch.Tensor [batch_size, 4 * hidden_size]
                Input projections for all gates.
            Ry: torch.Tensor [batch_size, 4 * hidden_size]
                Recurrent contributions from h_{t-1}.
            b: torch.Tensor [1, 4 * hidden_size]
                Bias term for all gates.
            states: torch.Tensor [4, batch_size, hidden_size]
                Previous recurrent states (h_{t-1}, c_{t-1}, n_{t-1}, m_{t-1}).
            constants: dict
                Placeholder for optional constants (kept for interface consistency).

        Returns:
            new_states: torch.Tensor [4, batch_size, hidden_size]
                Updated states (h_t, c_t, n_t, m_t).
            gates: torch.Tensor [4, batch_size, hidden_size]
                Gate activations (input, forget, cell_input, output).
        """
        # pass  # YOUR CODE HERE
        h_prev, c_prev, n_prev, m_prev = states
        # print("x:", Wx.shape, "h_prev:", h_prev.shape)
        gates = Wx + Ry + b
        i_t, f_t, z_t, o_t = gates.chunk(4, dim=-1)

        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        z_t = torch.tanh(z_t)
        o_t = torch.sigmoid(o_t)
        m_t = torch.maximum(torch.log(f_t + 1e-8) + m_prev,
                            torch.log(i_t + 1e-8))

        i_t_prime = torch.exp(torch.log(i_t + 1e-8) - m_t)
        f_t_prime = torch.exp(torch.log(f_t + 1e-8) + m_prev - m_t)

        c_t = f_t_prime * c_prev + i_t_prime * z_t
        n_t = f_t_prime * n_prev + i_t_prime

        h_t = o_t * (c_t / (n_t + 1e-8)) #torch.tanh(c_t)

        new_states = torch.stack([h_t, c_t, n_t, m_t], dim=0)
        gates = torch.stack([i_t_prime, f_t_prime, z_t, o_t], dim=0)


        return new_states, gates


    def forward_pointwise_sigmoid(self, Wx, Ry, b, states, constants):
        """
        Sigmoid gating path (standard LSTM).

        Implements the classical LSTM update using sigmoid and tanh gates.
        Provided for comparison and ablation studies.

        Args:
            Wx, Ry, b, states, constants: same as forward_pointwise_exp().

        Returns:
            new_states: torch.Tensor [4, batch_size, hidden_size]
                (h_t, c_t, n_t, m_t) ‚Äî n_t and m_t act as placeholders here.
            gates: torch.Tensor [4, batch_size, hidden_size]
                (input_gate, forget_gate, cell_input, output_gate).
        """
        # pass  # YOUR CODE HERE
        h_prev, c_prev, n_prev, m_prev = states

        gates = Wx + Ry + b
        i_t, f_t, z_t, o_t = gates.chunk(4, dim=-1)

        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        z_t = torch.tanh(z_t)
        o_t = torch.sigmoid(o_t)

        c_t = f_t * c_prev + i_t * z_t
        h_t = o_t * torch.tanh(c_t)

        n_t = n_prev
        m_t = m_prev

        new_states = torch.stack([h_t, c_t, n_t, m_t], dim=0)
        gates = torch.stack([i_t, f_t, z_t, o_t], dim=0)

        return new_states, gates

    def forward(self, x_t, h_prev, states):
        """
        Main forward interface for one timestep.

        Steps:
            1. Compute linear projections for all gates from x_t.
            2. Prepare recurrent contributions from h_prev and add biases.
            3. Depending on `use_exponential`, call:
                   - forward_pointwise_exp(...)  ‚Üí stabilized exponential gates
                   - forward_pointwise_sigmoid(...) ‚Üí standard LSTM gates
            4. Return the new states and gate activations.

        Args:
            x_t: torch.Tensor [batch_size, hidden_size]
                Input at time step t.
            h_prev: torch.Tensor [batch_size, hidden_size]
                Previous hidden state h_{t-1}.
            states: torch.Tensor [4, batch_size, hidden_size]
                Previous states (h, c, n, m).

        Returns:
            new_states: torch.Tensor [4, batch_size, hidden_size]
            gates: torch.Tensor [4, batch_size, hidden_size]
        """
        # pass  # YOUR CODE HERE
        Wx_i = self.input_gate_proj(x_t)
        Wx_f = self.forget_gate_proj(x_t)
        Wx_z = self.cell_input_proj(x_t)
        Wx_o = self.output_gate_proj(x_t)
        Wx = torch.cat([Wx_i, Wx_f, Wx_z, Wx_o], dim=-1)

        Ry_i = self.input_gate_proj(h_prev)
        Ry_f = self.forget_gate_proj(h_prev)
        Ry_z = self.cell_input_proj(h_prev)
        Ry_o = self.output_gate_proj(h_prev)
        Ry = torch.cat([Ry_i, Ry_f, Ry_z, Ry_o], dim=-1)

        b = torch.cat([
            self.input_gate_proj.bias,
            self.forget_gate_proj.bias,
            self.cell_input_proj.bias,
            self.output_gate_proj.bias
        ]).unsqueeze(0)

        constants = {}
        if self.use_exponential:
            new_states, gates = self.forward_pointwise_exp(Wx, Ry, b, states, constants)
        else:
            new_states, gates = self.forward_pointwise_sigmoid(Wx, Ry, b, states, constants)

        return new_states, gates


## Part 1.2 ‚Äì sLSTM: Scalar Memory Architecture [8 points]



### Part 1.2.1 sLSTMCell

In [None]:
class sLSTMCell(nn.Module):
    """
    Core scalar-memory LSTM cell (sLSTM).

    Implements the scalar-memory recurrence defined in the handout:
        c_t = f'_t * c_{t-1} + i'_t * tanh(z_t)
        n_t = f'_t * n_{t-1} + i'_t
        h_t = o_t * (c_t / n_t)

    The exponential gates (i'_t, f'_t) are stabilized using the ExponentialGates
    module with the stabilizer state m_t. The cell maintains four internal states
    (h_t, c_t, n_t, m_t) and supports exponential or sigmoid gating.
    """

    def __init__(self, hidden_size: int, use_exponential: bool = True):
        """
        Args:
            hidden_size: int
                Dimensionality of the hidden representation.
            use_exponential: bool
                If True, use exponential gating with stabilizer (xLSTM).
                If False, use standard sigmoid gating (baseline).
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.use_exponential = use_exponential

        # TODO: define parameters for the sLSTM recurrence
        # Include recurrent weights and biases for the four gates.
        # A positive forget gate bias helps with stable training.
        # pass  # YOUR CODE HERE
        # self.W = nn.Linear(input_size, 4 * hidden_size, bias=False)
        self.R =  nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        self.b = nn.Parameter(torch.zeros(4 * hidden_size))
        with torch.no_grad():
            # Set positive bias for forget gate (positions 1, 5, 9, ... in chunks of 4)
            self.b[1::4] = 1.0  # Forget gate bias
        
        # Exponential gates module for stabilized gating
        self.gates = ExponentialGates(hidden_size, use_exponential)


    def forward_sequence(self, x, states, pointwise_forward):
        """
        Performs the internal recurrent computation across time steps.

        This method applies the sLSTM update for each element in the input
        sequence while maintaining and returning all intermediate states.

        Args:
            x: torch.Tensor [seq_len, batch_size, 4 * hidden_size]
                Sequence of pre-activation inputs for all gates.
            states: torch.Tensor [4, batch_size, hidden_size]
                Previous recurrent states (h, c, n, m).
            pointwise_forward: Callable
                Function that computes a single time-step update.

        Returns:
            states_all: torch.Tensor [4, seq_len + 1, batch_size, hidden_size]
                All intermediate states, including the initial one.
            final_state: torch.Tensor [4, batch_size, hidden_size]
                Final recurrent state after the last time step.
            gates: torch.Tensor [seq_len, 4, batch_size, hidden_size]
                Gate activations for each time step.
        """
        # pass  # YOUR CODE HERE
        seq_len, batch_size, _ = x.shape

        # Store all states (including initial)
        states_all = [states]
        gates_all = []

        # Iterate over each time step
        for t in range(seq_len):
            h_prev = states[0]

            # Compute recurrent contribution Ry_t
            Ry_t = self.R(h_prev)

            # Compute new states and gates
            new_states, gates_t = pointwise_forward(
                Wx=x[t],          # input projection for this timestep
                Ry=Ry_t,          # recurrent projection
                b=self.b,         # bias term
                states=states,    # previous states (h, c, n, m)
                constants=None
            )

            # Save results
            states_all.append(new_states)
            gates_all.append(gates_t)

            # Update recurrent state for next timestep
            states = new_states

        # Stack along time dimension
        states_all = torch.stack(states_all, dim=1)      # [4, seq_len + 1, batch_size, hidden_size]
        gates_all = torch.stack(gates_all, dim=0)        # [seq_len, 4, batch_size, hidden_size]

        final_state = states
        return states_all, final_state, gates_all

    def forward(self, x, state=None):
        """
        Computes the sLSTM cell output for an entire input sequence.

        This is the public interface used by higher-level modules (e.g., sLSTMLayer).
        It handles initial state setup and delegates the temporal computation
        to `forward_sequence`.

        Args:
            x: torch.Tensor [batch_size, seq_len, 4 * hidden_size]
                Concatenated gate projections for the input sequence.
            state: torch.Tensor [4, batch_size, hidden_size], optional
                Initial recurrent states (h, c, n, m). Defaults to zeros.

        Returns:
            output: torch.Tensor [batch_size, seq_len, hidden_size]
                Sequence of hidden outputs (h_t).
            final_state: torch.Tensor [4, batch_size, hidden_size]
                Final recurrent states after processing the sequence.
        """
        # pass  # YOUR CODE HERE
        batch_size, seq_len, _ = x.shape

        x = x.transpose(0, 1)

        if state is None:
            h0 = x.new_zeros(batch_size, self.hidden_size)
            c0 = x.new_zeros(batch_size, self.hidden_size)
            n0 = x.new_zeros(batch_size, self.hidden_size)
            m0 = x.new_zeros(batch_size, self.hidden_size)
            state = torch.stack([h0, c0, n0, m0], dim=0)

        if self.use_exponential:
            pointwise_forward = self.gates.forward_pointwise_exp
        else:
            pointwise_forward = self.gates.forward_pointwise_sigmoid

        states_all, final_state, gates = self.forward_sequence(x, state, pointwise_forward)

        output = states_all[0]  # [seq_len + 1, batch_size, hidden_size]
        output = output[1:].transpose(0, 1)  # remove initial state, [batch_size, seq_len, hidden_size]

        return output, final_state


### Part 1.2.2 ‚Äì sLSTMLayer: High-Level Sequence Processor

In [None]:
class sLSTMLayer(nn.Module):
    """
    High-level sequence processor built on top of sLSTMCell.

    Applies the sLSTM recurrence across time for input sequences of shape [B, T, H]:
        h_t, (h, c, n, m) = sLSTMCell(x_t, h_{t-1}, states)
    Returns the full hidden sequence and final states.
    """

    def __init__(self, hidden_size: int, dropout_prob: float=0.1, use_exponential: bool = True):
        """
        Args:
            hidden_size: int
                Dimensionality of the hidden representation.
            dropout_prob: float
                Dropout probability applied to the output sequence.
            use_exponential: bool
                If True, use exponential gating (xLSTM mode).
                If False, use standard sigmoid gating (baseline).
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(dropout_prob)
        # Gate projections
        # Each projection maps hidden_size ‚Üí hidden_size and corresponds to:
        #   - Input gate: igate
        #   - Forget gate: fgate
        #   - Cell input (update): zgate
        #   - Output gate: ogate
        # pass # YOUR CODE HERE
        self.igate = nn.Linear(hidden_size, hidden_size)
        self.fgate = nn.Linear(hidden_size, hidden_size)
        self.zgate = nn.Linear(hidden_size, hidden_size)
        self.ogate = nn.Linear(hidden_size, hidden_size)
        self.cell = sLSTMCell(hidden_size, use_exponential)
        pass # YOUR CODE HERE


    def forward(self, x, initial_state=None):
        """
        Forward pass of sLSTM layer for batch processing.

        Args:
            x: Input tensor [batch_size, seq_len, hidden_size]

        Returns:
            output: Output tensor [batch_size, seq_len, hidden_size]
        """
        # pass  # YOUR CODE HERE
        
        batch_size, seq_len, _ = x.shape

        # Compute gate projections
        Wx_i = self.igate(x)
        Wx_f = self.fgate(x)
        Wx_z = self.zgate(x)
        Wx_o = self.ogate(x)
        Wx = torch.cat([Wx_i, Wx_f, Wx_z, Wx_o], dim=-1)
        # Forward through sLSTMCell
        
        # x_gated = x_gated.transpose(0, 1)  # [seq_len, batch_size, 4 * hidden_size]
        
        # # Process sequence through sLSTMCell
        # output_seq, final_states = self.cell(x_gated, initial_state)
        
        # # Transpose output back to [batch_size, seq_len, hidden_size]
        # output = output_seq.transpose(0, 1)  # [batch_size, seq_len, hidden_size]
        
        # # Apply dropout to the output sequence
        # output = self.dropout(output)


        output, final_state = self.cell(Wx, initial_state)
        output = self.dropout(output)
        return output, final_state


## **Part 1.3 ‚Äì xLSTM Block and Stack [8 points]**

In this section, we assemble full **xLSTM layers** by combining the previously implemented components into Transformer-style residual blocks.

The code below already provides:
- **LayerNorm** with residual weighting ‚Äî ensures numerical stability during deep stacking. *(No edits required.)*
- **GatedMLP** using GeLU activation and a projection factor of 4/3. *(No edits required.)*

You will complete:
- **xLSTMBlock** ‚Äî one residual block that wraps:
  1. LayerNorm ‚Üí sLSTM layer ‚Üí first residual connection  
  2. LayerNorm ‚Üí GatedMLP ‚Üí second residual connection  
  The block should return the final tensor after both residuals.

- **xLSTM** ‚Äî a stack of multiple xLSTMBlocks that processes entire sequences in parallel.  
  It should loop over the configured number of layers and apply each block sequentially.

Follow the structure and equations described in the handout.



In [None]:
class LayerNorm(nn.Module):
    """Layer normalization with residual weight mechanism (important for xLSTM stability)."""
    def __init__(self, hidden_size: int):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))  # Start with zeros!
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.eps = 1e-5

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Use residual weight: 1.0 + weight (important for xLSTM!)
        weight_proxy = 1.0 + self.weight
        return torch.nn.functional.layer_norm(x, normalized_shape=(x.shape[-1],),
                                           weight=weight_proxy, bias=self.bias, eps=self.eps)


class GatedMLP(nn.Module):
    """Gated MLP with GeLU activation and projection factor 4/3."""
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.projection_factor = 4 / 3
        self.inner_size = int(hidden_size * self.projection_factor)

        self.up_proj = nn.Linear(hidden_size, self.inner_size)
        self.down_proj = nn.Linear(self.inner_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(torch.nn.functional.gelu(self.up_proj(x)))


In [None]:
class xLSTMBlock(nn.Module):
    """
    One residual block of the xLSTM architecture.

    Each block follows:
        h‚ÇÅ = sLSTMLayer(LayerNorm(x))
        x‚ÇÅ = x + h‚ÇÅ
        h‚ÇÇ = GeLU(W_up * LayerNorm(x‚ÇÅ))
        x‚ÇÇ = x‚ÇÅ + W_down * h‚ÇÇ

    Combines recurrent memory updates (via sLSTM) and gated feed-forward
    processing within dual residual connections.
    """

    def __init__(self, hidden_size: int, dropout_prob: float, use_exponential: bool = True):
        """
        Args:
            hidden_size: int
                Dimensionality of the hidden representation.
            use_exponential: bool
                Whether to use exponential gating (xLSTM) or sigmoid gating (baseline).
        """
        super().__init__()
        pass  # YOUR CODE HERE
        self.layer_norm = LayerNorm(hidden_size)
        self.slstm = sLSTMLayer(hidden_size, use_exponential)
        self.mlp_norm = LayerNorm(hidden_size)  # Separate LayerNorm for MLP
        self.gated_mlp = GatedMLP(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)
        self.hidden_size = hidden_size


    def forward(self, x):
        """
        Forward pass of xLSTM block for batch processing.

        Args:
            x: Input tensor [batch_size, seq_len, hidden_size]

        Returns:
            output: Output tensor [batch_size, seq_len, hidden_size]
        """
        # pass  # YOUR CODE HERE
        # print(f"x.shape before transpose: {x.shape}")

        if x.dim() == 3 and x.shape[1] == self.hidden_size and x.shape[2] != self.hidden_size:
            x = x.transpose(1, 2)

            # likely shape is [B, H, T] ‚Äî fix it
            # x = x.transpose(1, 2)
            # Now shape = [B, T, H]

        x_norm = self.layer_norm(x)
        h1, _ = self.slstm(x_norm)
        x1 = x + h1
        x1_norm = self.mlp_norm(x1)
        h2 = self.gated_mlp(x1_norm)
        x2 = x1 + h2
        return x2




In [None]:
class xLSTM(nn.Module):
    """
    Full xLSTM stack of L residual blocks.

    Applies the recurrence:
        x ‚Üê xLSTMBlock‚ÇÅ(x)
        ...
        x ‚Üê xLSTMBlock_L(x)

    Args:
        d: Hidden dimension.
        num_layers: Number of stacked xLSTM blocks.
        dropout_prob: Dropout probability.
    """

    def __init__(self, d: int, num_layers: int, dropout_prob: float = 0.1, use_exponential: bool = True):
        """
        Args:
            hidden_size: int
                Dimensionality of the hidden representation.
            num_layers: int
                Number of stacked xLSTM blocks.
            use_exponential: bool
                If True, use exponential gating (xLSTM mode).
                If False, use standard sigmoid gating (baseline).
        """
        super().__init__()
        # TODO: create list of xLSTMBlock modules
        # pass  # YOUR CODE HERE
        self.input_proj = nn.LazyLinear(d)
        self.layers = nn.ModuleList([
            xLSTMBlock(d, dropout_prob, use_exponential)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_prob)
        self.output_proj = None  # will be created dynamically
        self.hidden_size = d

    def forward(self, x):
        """
        Forward pass through all blocks with true batch processing.

        Args:
            x: Input tensor [batch_size, seq_len, hidden_size]

        Returns:
            output: Final output [batch_size, seq_len, hidden_size]
        """
        # pass  # YOUR CODE HERE
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x)

        # Take the last timestep
        x = x[:, -1, :]   # [B, H]

        # Lazy init: create output projection on first forward
        if self.output_proj is None:
            num_classes = x.shape[-1]  # or another dimension if provided
            self.output_proj = nn.Linear(self.hidden_size, num_classes).to(x.device)

        out = self.output_proj(x)
        return out

## Part 1.4 ‚Äì  xLSTM Training on the Parity Task [6 points]

In this part, you will train and evaluate your **xLSTM implementation** on a synthetic *parity task* that tests the model‚Äôs ability to capture **long-range dependencies**.

Several helper components are provided to make experimentation easier:

- **ParityDataset** ‚Äì Generates binary sequences and corresponding parity labels for both training and testing. *(No modifications required.)*  
- **ModelFactory** ‚Äì Creates and configures different model types (xLSTM-Exp, xLSTM-Sigmoid, and Vanilla LSTM). *(No modifications required.)*  
- **ModelTrainer** ‚Äì Implements batching, optimization, and evaluation.  
  ‚ûú *You will complete this section.*  
- **ExperimentRunner** ‚Äì Manages the full training and evaluation pipeline, connecting all components together. *(No modifications required.)*

Your goal in this part is to **complete the missing code in `ModelTrainer`** and verify that your models from previous parts can successfully learn and generalize on the parity dataset.


### Parity Dataset (No Modifications Required)





In [None]:

# ParityDataset: Complete implementation provided - handles the generation of the parity dataset.
# No modifications required.
class ParityDataset:
    """Parity dataset with variable-length training and generalization testing."""

    def __init__(self, max_length: int, num_samples: int = 1000,
                 variable_length: bool = True, test_range: tuple = None):
        """
        Args:
            max_length: Maximum training length (for variable-length sequences).
            num_samples: Number of samples to generate.
            variable_length: Whether to randomly sample sequence lengths (for training/test).
            test_range: Tuple (min_len, max_len) for generalization test (e.g., (40, 256)).
        """
        self.max_length = max_length
        self.num_samples = num_samples
        self.variable_length = variable_length
        self.test_range = test_range

    def generate_sequence(self, length: int) -> Tuple[torch.Tensor, int]:
        seq = torch.randint(0, 2, (length,))
        parity = seq.sum().item() % 2
        return seq.float(), parity

    def create_dataset(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Create padded dataset with mask for variable-length sequences."""
        sequences, labels, lengths = [], [], []

        for _ in range(self.num_samples):
            if self.test_range:
                # Test mode: sample random length from generalization range
                length = np.random.randint(self.test_range[0], self.test_range[1] + 1)
            elif self.variable_length:
                # Training or in-distribution test: sample 1‚Äìmax_length
                length = np.random.randint(1, self.max_length + 1)
            else:
                # Fixed-length case
                length = self.max_length

            seq, label = self.generate_sequence(length)
            sequences.append(seq)
            labels.append(label)
            lengths.append(length)

        max_len = max(lengths)
        padded_sequences = torch.zeros(self.num_samples, max_len)
        mask = torch.zeros_like(padded_sequences)

        for i, seq in enumerate(sequences):
            padded_sequences[i, :len(seq)] = seq
            mask[i, :len(seq)] = 1  # mark valid positions

        return padded_sequences, torch.tensor(labels), mask

### Model Factory (No Modifications Required)





In [None]:
# ModelFactory: Complete implementation provided - handles the creation of different model types.
# including model creation. No modifications required.
class ModelFactory:
    """Factory for creating different model types."""

    def __init__(self, config: dict):
        self.config = config

    def create_xlstm_exp(self) -> xLSTM:
        """Create xLSTM with exponential gates."""
        model = xLSTM(self.config['hidden_size'], self.config['num_layers'], use_exponential=True)
        model.input_proj = nn.Linear(1, self.config['hidden_size'])  # Project binary input to hidden_size
        model.classifier = nn.Linear(self.config['hidden_size'], 2)
        return model

    def create_xlstm_sigmoid(self) -> xLSTM:
        """Create xLSTM with sigmoid gates."""
        model = xLSTM(self.config['hidden_size'], self.config['num_layers'], use_exponential=False)
        model.input_proj = nn.Linear(1, self.config['hidden_size'])  # Project binary input to hidden_size
        model.classifier = nn.Linear(self.config['hidden_size'], 2)
        return model

    def create_vanilla_lstm(self) -> nn.Module:
        """Create vanilla LSTM baseline."""
        model = nn.LSTM(self.config['hidden_size'], self.config['hidden_size'],
                       self.config['num_layers'], batch_first=True)
        model.input_proj = nn.Linear(1, self.config['hidden_size'])  # Project binary input to hidden_size
        model.classifier = nn.Linear(self.config['hidden_size'], 2)
        model.hidden_size = self.config['hidden_size']  # Add for compatibility
        return model

    def create_all_models(self) -> dict:
        """Create all models for comparison."""
        return {
            'xLSTM (Exponential)': self.create_xlstm_exp(),
            'xLSTM (Sigmoid)': self.create_xlstm_sigmoid(),
            'Vanilla LSTM': self.create_vanilla_lstm()
        }

### Model Training and Evaluation

In [None]:
from pyexpat import model


class ModelTrainer:
    """Handles training and evaluation of models (with masking support)."""

    def __init__(self, batch_size: int = 256, lr: float = 0.001):
        self.batch_size = batch_size
        self.lr = lr
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

    def train_model(self, model: nn.Module, train_data: tuple, epochs: int = 10):
        """
        Generic training loop supporting padding masks.

        Expects:
            train_data = (sequences, labels, masks)
        """
        sequences, labels, masks = train_data
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=0.1)
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        model = model.to(self.device)
        sequences, labels, masks = (
            sequences.to(self.device),
            labels.to(self.device),
            masks.to(self.device),
        )

        losses_per_epoch = []

        for epoch in range(epochs):
            model.train()
            total_loss, correct = 0.0, 0

            for batch_start in range(0, len(sequences), self.batch_size):
                batch_end = min(batch_start + self.batch_size, len(sequences))
                x = sequences[batch_start:batch_end]
                y = labels[batch_start:batch_end]
                m = masks[batch_start:batch_end]

                optimizer.zero_grad()

                if hasattr(model, "input_proj") and hasattr(model, "classifier"):
                    if x.ndim == 2:
                        x = x.unsqueeze(-1)

                    # Try direct call first
                    try:
                        out = model(x)
                    except RuntimeError as e:
                        if "input.size(-1)" in str(e):  # means input needs projection
                            x_proj = model.input_proj(x)
                            out = model(x_proj)
                        else:
                            raise e

                    # ‚úÖ Handle both tensor and tuple outputs (e.g., nn.LSTM returns tuple)
                    if isinstance(out, tuple):
                        out = out[0]  # extract the sequence output

                    # Handle both [B, T, H] and [B, H]
                    if out.ndim == 3:
                        out = out[:, -1, :]

                    logits = model.classifier(out)

                else:
                    if x.ndim == 2:
                        x = x.unsqueeze(-1)
                    logits = model(x)

                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                scheduler.step()

                total_loss += loss.item()
                correct += (logits.argmax(dim=1) == y).sum().item()

            avg_loss = total_loss / (len(sequences) // self.batch_size + 1)
            acc = correct / len(sequences)
            losses_per_epoch.append(avg_loss)
            print(f"Epoch {epoch+1}/{epochs} ‚Äî Loss: {avg_loss:.4f}, Acc: {acc:.4f}")

        return losses_per_epoch

    def evaluate_model(self, model: nn.Module, test_data: tuple) -> float:
        """
        Evaluation loop with padding mask support.
        """
        sequences, labels, masks = test_data
        model = model.to(self.device)
        sequences, labels, masks = (
            sequences.to(self.device),
            labels.to(self.device),
            masks.to(self.device),
        )

        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_start in range(0, len(sequences), self.batch_size):
                batch_end = min(batch_start + self.batch_size, len(sequences))
                x = sequences[batch_start:batch_end]
                y = labels[batch_start:batch_end]
                m = masks[batch_start:batch_end]

                # Forward pass ‚Äî implement your model call here
                # Example: logits = model(x, m)
                # ---------------------------------
                # YOUR CODE HERE
                # ---------------------------------
                if hasattr(model, "input_proj") and hasattr(model, "classifier"):
                    if x.ndim == 2:
                        x = x.unsqueeze(-1)

                    # Try direct call first
                    try:
                        out = model(x)
                    except RuntimeError as e:
                        if "input.size(-1)" in str(e):  # means input needs projection
                            x_proj = model.input_proj(x)
                            out = model(x_proj)
                        else:
                            raise e

                    # ‚úÖ Handle both tensor and tuple outputs (e.g., nn.LSTM returns tuple)
                    if isinstance(out, tuple):
                        out = out[0]  # extract the sequence output

                    # Handle both [B, T, H] and [B, H]
                    if out.ndim == 3:
                        out = out[:, -1, :]

                    logits = model.classifier(out)

                else:
                    if x.ndim == 2:
                        x = x.unsqueeze(-1)
                    logits = model(x)
                correct += (logits.argmax(dim=1) == y).sum().item()

        return correct / len(sequences)


### Experiment Runner (No Modifications Required)

In [None]:
# ExperimentRunner: Complete implementation provided - handles the full experiment pipeline
# including model testing, training, and evaluation. No modifications required.
class ExperimentRunner:
    """Runs the complete experiment pipeline."""

    def __init__(self, config: dict):
        self.config = config
        self.model_factory = ModelFactory(config)
        self.trainer = ModelTrainer(
            batch_size=config.get('batch_size', 256),
            lr=config.get('learning_rate', 1e-3)
        )

    def test_exponential_gates(self):
        """Test Part 1.1: ExponentialGates implementation."""
        print("Testing ExponentialGates...")

        # Create gates
        gates_exp = ExponentialGates(self.config['hidden_size'], use_exponential=True)
        gates_sigmoid = ExponentialGates(self.config['hidden_size'], use_exponential=False)

        # Test data
        batch_size = 4
        x_t = torch.randn(batch_size, self.config['hidden_size'])
        h_prev = torch.randn(batch_size, self.config['hidden_size'])
        states = torch.randn(4, batch_size, self.config['hidden_size'])

        # Test exponential gates
        new_states_exp, gates_exp_vals = gates_exp(x_t, h_prev, states)
        print(f"‚úÖ Exponential gates output shape: {new_states_exp.shape}")

        # Test sigmoid gates
        new_states_sigmoid, gates_sigmoid_vals = gates_sigmoid(x_t, h_prev, states)
        print(f"‚úÖ Sigmoid gates output shape: {new_states_sigmoid.shape}")

        print("‚úÖ Part 1.1: ExponentialGates working correctly!")

    def test_slstm_cell(self):
        """Test Part 1.2: sLSTMCell implementation."""
        print("Testing sLSTMCell...")

        # Create sLSTM cell
        slstm = sLSTMCell(self.config['hidden_size'], use_exponential=True)

        # Test data
        batch_size = 4
        seq_len = 10
        x = torch.randn(batch_size, seq_len, 4 * self.config['hidden_size'])  # Concatenated gates

        # Test forward pass
        output, final_state = slstm(x)
        print(f"‚úÖ sLSTM output shape: {output.shape}")
        print(f"‚úÖ sLSTM final state shape: {final_state.shape}")

        print("‚úÖ Part 1.2: sLSTMCell working correctly!")

    def test_xlstm_architecture(self):
        """Test Part 1.3: xLSTM Block and Stack."""
        print("Testing xLSTM architecture...")

        # Create xLSTM
        xlstm = xLSTM(self.config['hidden_size'], self.config['num_layers'], use_exponential=True)

        # Test data
        batch_size = 4
        seq_len = 10
        x = torch.randn(batch_size, seq_len, self.config['hidden_size'])

        # Test forward pass (batch processing)
        output = xlstm(x)
        print(f"‚úÖ xLSTM output shape: {output.shape}")

        print("‚úÖ Part 1.3: xLSTM architecture working correctly!")


    # -------------------------------------------------------------------------
    # 2. Main Experiment Flow
    # -------------------------------------------------------------------------
    def run_experiment(self):
        """Run complete training and evaluation experiment."""
        print("Running complete experiment...")

        # === Datasets ===
        print("Preparing datasets...")
        # Training: variable 1‚Äì40
        train_dataset = ParityDataset(
            max_length=self.config['train_length'],
            num_samples=self.config['num_samples'],
            variable_length=True
        )
        train_data = train_dataset.create_dataset()

        # IID test: 1‚Äì40
        test_id_dataset = ParityDataset(
            max_length=self.config['train_length'],
            num_samples=1000,
            variable_length=True
        )
        test_id_data = test_id_dataset.create_dataset()

        # Generalization test: 40‚Äì256
        test_ood_dataset = ParityDataset(
            max_length=self.config['test_max_length'],
            num_samples=1000,
            variable_length=False,
            test_range=(40, 256)
        )
        test_ood_data = test_ood_dataset.create_dataset()

        # === Models ===
        models = self.model_factory.create_all_models()

        # === Training ===
        loss_histories = {}
        for name, model in models.items():
            print(f"\nTraining {name}...")
            loss_history = self.trainer.train_model(model, train_data, self.config['epochs'])
            loss_histories[name] = loss_history

        # === Evaluation & Analysis ===
        self.analyze_results(models, loss_histories, test_id_data, test_ood_data)

    # -------------------------------------------------------------------------
    # 3. Unified Analyze Function
    # -------------------------------------------------------------------------
    def analyze_results(self, models: dict, loss_histories: dict,
                        test_id_data: tuple, test_ood_data: tuple):
        """Evaluate all models and generate IID + generalization plots and table."""
        print("\nüîç Analyzing results...")

        # Evaluate all models
        results = {}
        for name, model in models.items():
            print(f"\nEvaluating {name}...")
            id_acc = self.trainer.evaluate_model(model, test_id_data)
            ood_acc = self.trainer.evaluate_model(model, test_ood_data)
            results[name] = (id_acc, ood_acc)
            print(f"‚úÖ {name}: In-Dist={id_acc:.3f}, Generalization={ood_acc:.3f}")

        # === Plot all results ===
        self.plot_training_loss(loss_histories)
        self.plot_in_distribution_accuracy(results)
        self.plot_generalization_accuracy(results)
        self.create_results_table(results)

    # -------------------------------------------------------------------------
    # 4. Plots & Summary
    # -------------------------------------------------------------------------
    def plot_training_loss(self, loss_histories: dict):
        """Plot training loss over epochs for all models."""
        plt.figure(figsize=(10, 6))
        for name, losses in loss_histories.items():
            plt.plot(range(1, len(losses)+1), losses, marker='o', label=name, markersize=3)
        plt.xlabel('Epoch')
        plt.ylabel('Training Loss')
        plt.title('Training Loss Over Time')
        plt.legend()
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_in_distribution_accuracy(self, results: dict):
        """Plot in-distribution (1‚Äì40) test accuracy for all models."""
        plt.figure(figsize=(8, 6))
        models = list(results.keys())
        id_accs = [results[m][0] for m in models]
        plt.bar(models, id_accs, color='skyblue')
        plt.ylabel('Accuracy')
        plt.ylim(0, 1)
        plt.title('In-Distribution Accuracy (1‚Äì40)')
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_generalization_accuracy(self, results: dict):
        """Plot generalization (40‚Äì256) accuracy for all models."""
        plt.figure(figsize=(8, 6))
        models = list(results.keys())
        ood_accs = [results[m][1] for m in models]
        plt.bar(models, ood_accs, color='lightcoral')
        plt.ylabel('Accuracy')
        plt.ylim(0, 1)
        plt.title('Generalization Accuracy (40‚Äì256)')
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.show()

    def create_results_table(self, results: dict):
        """Print results summary table."""
        print("\n" + "="*70)
        print(f"{'Model':<25}{'In-Distribution (1‚Äì40)':<20}{'Generalization (40‚Äì256)':<20}")
        print("-"*70)
        for name, (id_acc, ood_acc) in results.items():
            print(f"{name:<25}{id_acc:<20.3f}{ood_acc:<20.3f}")
        print("="*70)

## Part 1.5 Result and Analysis [20 points]

### Training tips
When getting started, try training on a small sample ‚Äî for example, use num_samples = batch_size (e.g., 64 samples) ‚Äî just to confirm that your model runs without errors and that the loss decreases. Once things look stable, you can increase the number of epochs and data samples for a full training run.

In [None]:
print("üöÄ xLSTM Assignment Implementation")
print("="*50)

# GPU availability check
print(f"\nüñ•Ô∏è  GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print()
config = {
    'hidden_size': 64,
    'num_layers': 2,
    'train_length': 40,        # variable-length 1‚Äì40
    'test_max_length': 256,    # generalization test 40‚Äì256
    'num_samples': 10000,     # More data for better convergence
    'epochs': 20,
    'batch_size': 256,
    'learning_rate': 1e-3,
    'weight_decay': 0.1,     # Add weight decay!
}
print(f"Configuration: {config}")
# Create experiment runner
experiment = ExperimentRunner(config)
# Part 1.1: Test ExponentialGates
print("\n" + "="*50)
print("PART 1.1: Testing ExponentialGates")
print("="*50)
experiment.test_exponential_gates()
# Part 1.2: Test sLSTMCell
print("\n" + "="*50)
print("PART 1.2: Testing sLSTMCell")
print("="*50)
experiment.test_slstm_cell()
# Part 1.3: Test xLSTM Architecture
print("\n" + "="*50)
print("PART 1.3: Testing xLSTM Architecture")
print("="*50)
experiment.test_xlstm_architecture()
# Part 1.4 & 1.5: Run Complete Experiment
print("\n" + "="*50)
print("PART 1.4 & 1.5: Complete Experiment")
print("="*50)
experiment.run_experiment()
print("\n‚úÖ HW1 completed successfully!")

# Question 2: Building a Modern Transformer for Modular Arithmetic [50 points]

In this question, you will implement a state-of-the-art decoder-only transformer architecture from scratch using PyTorch. Your model will learn to perform modular arithmetic operations. Through careful implementation and systematic experimentation, you will gain deep understanding of modern architectural components including RMSNorm, SwiGLU activations, rotary position embeddings (RoPE), and grouped-query attention (GQA). This assignment emphasizes both implementation rigor and experimental methodology.

## Section 0: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Optional, Dict
import math
from tqdm import tqdm
import json
import pickle

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Section 1: Core Component Implementations (30 points total)

### 1.1 RMSNorm: Root Mean Square Normalization [3 points]

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.

    Paper: "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019)
    https://arxiv.org/abs/1910.07467

    Key implementation details from paper:
    - RMS(x) = sqrt(mean(x^2) + eps)
    - Output = gamma * (x / RMS(x))
    - No mean centering (unlike LayerNorm)
    - No bias parameter

    Args:
        d: Model dimension
        eps: Small constant for numerical stability (default: 1e-6)
    """
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

        # TODO: Create learnable scale parameter initialized to ones
        # Hint: Use nn.Parameter with torch.ones
        # Shape should be (d,)
        # pass  # YOUR CODE HERE
        self.scale = nn.Parameter(torch.ones(d))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor of shape (B, S, d)

        Returns:
            Normalized tensor of shape (B, S, d)
        """
        # pass  # YOUR CODE HERE
        # first step is to do RMS
        rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        x_norm = x * rms * self.scale
        return x_norm
        # Alternative implementation:
        # rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        # x_norm = x / rms
        # return self.scale * x_norm


### 1.2 SwiGLU Feed-Forward Network [3 points]

In [None]:
class FeedForward(nn.Module):
    """
    SwiGLU Feed-Forward Network.

    Paper: "GLU Variants Improve Transformer" (Shazeer, 2020)
    https://arxiv.org/abs/2002.05202

    Key implementation details from paper:
    - FFN_SwiGLU(x) = (Swish(xW1) ‚äô xV) W2
    - Swish(x) = x * sigmoid(x), also known as SiLU
    - Two parallel projections (gate and value)
    - Standard output projection
    - All projections are bias-free

    Args:
        d: Model dimension
        d_ff: Hidden dimension (typically 4*d, or 8d/3 for SwiGLU to match params)
        dropout_prob: Dropout probability
    """
    def __init__(self, d: int, d_ff: int, dropout_prob: float):
        super().__init__()

        # TODO: Define three linear layers (all bias=False)
        # - w_gate: d -> d_ff (gate path, will apply Swish)
        # - w_value: d -> d_ff (value path, stays linear)
        # - w_out: d_ff -> d (output projection)
        # pass  # YOUR CODE HERE
        self.w_gate = nn.Linear(d, d_ff, bias=False)
        self.w_value = nn.Linear(d, d_ff, bias=False)
        self.w_out = nn.Linear(d_ff, d, bias=False)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor of shape (B, S, d)

        Returns:
            Output tensor of shape (B, S, d)
        """
        # pass  # YOUR CODE HERE
        swiglu = self.w_out(torch.nn.functional.silu(self.w_gate(x)) * self.w_value(x))
        # we apply dropout after the output projection because it's common to apply dropout on the output of FFN layers. 
        # The real reason is to prevent overfitting and improve generalization.
        return self.dropout(swiglu)

### 1.3 RoPE: Rotary Position Embeddings [6 points]

In [None]:
from networkx import omega


def rotate_half(t: torch.Tensor) -> torch.Tensor:
    """
    Helper function for RoPE: rotates pairs of elements.

    Paper: "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021)
    https://arxiv.org/abs/2104.09864

    Transforms [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...]
    This implements the rotation matrix in complex number form.

    Args:
        t: Tensor of shape (..., d) where d is even

    Returns:
        Rotated tensor of shape (..., d)
    """
    # pass  # YOUR CODE HERE
    dimension = t.shape[-1]
    assert dimension % 2 == 0, "dimensionimension must be even for rotate_half"
    t1 = []
    for i in range(dimension):
        if i % 2 == 0:
            t1.append(t[..., i])
        else:
            t1.insert(i-1, -1*t[..., i])
    t1 = torch.stack(t1, dim=-1)

    return t1

    


def apply_rope(x: torch.Tensor, positions: torch.Tensor,
               d_rope: Optional[int] = None, theta: float = 10000.0) -> torch.Tensor:
    """
    Apply Rotary Position Embeddings to input tensor.

    Paper details:
    - Rotation angles: theta_j = position * base^(-2j/d_rope)
    - Applied as: x * cos(theta) + rotate_half(x) * sin(theta)
    - Only rotates first d_rope dimensions (partial rotation)
    - Base frequency typically 10000 (for context length ~2048)

    Args:
        x: Input tensor of shape (B, S, H, d_h)
        positions: Position indices of shape (S,)
        d_rope: Dimension to apply rotation (if None, use all d_h)
        theta: Base frequency (default: 10000)

    Returns:
        Rotated tensor of shape (B, S, H, d_h)
    """
    B, S, H, d_h = x.shape

    # Use full dimension if d_rope not specified
    if d_rope is None:
        d_rope = d_h
    assert d_rope <= d_h and d_rope % 2 == 0, "d_rope must be even and <= d_h"

    device = x.device
    
    # TODO: Implement RoPE
    # Step 1: Compute inverse frequencies
    # pass  # YOUR CODE HERE
    omega = theta
    inverse_frequcy = [omega**(-2*j/d_rope) for j in range(d_rope//2)]

    # Step 2: Build angle matrix of shape (S, d_rope/2)
    # pass  # YOUR CODE HERE
    theta = positions.unsqueeze(1) * torch.tensor(inverse_frequcy, device=device).unsqueeze(0)  # (S, d_rope/2)
    # Step 3: Repeat angles to match d_rope dimension
    # Create cos and sin angles for the rotation matrix
    # pass  # YOUR CODE HERE
    cos_theta = torch.cos(theta).repeat_interleave(2, dim=-1)  # (S, d_rope)
    sin_theta = torch.sin(theta).repeat_interleave(2, dim=-1)  # (S, d_rope)

    # Step 4: Apply rotation to first d_rope channels
    x_rotated = (x[..., :d_rope] * cos_theta.unsqueeze(1) +
                 rotate_half(x[..., :d_rope]) * sin_theta.unsqueeze(1))

    # Step 5: Concatenate with unchanged channels if d_rope < d_h
    if d_rope < d_h:
        x_rotated = torch.cat([x_rotated, x[..., d_rope:]], dim=-1)
   # Return rotated tensor with shape (B, S, H, dh)
    return x_rotated

### 1.4 Grouped-Query Attention (GQA) [12 points]

In [None]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention mechanism.

    Paper: "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
    https://arxiv.org/abs/2305.13245

    Key implementation details:
    - Q has H heads, K and V have G groups where G < H and G divides H
    - Each KV group is shared across H/G query heads
    - Memory efficient: reduces KV cache size by factor of H/G
    - Use repeat_interleave to broadcast KV groups to query heads

    Args:
        d: Model dimension
        num_heads: Number of query heads (H)
        num_kv_groups: Number of key-value groups (G), must divide num_heads
        dropout_prob: Dropout probability
        d_rope: Dimension for RoPE (if None, uses d_h)
    """
    def __init__(self, d: int, num_heads: int, num_kv_groups: int,
                 dropout_prob: float, d_rope: Optional[int] = None):
        super().__init__()
        assert d % num_heads == 0, "d must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.d = d
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.d_h = d // num_heads
        self.d_rope = d_rope if d_rope is not None else self.d_h

        # TODO: Define linear projections (all bias=False): query, key-value, output
        # pass  # YOUR CODE HERE
        self.w_query = nn.Linear(d, d, bias=False)  # Projects to H¬∑d_h = d
        self.w_kv = nn.Linear(d, 2 * self.d_h * num_kv_groups, bias=False)  # Projects to 2¬∑G¬∑d_h
        self.w_output = nn.Linear(d, d, bias=False)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.

        Args:
            x: Input tensor of shape (B, S, d)

        Returns:
            output: Context tensor of shape (B, S, d)
            attn_weights: Attention weights of shape (B, H, S, S)
        """
        B, S, d = x.shape

        # TODO: Step 1 - Project to Q, K, V
        # pass  # YOUR CODE HERE
        Q = self.w_query(x)  # (B, S, d)
        KV = self.w_kv(x)    # (B, S, 2¬∑G¬∑d_h)


        # TODO: Step 2 - Reshape Q to separate heads
        # pass  # YOUR CODE HERE
        q = Q.view(B, S, self.num_heads, self.d_h)  # (B, S, H, d_h) # view reshape the same data but with different shape

        # TODO: Step 3 - Reshape and split KV to get K and V
        # pass  # YOUR CODE HERE
        kv = KV.view(B, S, self.num_kv_groups, 2 * self.d_h)  # (B, S, G, 2¬∑d_h) #TODO check if it's just d_h or 2*d_h
        k, v = kv.split(self.d_h, dim=-1)  # each (B, S, G, d_h)

        # TODO: Step 4 - Broadcast K and V to match query heads
        # pass  # YOUR CODE HERE
        k = k.repeat_interleave(self.num_heads // self.num_kv_groups, dim=2)  # (B, S, H, d_h)
        v = v.repeat_interleave(self.num_heads // self.num_kv_groups, dim=2)  # (B, S, H, d_h)

        # TODO: Step 5 - Apply RoPE to Q and K
        # pass  # YOUR CODE HERE
        positions = torch.arange(S)
        q = apply_rope(q, positions, self.d_rope)
        k = apply_rope(k, positions, self.d_rope)

        # TODO: Step 6 - Transpose for attention computation
        # pass  # YOUR CODE HERE
        q = q.transpose(1, 2)  # (B, num_heads, S, d_h)
        k = k.transpose(1, 2)  # (B, num_heads, S, d_h)
        v = v.transpose(1, 2)  # (B, num_heads, S, d_h)

        # TODO: Step 7 - Compute scaled dot-product attention scores
        # pass  # YOUR CODE HERE
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_h)  # (B, H, S, S)
        

        # TODO: Step 8 - Apply causal mask
        # pass  # YOUR CODE HERE
        mask = torch.triu(torch.ones(S, S), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        mask.unsqueeze(0).unsqueeze(0)  # (1, 1, S, S)
        scores = scores + mask  # Apply the mask to the scores

        # TODO: Step 9 - Apply softmax and dropout
        # pass  # YOUR CODE HERE
        attn_weights = torch.softmax(scores, dim=-1)  # (B, H, S, S)
        attn_weights = self.dropout(attn_weights)

        # TODO: Step 10 - Apply attention to values
        # pass  # YOUR CODE HERE
        context = torch.matmul(attn_weights, v)  # (B, H, S, d_h)
        context = context.transpose(1, 2).contiguous().view(B, S, d)  # (B, S, d)

        # TODO: Step 11 - Concatenate heads and project
        # pass  # YOUR CODE HERE
        output = self.w_output(context)  # (B, S, d)
        output = self.dropout(output) # Note: dropout after output projection not sure if it's necessary
        return output, attn_weights

### 1.5 Decoder Block with Parallel Pre-Normalization [3 points]

In [None]:
class DecoderBlock(nn.Module):
    """
    Transformer decoder block with parallel pre-normalization.

    Architecture based on modern LLMs (e.g., PaLM, LLaMA):
    - Pre-normalization (norm before attention/FFN, not after)
    - Parallel formulation: both branches normalize the same input
    - Equation: y = x + Dropout(Attn(Norm(x))) + Dropout(FFN(Norm(x)))

    Note: This differs from sequential (GPT-2 style) where FFN sees attention output

    Args:
        d: Model dimension
        num_heads: Number of attention heads
        num_kv_groups: Number of key-value groups
        d_ff: Feed-forward hidden dimension
        dropout_prob: Dropout probability
    """
    def __init__(self, d: int, num_heads: int, num_kv_groups: int,
                 d_ff: int, dropout_prob: float = 0.1):
        super().__init__()

        # TODO: Create two RMSNorm instances (one for attention path, one for FFN path)
        # pass  # YOUR CODE HERE
        self.attn_RMSNorm = RMSNorm(d=d)
        self.FFN_RMSNorm = RMSNorm(d=d)

        # TODO: Create GroupedQueryAttention and FeedForward instances
        # pass  # YOUR CODE HERE
        self.group_attention = GroupedQueryAttention(d=d, num_heads=num_heads, num_kv_groups=num_kv_groups,dropout_prob=0.1)
        self.feed_forward = FeedForward(d=d, d_ff=d_ff, dropout_prob=dropout_prob)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with parallel pre-normalization.

        Args:
            x: Input tensor of shape (B, S, d)

        Returns:
            output: Output tensor of shape (B, S, d)
            attn_weights: Attention weights of shape (B, H, S, S)
        """
        # TODO: Implement parallel pre-normalization
        # Step 1: Attention branch
        # Step 2: FFN branch (normalizes the SAME input x, not attn output)
        # Step 3: Sum both branches with input residual

        # pass  # YOUR CODE HERE
        attn_input = self.attn_RMSNorm(x)
        attn_output, attn_weights = self.group_attention(attn_input)

        ffn_input = self.FFN_RMSNorm(x)
        ffn_output = self.feed_forward(ffn_input)

        output = x + self.dropout(attn_output) + self.dropout(ffn_output)
        return output, attn_weights


### 1.6 Complete Model: ModernDecoderLM [3 points]

In [None]:
class ModernDecoderLM(nn.Module):
    """
    Modern decoder-only language model with state-of-the-art components.

    Architecture features:
    - Token embeddings only (no absolute positional embeddings - RoPE handles this)
    - Stack of decoder blocks with parallel pre-norm
    - Final RMSNorm before output
    - Weight tying: embedding weights shared with LM head

    Args:
        vocab_size: Vocabulary size
        d: Model dimension
        num_layers: Number of decoder layers
        num_heads: Number of attention heads
        num_kv_groups: Number of key-value groups
        d_ff: Feed-forward hidden dimension
        dropout_prob: Dropout probability
    """
    def __init__(self, vocab_size: int, d: int, num_layers: int,
                 num_heads: int, num_kv_groups: int, d_ff: int,
                 dropout_prob: float):
        super().__init__()

        self.vocab_size = vocab_size
        self.d = d
        self.num_layers = num_layers

        # TODO: Create token embedding (no positional embeddings - RoPE handles positions)
        # pass  # YOUR CODE HERE
        self.token_embedding = nn.Embedding(vocab_size, d)
        # TODO: Create stack of decoder blocks using nn.ModuleList. L instances of DecoderBlock stacked sequentially.
        # pass  # YOUR CODE HERE
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(d=d, num_heads=num_heads, num_kv_groups=num_kv_groups,
                         d_ff=d_ff, dropout_prob=dropout_prob)
            for _ in range(num_layers)
        ])

        # TODO: Create final normalization layer
        # pass  # YOUR CODE HERE
        self.final_norm = RMSNorm(d=d)

        # TODO: Create language model head (linear projection to vocabulary)
        # Then implement weight tying
        # pass  # YOUR CODE HERE
        self.lm_head = nn.Linear(d, vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight 

    def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass.

        Args:
            input_ids: Token indices of shape (B, S)

        Returns:
            logits: Logits of shape (B, S, vocab_size)
            hidden_states: Hidden states of shape (L, B, S, d)
            attn_weights: Attention weights of shape (L, B, H, S, S)
        """
        # TODO: Implement forward pass
        # Step 1: Embed tokens
        # Step 2: Pass through decoder stack, collecting hidden states and attention weights
        # Step 3: Apply final normalization
        # Step 4: Project to vocabulary
        # Step 5: Stack hidden states and attention weights into tensors

        # pass  # YOUR CODE HERE
        x = self.token_embedding(input_ids)

        hidden_states = []
        attn_weights_list = []
        for block in self.decoder_blocks:
            x, attn = block(x)
            hidden_states.append(x)
            attn_weights_list.append(attn)

        x = self.final_norm(x)
        logits = self.lm_head(x)

        hidden_states = torch.stack(hidden_states)
        attn_weights = torch.stack(attn_weights_list)

        return logits, hidden_states, attn_weights

## Section 2: Dataset Generation (DO NOT MODIFY)

In [None]:
def create_modular_arithmetic_dataset(p=11, train_split=0.9, seed=42):
    """
    Generate modular arithmetic dataset.

    Creates equations over Z/pZ (integers modulo p):
    - Binary: [BOS] a op b [=] r [EOS] [PAD] [PAD]
    - Ternary: [BOS] a op b op c [=] r [EOS]
    where op ‚àà {+, *} and r is result mod p

    Args:
        p: Modulus (prime number, default 11)
        train_split: Fraction of data for training (default 0.9)
        seed: Random seed for reproducibility

    Returns:
        train_inputs: Training input sequences (N_train, S-1)
        train_targets: Training target sequences (N_train, S-1)
        val_inputs: Validation input sequences (N_val, S-1)
        val_targets: Validation target sequences (N_val, S-1)
        vocab: Vocabulary dictionary
    """
    np.random.seed(seed)

    # Vocabulary: digits 0-(p-1), +, *, [BOS], [EOS], [PAD], [=]
    vocab = {
        'token_to_idx': {},
        'idx_to_token': {}
    }

    # Assign indices
    for i in range(p):
        vocab['token_to_idx'][str(i)] = i
        vocab['idx_to_token'][i] = str(i)

    vocab['token_to_idx']['+'] = p
    vocab['token_to_idx']['*'] = p + 1
    vocab['token_to_idx']['[BOS]'] = p + 2
    vocab['token_to_idx']['[EOS]'] = p + 3
    vocab['token_to_idx']['[PAD]'] = p + 4
    vocab['token_to_idx']['[=]'] = p + 5

    for i in range(p, p + 6):
        vocab['idx_to_token'][i] = list(vocab['token_to_idx'].keys())[list(vocab['token_to_idx'].values()).index(i)]

    BOS, EOS, PAD, EQ = p + 2, p + 3, p + 4, p + 5
    ADD, MUL = p, p + 1
    sequences = []

    # Binary operations: [BOS] a op b [=] r [EOS] [PAD] [PAD]
    for op in [ADD, MUL]:
        for a in range(p):
            for b in range(p):
                r = (a + b) % p if op == ADD else (a * b) % p
                seq = [BOS, a, op, b, EQ, r, EOS, PAD]
                sequences.append(seq)

    # Ternary operations: [BOS] a op b op c [=] r [EOS]
    for op in [ADD, MUL]:
        for a in range(p):
            for b in range(p):
                for c in range(p):
                    r = (a + b + c) % p if op == ADD else (a * b * c) % p
                    seq = [BOS, a, op, b, op, c, EQ, r]
                    sequences.append(seq)

    sequences = np.array(sequences)

    # Shuffle and split
    indices = np.random.permutation(len(sequences))
    split_idx = int(len(sequences) * train_split)

    train_seqs = sequences[indices[:split_idx]]
    val_seqs = sequences[indices[split_idx:]]

    # Convert to tensors (input is all but last, target is all but first)
    train_inputs = torch.LongTensor(train_seqs[:, :-1])
    train_targets = torch.LongTensor(train_seqs[:, 1:])
    val_inputs = torch.LongTensor(val_seqs[:, :-1])
    val_targets = torch.LongTensor(val_seqs[:, 1:])

    return train_inputs, train_targets, val_inputs, val_targets, vocab

## Section 3: Training and Evaluation Infrastructure (DO NOT MODIFY)

In [None]:
def compute_metrics(logits, targets, vocab):
    """
    Compute loss and accuracy on tokens after [=] sign only.

    Args:
        logits: Model logits (B, S, V)
        targets: Target tokens (B, S)
        vocab: Vocabulary dictionary

    Returns:
        loss: Mean cross-entropy loss on RHS tokens
        accuracy: Sequence-level accuracy (all RHS tokens correct)
    """
    B, S, V = logits.shape
    EQ_idx = vocab['token_to_idx']['[=]']
    PAD_idx = vocab['token_to_idx']['[PAD]']
    EOS_idx = vocab['token_to_idx']['[EOS]']

    # Create mask for tokens after [=] (excluding [PAD] and [EOS])
    eq_mask = (targets == EQ_idx)
    mask = torch.zeros_like(targets, dtype=torch.bool)

    for i in range(B):
        eq_positions = torch.where(eq_mask[i])[0]
        if len(eq_positions) > 0:
            eq_pos = eq_positions[0].item()
            for j in range(eq_pos + 1, S):
                if targets[i, j] != PAD_idx and targets[i, j] != EOS_idx:
                    mask[i, j] = True

    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)

    # Compute loss on masked tokens
    logits_flat = logits.view(-1, V)[mask.view(-1)]
    targets_flat = targets.view(-1)[mask.view(-1)]
    loss = F.cross_entropy(logits_flat, targets_flat)

    # Compute sequence-level accuracy
    preds = logits.argmax(dim=-1)
    correct_tokens = (preds == targets) | (~mask)
    correct_sequences = correct_tokens.all(dim=1).float().mean()

    return loss, correct_sequences


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """
    Cosine learning rate schedule with linear warmup.

    LR increases linearly from 0 to max_lr during warmup,
    then decreases following cosine curve to 0.
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_model(model, train_loader, val_loader, vocab, config, device):
    """
    Train the model with AdamW optimizer and cosine schedule.

    Args:
        model: Model to train
        train_loader: Training dataloader
        val_loader: Validation dataloader
        vocab: Vocabulary dictionary
        config: Configuration dictionary with hyperparameters
        device: Device to train on

    Returns:
        history: Dictionary with training history
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                   lr=config['learning_rate'],
                                   weight_decay=config['weight_decay'])

    num_training_steps = config['num_epochs'] * len(train_loader)
    num_warmup_steps = int(0.1 * num_training_steps)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'steps': []
    }

    if 'track_param_norm' in config and config['track_param_norm']:
        history['param_norms'] = []

    best_val_acc = 0.0
    step = 0

    for epoch in range(config['num_epochs']):
        model.train()
        train_loss_epoch = []
        train_acc_epoch = []

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
        for batch_inputs, batch_targets in pbar:
            batch_inputs = batch_inputs.to(device)
            batch_targets = batch_targets.to(device)

            optimizer.zero_grad()
            logits, _, _ = model(batch_inputs)
            loss, acc = compute_metrics(logits, batch_targets, vocab)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
            optimizer.step()
            scheduler.step()

            train_loss_epoch.append(loss.item())
            train_acc_epoch.append(acc.item())

            pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})
            step += 1

        # Validation
        model.eval()
        val_loss_list = []
        val_acc_list = []

        with torch.no_grad():
            for batch_inputs, batch_targets in val_loader:
                batch_inputs = batch_inputs.to(device)
                batch_targets = batch_targets.to(device)

                logits, _, _ = model(batch_inputs)
                loss, acc = compute_metrics(logits, batch_targets, vocab)

                val_loss_list.append(loss.item())
                val_acc_list.append(acc.item())

        train_loss_avg = np.mean(train_loss_epoch)
        train_acc_avg = np.mean(train_acc_epoch)
        val_loss_avg = np.mean(val_loss_list)
        val_acc_avg = np.mean(val_acc_list)

        history['train_loss'].append(train_loss_avg)
        history['train_acc'].append(train_acc_avg)
        history['val_loss'].append(val_loss_avg)
        history['val_acc'].append(val_acc_avg)
        history['steps'].append(step)

        # Track parameter norm if requested
        if 'track_param_norm' in config and config['track_param_norm']:
            param_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters() if p.requires_grad) ** 0.5
            history['param_norms'].append(param_norm)

        print(f"Epoch {epoch+1}: Train Loss={train_loss_avg:.4f}, Train Acc={train_acc_avg:.4f}, "
              f"Val Loss={val_loss_avg:.4f}, Val Acc={val_acc_avg:.4f}")

        # Save best model
        if val_acc_avg > best_val_acc:
            best_val_acc = val_acc_avg
            torch.save(model.state_dict(), 'best_model.pt')

    return history

## Section 4A: Plots and Summary Utilities (DO NOT MODIFY)

In [None]:
def plot_training_curves(history, title="Training Curves", save_name=None):
    """Plot loss and accuracy curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    steps = history['steps']

    # Loss curves
    ax1.plot(steps, history['train_loss'], label='Train Loss', alpha=0.8, linewidth=2, color='#2E86AB')
    ax1.plot(steps, history['val_loss'], label='Val Loss', alpha=0.8, linewidth=2, color='#A23B72')
    ax1.set_xlabel('Training Steps', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Loss Curves', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)

    # Accuracy curves
    ax2.plot(steps, history['train_acc'], label='Train Accuracy', alpha=0.8, linewidth=2, color='#2E86AB')
    ax2.plot(steps, history['val_acc'], label='Val Accuracy', alpha=0.8, linewidth=2, color='#A23B72')
    ax2.set_xlabel('Training Steps', fontsize=12)
    ax2.set_ylabel('Accuracy', fontsize=12)
    ax2.set_title('Accuracy Curves', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim([0, 1.05])

    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()

    if save_name:
        plt.savefig(f"{save_name}.png", dpi=150, bbox_inches='tight')
    plt.show()


def compute_summary_metrics(history):
    """Compute summary metrics from training history."""
    metrics = {
        'best_train_loss': min(history['train_loss']),
        'best_val_loss': min(history['val_loss']),
        'best_train_acc': max(history['train_acc']),
        'best_val_acc': max(history['val_acc']),
        'step_best_train_loss': history['steps'][np.argmin(history['train_loss'])],
        'step_best_val_loss': history['steps'][np.argmin(history['val_loss'])],
        'step_best_train_acc': history['steps'][np.argmax(history['train_acc'])],
        'step_best_val_acc': history['steps'][np.argmax(history['val_acc'])],
    }

    metrics['lag_loss'] = metrics['step_best_val_loss'] - metrics['step_best_train_loss']
    metrics['lag_acc'] = metrics['step_best_val_acc'] - metrics['step_best_train_acc']

    return metrics

def print_summary_table(metrics):
    """Print formatted summary table"""
    print("\n" + "="*60)
    print("SUMMARY METRICS")
    print("="*60)
    print(f"{'Metric':<35} {'Value':<20}")
    print("-" * 60)
    print(f"{'Best Train Loss':<35} {metrics['best_train_loss']:.4f}")
    print(f"{'Best Val Loss':<35} {metrics['best_val_loss']:.4f}")
    print(f"{'Best Train Accuracy':<35} {metrics['best_train_acc']:.4f}")
    print(f"{'Best Val Accuracy':<35} {metrics['best_val_acc']:.4f}")
    print(f"{'Step (Best Train Loss)':<35} {metrics['step_best_train_loss']}")
    print(f"{'Step (Best Val Loss)':<35} {metrics['step_best_val_loss']}")
    print(f"{'Step (Best Train Acc)':<35} {metrics['step_best_train_acc']}")
    print(f"{'Step (Best Val Acc)':<35} {metrics['step_best_val_acc']}")
    print(f"{'Generalization Lag (Loss)':<35} {metrics['lag_loss']} steps")
    print(f"{'Generalization Lag (Acc)':<35} {metrics['lag_acc']} steps")
    print("="*60)

## Section 4 B: Visualization Utilities (DO NOT MODIFY)

In [None]:
def visualize_attention_patterns(model, val_inputs, vocab, device):
    """Visualize attention patterns for all 5 standardized examples"""

    STANDARD_EXAMPLES = {
        'binary_add_small': [13, 2, 11, 3, 16, 5, 14, 15],
        'binary_add_carry': [13, 7, 11, 8, 16, 4, 14, 15],
        'binary_mult_small': [13, 2, 12, 3, 16, 6, 14, 15],
        'ternary_add': [13, 1, 11, 2, 11, 3, 16, 6],
        'ternary_mult': [13, 2, 12, 3, 12, 4, 16, 2],
    }

    EXAMPLE_TITLES = {
        'binary_add_small': '1. Binary Addition (Small): [BOS] 2 + 3 [=] 5 [EOS] [PAD]',
        'binary_add_carry': '2. Binary Addition (Carry): [BOS] 7 + 8 [=] 4 [EOS] [PAD]',
        'binary_mult_small': '3. Binary Multiplication: [BOS] 2 * 3 [=] 6 [EOS] [PAD]',
        'ternary_add': '4. Ternary Addition: [BOS] 1 + 2 + 3 [=] 6 [EOS]',
        'ternary_mult': '5. Ternary Multiplication: [BOS] 2 * 3 * 4 [=] 2 [EOS]'
    }

    model.eval()

    for example_key, target_seq in STANDARD_EXAMPLES.items():
        # Input is all but last token (for autoregressive prediction)
        input_seq = target_seq[:-1]
        input_tensor = torch.LongTensor(input_seq)

        # Find the example in validation set
        found = False
        for i in range(len(val_inputs)):
            # Compare input sequences (length 7)
            if len(val_inputs[i]) == len(input_tensor):
                if torch.all(val_inputs[i] == input_tensor):
                    example_input = val_inputs[i]
                    found = True
                    break

        if not found:
            print(f"Note: Using constructed example for {example_key}")
            example_input = input_tensor

        # Get tokens for labeling (use full sequence for display)
        full_seq_tokens = [vocab['idx_to_token'][idx] for idx in target_seq]

        # Get attention weights using input sequence
        with torch.no_grad():
            input_ids = example_input.unsqueeze(0).to(device)
            logits, hidden_states, attn_weights = model(input_ids)

            # Get last layer attention: (H, S, S) where S is input length
            attn = attn_weights[-1, 0].cpu().numpy()

        num_heads = attn.shape[0]
        seq_len = attn.shape[1]

        # Use tokens corresponding to input length
        tokens = full_seq_tokens[:seq_len]

        # Create visualization
        fig, axes = plt.subplots(2, 4, figsize=(20, 10))
        axes = axes.flatten()

        for h in range(num_heads):
            ax = axes[h]
            im = ax.imshow(attn[h], cmap='viridis', aspect='auto', vmin=0, vmax=1)

            # Set ticks and labels
            ax.set_xticks(range(seq_len))
            ax.set_yticks(range(seq_len))
            ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
            ax.set_yticklabels(tokens, fontsize=9)

            ax.set_title(f'Head {h}', fontsize=12, fontweight='bold')
            ax.set_xlabel('Key Position', fontsize=10)
            ax.set_ylabel('Query Position', fontsize=10)

            # Add colorbar
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        full_title = EXAMPLE_TITLES[example_key]
        plt.suptitle(f'Attention Patterns - Final Layer\n{full_title}',
                     fontsize=14, fontweight='bold')
        plt.tight_layout()

        filename = f'attention_{example_key}.png'
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {filename}")
        plt.show()
        print()  # Add spacing between examples

## Section 5: Experiment 1 - Sanity Check and Baseline [6 points]

In [None]:
# Generate dataset
train_inputs, train_targets, val_inputs, val_targets, vocab = create_modular_arithmetic_dataset(p=11)

print(f"Dataset Statistics:")
print(f"  Training samples: {len(train_inputs)}")
print(f"  Validation samples: {len(val_inputs)}")
print(f"  Vocabulary size: {len(vocab['token_to_idx'])}")
print(f"  Sequence length: {train_inputs.shape[1]}")
print(f"  Total equations: {len(train_inputs) + len(val_inputs)}")

# Model configuration
config = {
    'vocab_size': len(vocab['token_to_idx']),
    'd': 128,
    'num_layers': 4,
    'num_heads': 8,
    'num_kv_groups': 4,
    'd_ff': 512,
    'dropout_prob': 0.1,
    'learning_rate': 3e-4,
    'weight_decay': 1e-4,
    'num_epochs': 100,
    'grad_clip': 1.0,
    'batch_size': 64,
}

print(f"\nModel Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# Create dataloaders
batch_size = config['batch_size']
train_dataset = TensorDataset(train_inputs, train_targets)
val_dataset = TensorDataset(val_inputs, val_targets)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

def init_weights(module):
    """Initialize model with better weight initialization (small values)"""
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

model = ModernDecoderLM(
    vocab_size=config['vocab_size'],
    d=config['d'],
    num_layers=config['num_layers'],
    num_heads=config['num_heads'],
    num_kv_groups=config['num_kv_groups'],
    d_ff=config['d_ff'],
    dropout_prob=config['dropout_prob']
)
model.apply(init_weights)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {num_params:,}")

# Train model
print("\nTraining baseline model...")
history_baseline = train_model(model, train_loader, val_loader, vocab, config, device)

# SOLUTION: Part A - Generate plots
print("\n--- Part A: Training Curves ---")
plot_training_curves(history_baseline, "Experiment 1: Baseline Training", "exp1_baseline_curves")

# SOLUTION: Part B - Summary metrics and table
print("\n--- Part B: Summary Metrics ---")
metrics_baseline = compute_summary_metrics(history_baseline)
print_summary_table(metrics_baseline)

In [None]:
# print("\n" + "="*80)
# print("TODO: Write your interpretation in the PDF report [3 points]:")
# print("Address the following questions:")
# print("1. Does the model successfully learn the modular arithmetic task?")
# print("2. Is there evidence of overfitting or underfitting?")
# print("3. How quickly does validation performance follow training performance?")
# print("4. What do the generalization lags suggest about task difficulty?")
# print("5. Are there any unexpected behaviors in the learning curves?")
# print("="*80 + "\n")

## Section 6: Experiment 2 - Regularization and Optimization [8 points]

### Part A: Dropout Sweep [4 points]

In [None]:
dropout_values = [0.0, 0.2, 0.4]
dropout_histories = {}

for dropout_prob in dropout_values:
    print(f"\nTraining with dropout={dropout_prob}")

    config_dropout = config.copy()
    config_dropout['dropout_prob'] = dropout_prob
    config_dropout['num_epochs'] = 100

    model_dropout = ModernDecoderLM(
        vocab_size=config_dropout['vocab_size'],
        d=config_dropout['d'],
        num_layers=config_dropout['num_layers'],
        num_heads=config_dropout['num_heads'],
        num_kv_groups=config_dropout['num_kv_groups'],
        d_ff=config_dropout['d_ff'],
        dropout_prob=config_dropout['dropout_prob']
    )
    model_dropout.apply(init_weights)

    history_dropout = train_model(model_dropout, train_loader, val_loader,
                                   vocab, config_dropout, device)
    dropout_histories[dropout_prob] = history_dropout

# Plot comparison of dropout values
print("\nPlotting dropout comparison...")
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

colors = {'0.0': '#E63946', '0.2': '#2A9D8F', '0.4': '#F77F00'}

for dropout_prob, history in dropout_histories.items():
    color = colors[str(dropout_prob)]
    label = f'Dropout={dropout_prob}'

    axes[0, 0].plot(history['steps'], history['train_loss'],
                    label=label, alpha=0.8, linewidth=2, color=color)
    axes[0, 1].plot(history['steps'], history['val_loss'],
                    label=label, alpha=0.8, linewidth=2, color=color)
    axes[1, 0].plot(history['steps'], history['train_acc'],
                    label=label, alpha=0.8, linewidth=2, color=color)
    axes[1, 1].plot(history['steps'], history['val_acc'],
                    label=label, alpha=0.8, linewidth=2, color=color)

titles = ['Training Loss', 'Validation Loss', 'Training Accuracy', 'Validation Accuracy']
for ax, title in zip(axes.flat, titles):
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.set_xlabel('Steps', fontsize=10)
    ax.set_ylabel(title.split()[-1], fontsize=10)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    if 'Accuracy' in title:
        ax.set_ylim([0, 1.05])

plt.suptitle('Experiment 2A: Dropout Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('exp2a_dropout_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# SOLUTION: Summary table
print("\n" + "="*60)
print("DROPOUT SWEEP SUMMARY")
print("="*60)
print(f"{'Dropout':<12} {'Best Val Loss':<16} {'Best Val Acc':<16} {'Lag (Acc)':<15}")
print("-" * 60)
for dropout_prob, history in dropout_histories.items():
    metrics_drop = compute_summary_metrics(history)
    print(f"{dropout_prob:<12} {metrics_drop['best_val_loss']:<16.4f} "
          f"{metrics_drop['best_val_acc']:<16.4f} {metrics_drop['lag_acc']:<15}")
print("="*60)

In [None]:
# print("\n" + "="*80)
# print("TODO: Write your analysis in the PDF report [2 points]:")
# print("Address the following questions:")
# print("1. Which dropout value yields the best validation performance? Why?")
# print("2. Explain the trade-off between underfitting (too much dropout)")
# print("   and overfitting (too little dropout).")
# print("3. How does dropout affect the generalization lag?")
# print("4. Is there a clear optimal dropout rate for this task?")
# print("="*80 + "\n")

### Part B: Weight Decay Analysis [4 points]

In [None]:
weight_decay_values = [2.5e-4, 5e-4, 1e-3]
weight_decay_histories = {}

for wd in weight_decay_values:
    print(f"\nTraining with weight_decay={wd}")

    config_wd = config.copy()
    config_wd['weight_decay'] = wd
    config_wd['num_epochs'] = 100
    config_wd['track_param_norm'] = True

    model_wd = ModernDecoderLM(
        vocab_size=config_wd['vocab_size'],
        d=config_wd['d'],
        num_layers=config_wd['num_layers'],
        num_heads=config_wd['num_heads'],
        num_kv_groups=config_wd['num_kv_groups'],
        d_ff=config_wd['d_ff'],
        dropout_prob=config_wd['dropout_prob']
    )
    model_wd.apply(init_weights)

    history_wd = train_model(model_wd, train_loader, val_loader, vocab, config_wd, device)
    weight_decay_histories[wd] = history_wd

# Plot parameter norms
print("\nPlotting parameter norms...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

colors_wd = {2.5e-4: '#264653', 5e-4: '#2A9D8F', 1e-3: '#E76F51'}

# Parameter norms
for wd, history in weight_decay_histories.items():
    if 'param_norms' in history:
        color = colors_wd[wd]
        ax1.plot(history['steps'], history['param_norms'],
                 label=f'WD={wd}', alpha=0.8, linewidth=2, color=color)

ax1.set_xlabel('Training Steps', fontsize=12)
ax1.set_ylabel('L2 Parameter Norm', fontsize=12)
ax1.set_title('Parameter Norm vs Training Steps', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Validation accuracy
for wd, history in weight_decay_histories.items():
    color = colors_wd[wd]
    ax2.plot(history['steps'], history['val_acc'],
             label=f'WD={wd}', alpha=0.8, linewidth=2, color=color)

ax2.set_xlabel('Training Steps', fontsize=12)
ax2.set_ylabel('Validation Accuracy', fontsize=12)
ax2.set_title('Validation Accuracy vs Training Steps', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1.05])

plt.suptitle('Experiment 2B: Weight Decay Analysis', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('exp2b_weight_decay_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

# SOLUTION: Summary table
print("\n" + "="*60)
print("WEIGHT DECAY SUMMARY")
print("="*60)
print(f"{'Weight Decay':<18} {'Best Val Acc':<18} {'Final Param Norm':<20}")
print("-" * 60)
for wd, history in weight_decay_histories.items():
    best_val_acc = max(history['val_acc'])
    final_norm = history['param_norms'][-1] if 'param_norms' in history else 0.0
    print(f"{wd:<18} {best_val_acc:<18.4f} {final_norm:<20.2f}")
print("="*60)

In [None]:
# print("\n" + "="*80)
# print("TODO: Write your discussion in the PDF report [2 points]:")
# print("Address the following questions:")
# print("1. How does weight decay affect parameter norms over training?")
# print("2. Is there a correlation between norm magnitude and validation performance?")
# print("3. What does this suggest about weight decay's role in preventing overfitting?")
# print("4. Which weight decay value provides the best balance?")
# print("="*80 + "\n")

### Section 7: Experiment 3 - Attention Pattern Visualization [6 points]

In [None]:
# Load best baseline model
model.load_state_dict(torch.load('best_model.pt'))
model = model.to(device)
model.eval()
print("Loaded best baseline model checkpoint")

print("\nGenerating attention visualizations for all 5 standardized examples...")
visualize_attention_patterns(model, val_inputs, vocab, device)

In [None]:
# print("\n" + "="*80)
# print("TODO: Write your interpretability analysis in the PDF report [3 points]:")
# print("For each of the 5 examples, analyze the attention patterns:")
# print("1. Do tokens after [=] attend back to the operands and operators?")
# print("2. Do different heads exhibit different specializations?")
# print("   (e.g., local vs. global attention, operator-focused vs. operand-focused)")
# print("3. Can you identify any heads that seem to implement specific")
# print("   algorithmic components (e.g., carrying information forward)?")
# print("="*80 + "\n")

## Section 8: Save Results

In [None]:
results = {
    'experiment_1': {
        'history': history_baseline,
        'metrics': metrics_baseline,
        'config': config
    },
    'experiment_2a': {
        'dropout_histories': dropout_histories,
        'dropout_values': dropout_values
    },
    'experiment_2b': {
        'weight_decay_histories': weight_decay_histories,
        'weight_decay_values': weight_decay_values
    },
}

with open('HW2_A25_modern_transformer_results.pkl', 'wb') as f:
    pickle.dump(results, f)

print("Results saved to 'HW2_A25_modern_transformer_results.pkl'")