# Understanding Mamba: State Space Models from First Principles

Transformers revolutionized sequence modeling, but they have a fundamental limitation: **O(nÂ²) attention complexity**. For a 100k token context, that's 10 billion operations per layer. 

Mamba offers an alternative: **O(n) linear complexity** while maintaining expressiveness. But how does it work? And why did it take until 2023 to figure this out?

This notebook builds intuition from the ground upâ€”starting from physics and control theory, through the evolution of state space models, to understanding exactly why Mamba works.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from einops import rearrange, repeat, einsum
from tqdm.auto import tqdm
import math

# Add parent directory to path for our utilities
import sys
sys.path.append('../..')
from silen_lib.utils import utils
utils.set_seed(42)

# Check if we have a GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## The Transformer Bottleneck

Before we understand Mamba, let's feel the pain it solves.

In a Transformer, every token attends to every other token. If we have $n$ tokens, that's $n \times n$ attention computations per layer.


In [None]:
# With 1,000 tokens
n = 1000
n * n


In [None]:
# With 100,000 tokens (modern context windows)
n = 100_000
f"{n * n:,} operations"


In [None]:
# Let's visualize how these scale

seq_lengths = torch.arange(1, 10001, 100)

# Transformer: O(nÂ²) attention
transformer_ops = seq_lengths ** 2

# Mamba: O(n) linear
mamba_ops = seq_lengths * 16  # state dimension typically 16

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(seq_lengths, transformer_ops / 1e6, label='Transformer O(nÂ²)', linewidth=2)
ax.plot(seq_lengths, mamba_ops / 1e6, label='Mamba O(n)', linewidth=2)
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Operations (millions)')
ax.set_title('The Quadratic Wall')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()


## Where Did This Thinking Come From?

State Space Models aren't new to deep learningâ€”they come from **control theory** and **signal processing**, developed in the 1960s.

**The lineage:**
- **1960s**: Rudolf Kalman develops state space representations for control systems (Kalman filters)
- **1970s-2000s**: SSMs become fundamental to signal processing, robotics, and engineering
- **2020**: HiPPO paper shows SSMs can have long-range memory in deep learning
- **2021**: S4 (Structured State Spaces) achieves breakthrough on Long Range Arena
- **2023**: Mamba adds selectivity, matching Transformer quality with linear complexity

The key insight: **engineers have been solving the "process a sequence efficiently" problem for 60 years**. Deep learning just needed to adapt these ideas.


## What is "State"? Building from Physics

Before we dive into equations, let's understand what "state" means intuitively.

**A physics example:** Imagine a ball flying through the air. If I tell you only its position right now, can you predict where it will be in 1 second?


In [None]:
# Ball at position x=5 meters
position = 5.0
position


No! It could be moving left, right, up, down, or sitting still. Position alone doesn't determine the future.

But if I tell you **both position and velocity**...


In [None]:
# Position AND velocity together = the STATE
position = 5.0
velocity = 2.0  # moving right at 2 m/s

state = torch.tensor([position, velocity])
state


In [None]:
# Now I CAN predict the future! After 1 second:
dt = 1.0
new_position = position + velocity * dt
new_position


**This is the key insight:** State is the minimal information needed to predict the future, given the dynamics.

For sequences (like text), we can think of state as a **compressed summary of everything we've seen so far**. Instead of storing all past tokens (like Transformers do with their KV cache), we maintain a fixed-size state that evolves as we see new tokens.


### State as Compressed Memory

Let's make this concrete with a simple example. Imagine you're reading a story and trying to predict the next word.

**Option 1: Store everything** (Transformer approach)
- "The cat sat on the mat and then the cat jumped onto the..."
- Keep all 12 tokens in memory, compute attention over all of them

**Option 2: Compress into state** (SSM approach)  
- Maintain a hidden state that captures: "we're talking about a cat doing actions"
- Update this state with each new token
- Fixed memory regardless of sequence length


In [None]:
# A very simple state that "remembers" by exponential averaging
# Think of it as a leaky bucket of information

state = 0.0
decay = 0.9  # how much of the old state we keep

# Incoming "tokens" (just numbers for now)
tokens = [1.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0]

print("Token â†’ New State")
for t in tokens:
    state = decay * state + (1 - decay) * t
    print(f"  {t:.1f}  â†’  {state:.3f}")


Notice how:
- The state gradually forgets token 1 (0.100 â†’ 0.090 â†’ 0.081...)
- When token 5 arrives, it gets incorporated
- The state is a **weighted average of the history**, with recent tokens weighted more

This is the simplest possible "state space model". The real ones are more sophisticated, but the intuition is the same: **compress history into a fixed-size representation that evolves over time**.


## Continuous-Time State Space Models

Now let's formalize this. In control theory, a continuous-time state space model is defined by:

$$\frac{dx}{dt} = Ax + Bu$$
$$y = Cx + Du$$

Where:
- $x$ is the **state** (hidden representation)
- $u$ is the **input** (incoming signal)
- $y$ is the **output** (what we predict)
- $A$ controls how the state evolves on its own
- $B$ controls how input affects the state
- $C$ controls how state maps to output
- $D$ is a skip connection (input directly to output)

Let's build intuition by simulating a simple system.


### Example: A Damped Oscillator

A mass on a spring with friction is a perfect SSM example. The state is [position, velocity], and the physics determines how it evolves.


In [None]:
# State: [position, velocity]
# dx/dt = Ax means: d[pos, vel]/dt = [[0, 1], [-k, -b]] @ [pos, vel]
# This encodes: d(pos)/dt = vel, d(vel)/dt = -k*pos - b*vel (spring + friction)

k = 1.0   # spring constant
b = 0.3   # damping/friction

A = torch.tensor([
    [0, 1],      # d(position)/dt = velocity
    [-k, -b]     # d(velocity)/dt = -k*position - b*velocity
])
A


In [None]:
# Simulate using Euler's method (simple numerical integration)
# x(t + dt) â‰ˆ x(t) + dt * dx/dt = x(t) + dt * A @ x(t)

dt = 0.05
steps = 200
x = torch.tensor([1.0, 0.0])  # start at position=1, velocity=0

trajectory = [x.clone()]
for _ in range(steps):
    dx_dt = A @ x
    x = x + dt * dx_dt
    trajectory.append(x.clone())

trajectory = torch.stack(trajectory)
trajectory.shape  # (steps+1, 2) - position and velocity at each step


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

time = torch.arange(len(trajectory)) * dt

# Position over time
axes[0].plot(time, trajectory[:, 0])
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Position')
axes[0].set_title('Damped Oscillator: Position')
axes[0].axhline(y=0, color='k', linestyle='--', alpha=0.3)

# Phase space (position vs velocity)
axes[1].plot(trajectory[:, 0], trajectory[:, 1])
axes[1].scatter([trajectory[0, 0]], [trajectory[0, 1]], color='green', s=100, label='Start', zorder=5)
axes[1].scatter([trajectory[-1, 0]], [trajectory[-1, 1]], color='red', s=100, label='End', zorder=5)
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Velocity')
axes[1].set_title('Phase Space: State Evolution')
axes[1].legend()

plt.tight_layout()


The state spirals inward (energy dissipates due to friction) and settles at equilibrium. 

**Key insight about A:** The matrix A determines the system's dynamics. Different A matrices create different behaviors:
- If eigenvalues have negative real parts â†’ system is stable (decays to equilibrium)
- If eigenvalues have positive real parts â†’ system is unstable (explodes)
- Complex eigenvalues â†’ oscillation


In [None]:
# Check eigenvalues of our A matrix
eigenvalues = torch.linalg.eigvals(A)
eigenvalues  # Complex with negative real parts = stable oscillation


### Adding Inputs: The B Matrix

So far our system has no inputâ€”it just evolves on its own. But for language modeling, we need to feed in tokens! That's what B does: it maps input to state changes.


In [None]:
# B maps input to state changes
# If we push the mass, it affects velocity (not position directly)
B = torch.tensor([[0.0], [1.0]])  # input affects velocity
B


In [None]:
# Now simulate with periodic "pushes" (like tokens coming in)
x = torch.tensor([0.0, 0.0])  # start at rest

# Input signal: periodic pushes
inputs = torch.zeros(steps)
inputs[20] = 5.0   # push at step 20
inputs[80] = -3.0  # push other direction at step 80
inputs[140] = 4.0  # another push

trajectory_with_input = [x.clone()]
for i in range(steps):
    u = inputs[i:i+1]  # current input (shape [1])
    dx_dt = A @ x + (B @ u).squeeze()
    x = x + dt * dx_dt
    trajectory_with_input.append(x.clone())

trajectory_with_input = torch.stack(trajectory_with_input)


In [None]:
fig, ax = plt.subplots(figsize=(12, 4))

ax.plot(time, trajectory_with_input[:, 0], label='Position')

# Mark input times
for i, inp in enumerate(inputs):
    if inp != 0:
        ax.axvline(x=i*dt, color='red', alpha=0.5, linestyle='--')
        ax.annotate(f'Push: {inp:.0f}', (i*dt, 0.8), fontsize=9)

ax.set_xlabel('Time')
ax.set_ylabel('Position')
ax.set_title('SSM with Inputs: Each "push" affects future states')
ax.legend()
plt.tight_layout()


**This is exactly how SSMs process sequences!**
- Each token is like a "push" (input through B)
- The state evolves according to A, remembering past inputs
- The effect of each input ripples through time

### The C Matrix: Reading the State

Finally, C maps state to output. We might only care about certain aspects of the state.


In [None]:
# C extracts output from state
# Maybe we only care about position, not velocity
C = torch.tensor([[1.0, 0.0]])  # output = position only
C


In [None]:
# Output for each timestep
outputs = (C @ trajectory_with_input.T).squeeze()
outputs.shape  # one output per timestep


## From Continuous to Discrete: Discretization

There's a problem: our equations are continuous (derivatives), but our data is discrete (tokens at specific positions).

We need to **discretize** the continuous SSM. The most common method is **Zero-Order Hold (ZOH)**: assume the input is constant between timesteps.


### The Discretization Formulas

Given step size $\Delta$ (not to be confused with "change"):

$$\bar{A} = e^{A\Delta}$$
$$\bar{B} = (A)^{-1}(e^{A\Delta} - I) B$$

Then the discrete recurrence becomes:
$$x_k = \bar{A} x_{k-1} + \bar{B} u_k$$
$$y_k = C x_k$$

Let's see what this means intuitively.


In [None]:
# Simplest case: 1D state, 1D input
# dx/dt = a*x + b*u  (scalar version)

a = -0.5  # decay rate (negative = stable)
b = 1.0   # input sensitivity


In [None]:
# Discretize with step size delta
delta = 0.1

# A_bar = exp(a * delta)
A_bar = math.exp(a * delta)
A_bar


In [None]:
# B_bar = (1/a) * (exp(a*delta) - 1) * b
B_bar = (1/a) * (A_bar - 1) * b
B_bar


**What does $\bar{A} = 0.951$ mean?**

Each step, the state is multiplied by 0.951. So after one step, we retain 95.1% of the previous state. After 10 steps: $0.951^{10} = 0.606$. After 100 steps: $0.951^{100} = 0.007$. 

This is **exponential decay**â€”the hallmark of linear systems. The eigenvalues of A control how fast things are forgotten.


In [None]:
# Discrete SSM step is now trivial: x_k = A_bar * x_{k-1} + B_bar * u_k
x = 0.0
for u in [1.0, 0.0, 0.0, 0.0, 0.0]:
    x = A_bar * x + B_bar * u
    print(f"input={u:.1f} â†’ state={x:.4f}")


### Visualizing Discretization

Let's see how the continuous and discrete systems compare:


In [None]:
# Compare continuous and discrete at different step sizes
fig, ax = plt.subplots(figsize=(10, 5))

# "True" continuous solution: x(t) = x0 * exp(a*t) for zero input starting from x=1
t_continuous = torch.linspace(0, 5, 500)
x_continuous = torch.exp(a * t_continuous)
ax.plot(t_continuous, x_continuous, 'k-', label='Continuous', linewidth=2)

# Discrete with different step sizes
for delta in [0.1, 0.5, 1.0]:
    A_bar = math.exp(a * delta)
    t_discrete = torch.arange(0, 5 + delta, delta)
    x_discrete = torch.tensor([A_bar ** k for k in range(len(t_discrete))])
    ax.plot(t_discrete, x_discrete, 'o--', label=f'Discrete Î”={delta}', markersize=5)

ax.set_xlabel('Time')
ax.set_ylabel('State')
ax.set_title('Continuous vs Discrete: Smaller Î” = Better Approximation')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()


**Key insight for Mamba:** In Mamba, $\Delta$ (the discretization step) becomes a **learnable parameter** that can vary per token! 

- Small $\Delta$ â†’ the system changes slowly, focusing on the input
- Large $\Delta$ â†’ the system evolves more, mixing with past state

This will become crucial later when we discuss **selectivity**.


### ðŸ§ª Test Problem: Discretize a 2D SSM

Given the continuous A and B matrices from our oscillator, compute the discrete versions.


In [None]:
# Test problem: Discretize our 2D oscillator SSM
test_A = torch.tensor([[0.0, 1.0], [-1.0, -0.3]], dtype=torch.float32)
test_B = torch.tensor([[0.0], [1.0]], dtype=torch.float32)
test_delta = 0.1

# Fill in the code to compute A_bar and B_bar
# Hint: Use torch.linalg.matrix_exp for matrix exponential
# Hint: Use torch.linalg.solve or torch.linalg.inv for the inverse

# test_A_bar = ???  # fill in code here
# test_B_bar = ???  # fill in code here

# Uncomment to check your answer:
# assert test_A_bar.shape == (2, 2), "A_bar should be 2x2"
# assert test_B_bar.shape == (2, 1), "B_bar should be 2x1"
# assert torch.allclose(test_A_bar[0, 0], torch.tensor(0.9851), atol=1e-3), "Check your A_bar computation"
# print("âœ“ Discretization correct!")


## The Two Faces of SSMs: Recurrence vs Convolution

Here's a remarkable fact: the same SSM computation can be done two completely different ways:

1. **Recurrent mode**: Process step-by-step like an RNN
2. **Convolutional mode**: Precompute a kernel and convolve

They give **identical results**, but have different computational tradeoffs.


### Mode 1: Recurrent Computation

The obvious wayâ€”process each token sequentially:


In [None]:
def ssm_recurrent(A_bar, B_bar, C, u_sequence):
    """
    Process sequence recurrently.
    
    Args:
        A_bar: (N, N) discrete state matrix
        B_bar: (N, 1) discrete input matrix
        C: (1, N) output matrix
        u_sequence: (L,) input sequence
    
    Returns:
        (L,) output sequence
    """
    N = A_bar.shape[0]
    L = len(u_sequence)
    
    x = torch.zeros(N)  # initial state
    outputs = []
    
    for t in range(L):
        x = A_bar @ x + B_bar.squeeze() * u_sequence[t]
        y = C @ x
        outputs.append(y.item())
    
    return torch.tensor(outputs)


In [None]:
# Create a simple 1D SSM for demonstration
N = 1  # state dimension
A_demo = torch.tensor([[0.9]])  # decay factor
B_demo = torch.tensor([[0.1]])  # input scaling
C_demo = torch.tensor([[1.0]])  # output = state

# Fake embeddings for tokens "The" "cat" "sat"
u_demo = torch.randn(10)
u_demo


In [None]:
y_recurrent = ssm_recurrent(A_demo, B_demo, C_demo, u_demo)
y_recurrent


### Mode 2: Convolutional Computation

Now the magic: we can **precompute a convolution kernel** that does the same thing!

The kernel is: $\mathcal{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, ..., C\bar{A}^{L-1}\bar{B})$

This is basically "how much does input at time 0 affect output at time k?"


In [None]:
def compute_ssm_kernel(A_bar, B_bar, C, L):
    """
    Compute the SSM convolution kernel.
    
    K[k] = C @ A_bar^k @ B_bar
    """
    kernel = []
    A_power = torch.eye(A_bar.shape[0])
    
    for k in range(L):
        K_k = C @ A_power @ B_bar
        kernel.append(K_k.item())
        A_power = A_power @ A_bar
    
    return torch.tensor(kernel)


In [None]:
kernel = compute_ssm_kernel(A_demo, B_demo, C_demo, len(u_demo))
kernel  # Notice: exponentially decaying!


In [None]:
def ssm_convolutional(kernel, u_sequence):
    """
    Process sequence via convolution.
    This is a causal convolution (only past affects present).
    """
    L = len(u_sequence)
    outputs = torch.zeros(L)
    
    for t in range(L):
        # Sum over past inputs weighted by kernel
        for k in range(t + 1):
            outputs[t] += kernel[k] * u_sequence[t - k]
    
    return outputs


In [None]:
y_conv = ssm_convolutional(kernel, u_demo)
y_conv


In [None]:
# They should be identical!
torch.allclose(y_recurrent, y_conv)


### Why Two Modes?

**Recurrent mode:**
- Complexity: O(L) sequential steps
- Training: Slow (can't parallelize across time)  
- Inference: Fast (just update state with each new token)

**Convolutional mode:**
- Complexity: O(L log L) using FFT
- Training: Fast (fully parallelizable!)
- Inference: Slow (need full sequence to convolve)

**The clever trick:** During training, use convolutional mode for parallelism. During inference, use recurrent mode for efficiency.


In [None]:
# Visualize the kernel - it's like a "memory window"
fig, ax = plt.subplots(figsize=(10, 4))
ax.bar(range(len(kernel)), kernel)
ax.set_xlabel('Lag (how many steps ago)')
ax.set_ylabel('Weight')
ax.set_title('SSM Kernel: How much past inputs affect current output')
ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.tight_layout()


## Connection to RNNs: Linear Recurrence

Look at the recurrent update again:

$$x_k = \bar{A} x_{k-1} + \bar{B} u_k$$

Compare to an RNN:

$$h_k = \tanh(W_h h_{k-1} + W_x x_k)$$

The SSM is like a **linear RNN** (no nonlinearity in the recurrence itself). This seems limiting, but it's actually a feature!


### Why Linearity is Actually Good

**Problem with RNNs:** The tanh squishes gradients. After many steps, gradients either vanish (â†’0) or explode (â†’âˆž).

**SSM advantage:** Linear recurrence means we can analyze gradient flow exactly using eigenvalues of A. If all eigenvalues have magnitude < 1, gradients are bounded!

Let's see this:


In [None]:
# For linear recurrence x_k = A @ x_{k-1}, after n steps:
# x_n = A^n @ x_0
# Gradient of x_n w.r.t. x_0 is just A^n

# If eigenvalues have magnitude < 1, A^n â†’ 0 (bounded, no explosion)
# If eigenvalues have magnitude > 1, A^n â†’ âˆž (explosion)

A_stable = torch.tensor([[0.9]])  # eigenvalue = 0.9 < 1
A_unstable = torch.tensor([[1.1]])  # eigenvalue = 1.1 > 1

n_steps = 50
powers_stable = [torch.linalg.matrix_power(A_stable, n).item() for n in range(n_steps)]
powers_unstable = [torch.linalg.matrix_power(A_unstable, n).item() for n in range(n_steps)]

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(powers_stable, label='Stable (Î»=0.9)', linewidth=2)
ax.plot(powers_unstable, label='Unstable (Î»=1.1)', linewidth=2)
ax.set_xlabel('Steps')
ax.set_ylabel('A^n')
ax.set_title('Gradient Scaling: Stable vs Unstable Systems')
ax.legend()
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
plt.tight_layout()


### But Wait, Isn't Linear Too Simple?

Yes, a single linear SSM is limited. The solution:

1. Stack multiple SSM layers with **nonlinearities between them**
2. The nonlinearity is outside the recurrence (in the MLP/gating), not inside
3. The SSM handles long-range dependencies; MLPs handle local nonlinear mixing

This gives us stability of linear systems + expressiveness of deep networks.


## The HiPPO Matrix: Learning to Remember

We've seen that the A matrix controls how the system remembers. But how should we initialize A?

Random initialization works poorly. The breakthrough came from the **HiPPO** (High-order Polynomial Projection Operators) framework.

**Key insight:** Instead of hoping the network learns good memory, we can mathematically derive an A matrix that **optimally compresses the history** into the state.


### The Intuition: Polynomial Approximation

Imagine you want to remember a signal f(t) using only N numbers (your state). What's the best way?

**Answer:** Project f(t) onto an orthogonal basis (like Legendre polynomials). Store the N coefficients.

HiPPO derives the A and B matrices that maintain these optimal projections as new inputs arrive!


In [None]:
def make_hippo_legs(N):
    """
    Create the HiPPO-LegS (Legendre) matrix.
    This matrix optimally compresses history into Legendre polynomial coefficients.
    """
    P = torch.zeros(N, N)
    for n in range(N):
        for k in range(N):
            if n > k:
                P[n, k] = (2*n + 1) ** 0.5 * (2*k + 1) ** 0.5
            elif n == k:
                P[n, k] = n + 1
    return -P


In [None]:
# Create an 8-dimensional HiPPO matrix
N = 8
A_hippo = make_hippo_legs(N)

fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(A_hippo, cmap='RdBu', vmin=-10, vmax=10)
ax.set_title('HiPPO-LegS Matrix (N=8)')
ax.set_xlabel('Input dimension')
ax.set_ylabel('Output dimension')
plt.colorbar(im)
plt.tight_layout()


### Memory Comparison: Random vs HiPPO

Let's compare how well different A matrices remember a signal:


In [None]:
# Create two different SSMs: random vs HiPPO initialization
N = 16
dt = 0.01

# Random A (scaled to be stable)
A_random = torch.randn(N, N) * 0.5

# HiPPO A
A_hippo = make_hippo_legs(N)

# Simple B (all ones, scaled)
B = torch.ones(N, 1)

# Discretize both
A_bar_random = torch.linalg.matrix_exp(A_random * dt)
A_bar_hippo = torch.linalg.matrix_exp(A_hippo * dt)
# Simplified B discretization for demonstration
B_bar = B * dt


In [None]:
# Input: a "spike" at the beginning, then zeros
# A good memory should maintain information about this spike
L = 500
u = torch.zeros(L)
u[0] = 1.0  # impulse at t=0

# Run both SSMs
x_random = torch.zeros(N)
x_hippo = torch.zeros(N)
states_random = []
states_hippo = []

for t in range(L):
    x_random = A_bar_random @ x_random + B_bar.squeeze() * u[t]
    x_hippo = A_bar_hippo @ x_hippo + B_bar.squeeze() * u[t]
    states_random.append(x_random.norm().item())
    states_hippo.append(x_hippo.norm().item())


In [None]:
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(states_random, label='Random A', alpha=0.8)
ax.plot(states_hippo, label='HiPPO A', alpha=0.8)
ax.set_xlabel('Time step')
ax.set_ylabel('State norm (memory of impulse)')
ax.set_title('Memory Decay: HiPPO maintains information longer')
ax.legend()
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
plt.tight_layout()


## S4: Structured State Spaces

The HiPPO matrix looks great, but there's a problem: computing the convolution kernel naively is O(NÂ²L) which is expensive.

**S4 (Structured State Space)** discovered that if we parameterize A in a special way (DPLR: Diagonal Plus Low-Rank), we can compute the kernel efficiently using the **Cauchy kernel** trick.

The key insight: work in the **frequency domain** using complex numbers!


### The Diagonal Trick

When A is diagonal, the SSM becomes much simpler. Each dimension evolves independently:

$$x_k^{(i)} = \lambda_i \cdot x_{k-1}^{(i)} + b_i \cdot u_k$$

where $\lambda_i$ is the i-th diagonal element. The convolution kernel is just geometric series!


In [None]:
# Diagonal SSM: each dimension is independent
# Lambda values (complex for oscillation + decay)
lambdas = torch.tensor([0.9 + 0.1j, 0.95 - 0.05j, 0.8, 0.99])
lambdas


In [None]:
# For diagonal A, kernel[k] = C * A^k * B = sum_i (c_i * lambda_i^k * b_i)
# This is just a weighted sum of exponentials!

L = 100
c = torch.ones(4, dtype=torch.cfloat)  # output weights
b = torch.ones(4, dtype=torch.cfloat)  # input weights

# Compute kernel: K[k] = sum_i c_i * lambda_i^k * b_i
k_indices = torch.arange(L)
kernel_diagonal = torch.zeros(L, dtype=torch.cfloat)
for i in range(4):
    kernel_diagonal += c[i] * (lambdas[i] ** k_indices) * b[i]

# Take real part for visualization
kernel_diagonal = kernel_diagonal.real


In [None]:
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(kernel_diagonal)
ax.set_xlabel('Lag')
ax.set_ylabel('Kernel value')
ax.set_title('Diagonal SSM Kernel: Sum of exponentials (with oscillation from complex eigenvalues)')
ax.grid(True, alpha=0.3)
plt.tight_layout()


## The Selectivity Problem: Why S4 Wasn't Enough

S4 achieved impressive results on benchmarks like Long Range Arena. But it struggled with tasks that require **content-aware** processing.

The core issue: in S4, the matrices A, B, C are **fixed** for all inputs. The same transformation is applied regardless of whether the input is "the" or "revolutionary".


### A Task That Requires Selection: Selective Copying

Consider this task: given a sequence with "markers", copy only the marked elements.

```
Input:  [a, b, c, *, d, e, *, f, g]
        (where * means "copy previous")
Output: [_, _, _, c, _, _, e, _, _]
```

A transformer can do this easilyâ€”just attend to the marked positions. But a fixed SSM treats every position the same!


In [None]:
# Illustrating the problem: Transformer attention vs fixed SSM

# Transformer can attend selectively
print("Transformer attention (content-dependent):")
print("Token 'copy_marker' â†’ high attention to previous tokens")
print("Token 'regular_word' â†’ normal attention pattern")
print()
print("Fixed SSM (same for all):")
print("Token 'copy_marker' â†’ same A, B, C applied")
print("Token 'regular_word' â†’ same A, B, C applied")


## Mamba: The Key Insight

Mamba's breakthrough is simple but powerful: **make B, C, and Î” depend on the input!**

Instead of:
- $B$ = fixed matrix
- $C$ = fixed matrix  
- $\Delta$ = fixed step size

We have:
- $B(x)$ = function of current input
- $C(x)$ = function of current input
- $\Delta(x)$ = function of current input

This makes the SSM **selective**â€”it can choose what to remember based on context.


### ðŸ“· Insert Image Here
**Paste the Mamba architecture diagram from the paper (Figure 3) showing the selective SSM block**

The image should show how B, C, Î” are computed from the input x via linear projections.


### What Does Selectivity Enable?

Think about it intuitively:

- **Large Î”(x):** "This token is important, integrate it strongly into state"
- **Small Î”(x):** "This token is noise, let state decay without much update"

- **B(x) large:** "Write this token's information into state"
- **B(x) small:** "Don't store this token"

- **C(x) large:** "Read from state to produce output here"
- **C(x) small:** "Don't need state information for this output"


### The Trade-off: No More Convolution Mode

There's a catch: when B, C, Î” vary per position, we can't precompute a single kernel. The convolutional trick breaks!

**The solution: Parallel Scan**

The recurrence $x_k = \bar{A}_k x_{k-1} + \bar{B}_k u_k$ looks sequential, but it can actually be parallelized using a technique called **parallel prefix scan** (also used in GPUs for cumulative sums).

Key insight: the operation is **associative**, so we can compute it in O(log L) parallel steps instead of O(L) sequential steps.


In [None]:
# Simple example: parallel prefix sum
# Sequential: [1, 2, 3, 4] â†’ [1, 1+2, 1+2+3, 1+2+3+4] = [1, 3, 6, 10]
# This takes O(n) sequential operations

# But we can do it in O(log n) parallel steps!
# Step 1: [1, 1+2, 3, 3+4] = [1, 3, 3, 7]
# Step 2: [1, 3, 1+3, 3+7] = [1, 3, 6, 10]

x = torch.tensor([1, 2, 3, 4])

# Sequential
cumsum_seq = x.cumsum(0)
print(f"Sequential cumsum: {cumsum_seq}")

# Note: PyTorch's cumsum is already optimized, but the principle applies to SSM recurrence


## The Mamba Block Architecture

Now let's build the complete Mamba block step by step.

**Components:**
1. **Input projection**: Expand input dimension
2. **Conv1D**: Short local convolution for local context
3. **Selective SSM**: The core state space model with input-dependent parameters
4. **Gating**: SiLU activation + gating for nonlinearity
5. **Output projection**: Return to original dimension


In [None]:
# Let's define our dimensions
d_model = 64       # input/output dimension
d_inner = 128      # expanded inner dimension (2x d_model is typical)
d_state = 16       # SSM state dimension (N)
d_conv = 4         # local convolution width
dt_rank = 8        # rank for Î” projection


### Step 1: Input Projection

Like in transformers, we first project to a larger dimension:


In [None]:
# Fake input: batch of 2 sequences, length 10, dim 64
# Think of this as embeddings for "The cat sat on the mat..."
x = torch.randn(2, 10, d_model)
x.shape


In [None]:
# Project to expanded dimension (split for gating later)
in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
xz = in_proj(x)
xz.shape  # (batch, seq, 2 * d_inner)


In [None]:
# Split: one path goes through SSM, other is used for gating
x_ssm, z = xz.chunk(2, dim=-1)
x_ssm.shape, z.shape


### Step 2: Short Convolution

A 1D convolution captures local patterns before the SSM processes global dependencies:


In [None]:
# Causal 1D convolution (padding to maintain causality)
conv1d = nn.Conv1d(
    in_channels=d_inner,
    out_channels=d_inner,
    kernel_size=d_conv,
    padding=d_conv - 1,  # causal padding
    groups=d_inner  # depthwise = each channel separately
)

# Conv1d expects (batch, channels, length)
x_conv = conv1d(x_ssm.transpose(1, 2))[:, :, :x_ssm.shape[1]]  # trim extra padding
x_conv = x_conv.transpose(1, 2)  # back to (batch, length, channels)
x_conv.shape


In [None]:
# Apply SiLU (Swish) activation
x_conv = F.silu(x_conv)


### Step 3: Selective SSM

Now the core of Mamba! We compute B, C, Î” as functions of the input:


In [None]:
# Projections to compute B, C, Î” from input
x_dbl_proj = nn.Linear(d_inner, dt_rank + d_state * 2, bias=False)
dt_proj = nn.Linear(dt_rank, d_inner, bias=True)

# Compute selective parameters
x_dbl = x_dbl_proj(x_conv)  # (batch, seq, dt_rank + 2*d_state)


In [None]:
# Split into delta, B, C
delta, B, C = x_dbl.split([dt_rank, d_state, d_state], dim=-1)
print(f"delta shape (before expansion): {delta.shape}")
print(f"B shape: {B.shape}")
print(f"C shape: {C.shape}")


In [None]:
# Project delta to full inner dimension and apply softplus (ensures positive)
delta = dt_proj(delta)  # (batch, seq, d_inner)
delta = F.softplus(delta)  # positive step sizes
print(f"delta shape (after expansion): {delta.shape}")


In [None]:
# A is NOT input-dependent - it's a learnable parameter
# Parameterized in log-space for numerical stability
A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
A = -torch.exp(A_log)  # negative for stability
A.shape  # (d_state,) - diagonal elements


### Step 4: Discretization (Per-Position!)

Now we discretize A and B using the input-dependent delta:


In [None]:
# Discretize: A_bar = exp(delta * A)
# For diagonal A, this is element-wise exp
# delta: (batch, seq, d_inner)
# A: (d_state,)

# We need to broadcast properly
# delta_A: (batch, seq, d_inner, d_state)
delta_A = delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0)
A_bar = torch.exp(delta_A)
A_bar.shape


In [None]:
# Simplified discretization for B (using Euler instead of exact for clarity)
# B_bar â‰ˆ delta * B
# B: (batch, seq, d_state)
# delta: (batch, seq, d_inner)

# We need B to be (batch, seq, d_inner, d_state) to match shapes
B_bar = delta.unsqueeze(-1) * B.unsqueeze(-2)  # (batch, seq, d_inner, d_state)
B_bar.shape


### Step 5: Run the Selective SSM

Now we run the recurrence (simplified version - real Mamba uses parallel scan):


In [None]:
def selective_ssm_recurrent(x_conv, A_bar, B_bar, C):
    """
    Selective SSM forward pass (recurrent mode).
    
    Args:
        x_conv: (batch, seq, d_inner) - input after convolution
        A_bar: (batch, seq, d_inner, d_state) - discretized A
        B_bar: (batch, seq, d_inner, d_state) - discretized B  
        C: (batch, seq, d_state) - output projection
    
    Returns:
        (batch, seq, d_inner) - output
    """
    batch, seq_len, d_inner = x_conv.shape
    d_state = A_bar.shape[-1]
    
    # Initialize state
    h = torch.zeros(batch, d_inner, d_state, device=x_conv.device)
    outputs = []
    
    for t in range(seq_len):
        # h_new = A_bar[t] * h + B_bar[t] * x[t]
        # Note: element-wise for diagonal A
        h = A_bar[:, t] * h + B_bar[:, t] * x_conv[:, t, :, None]
        
        # y = C[t] @ h (for each inner dim)
        # h: (batch, d_inner, d_state)
        # C: (batch, seq, d_state) -> C[t]: (batch, d_state)
        y = (h * C[:, t, None, :]).sum(dim=-1)  # (batch, d_inner)
        outputs.append(y)
    
    return torch.stack(outputs, dim=1)  # (batch, seq, d_inner)


In [None]:
y_ssm = selective_ssm_recurrent(x_conv, A_bar, B_bar, C)
y_ssm.shape


### Step 6: Gating and Output

Finally, we apply gating (multiply with the z branch) and project back:


In [None]:
# Gating: multiply by silu(z)
# This is like the gating in GLU (Gated Linear Unit)
y_gated = y_ssm * F.silu(z)
y_gated.shape


In [None]:
# Project back to original dimension
out_proj = nn.Linear(d_inner, d_model, bias=False)
output = out_proj(y_gated)
output.shape  # Back to (batch, seq, d_model)!


### Complete Mamba Block

Let's wrap everything into a proper module:


In [None]:
class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.d_inner = int(expand * d_model)
        self.dt_rank = max(d_model // 16, 1)
        
        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        
        # Convolution
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner, 
            kernel_size=d_conv, padding=d_conv - 1,
            groups=self.d_inner
        )
        
        # SSM parameters
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
        
        # A in log space (not input-dependent)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
        
        # D is a skip connection parameter
        self.D = nn.Parameter(torch.ones(self.d_inner))
        
        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
    
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # Input projection and split
        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)
        
        # Conv1d
        x_conv = self.conv1d(x_ssm.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
        x_conv = F.silu(x_conv)
        
        # Compute selective parameters
        x_dbl = self.x_proj(x_conv)
        delta, B, C = x_dbl.split([self.dt_rank, self.d_state, self.d_state], dim=-1)
        delta = F.softplus(self.dt_proj(delta))
        
        # Discretize
        A = -torch.exp(self.A_log)
        A_bar = torch.exp(delta.unsqueeze(-1) * A)
        B_bar = delta.unsqueeze(-1) * B.unsqueeze(-2)
        
        # SSM recurrence
        y = self._ssm_recurrent(x_conv, A_bar, B_bar, C)
        
        # Skip connection
        y = y + x_conv * self.D
        
        # Gating and output
        y = y * F.silu(z)
        return self.out_proj(y)
    
    def _ssm_recurrent(self, x, A_bar, B_bar, C):
        batch, seq_len, d_inner = x.shape
        h = torch.zeros(batch, d_inner, self.d_state, device=x.device)
        outputs = []
        
        for t in range(seq_len):
            h = A_bar[:, t] * h + B_bar[:, t] * x[:, t, :, None]
            y = (h * C[:, t, None, :]).sum(dim=-1)
            outputs.append(y)
        
        return torch.stack(outputs, dim=1)


In [None]:
# Test our Mamba block
mamba_block = MambaBlock(d_model=64)
test_input = torch.randn(2, 10, 64)
test_output = mamba_block(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"Parameters: {sum(p.numel() for p in mamba_block.parameters()):,}")


### ðŸ§ª Test Problem: Modify the Block

Add a RMSNorm before the Mamba block (pre-norm style, like modern transformers).


In [None]:
# Test problem: Create a pre-normed Mamba block

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        # Fill in code here: compute RMS norm
        # Hint: rms = sqrt(mean(x^2))
        # return x / rms * self.weight
        pass

class PreNormMambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16):
        super().__init__()
        # Fill in code here
        pass
    
    def forward(self, x):
        # Fill in code here: apply norm, then mamba, then residual connection
        pass

# Uncomment to test:
# test_block = PreNormMambaBlock(64)
# test_out = test_block(torch.randn(2, 10, 64))
# assert test_out.shape == (2, 10, 64), "Output shape should match input"
# print("âœ“ Pre-normed block working!")


## Building a Complete Mamba Model

Now let's build a full Mamba model for text classification, then language modeling.


In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.weight


In [None]:
class MambaLM(nn.Module):
    """Complete Mamba Language Model"""
    def __init__(self, vocab_size, d_model, n_layers, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'norm': RMSNorm(d_model),
                'mamba': MambaBlock(d_model, d_state, d_conv, expand)
            })
            for _ in range(n_layers)
        ])
        
        self.norm_f = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.embedding.weight
    
    def forward(self, input_ids):
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = x + layer['mamba'](layer['norm'](x))
        
        x = self.norm_f(x)
        return self.lm_head(x)


In [None]:
# Test the model
model = MambaLM(vocab_size=10000, d_model=128, n_layers=4)
test_ids = torch.randint(0, 10000, (2, 32))  # batch of 2, seq len 32
logits = model(test_ids)
print(f"Input: {test_ids.shape}")
print(f"Output logits: {logits.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


## Training Mamba: Text Classification

Before we tackle language modeling, let's train on a simpler task: text classification with AG News.

This will help us verify our implementation works and build intuition for training dynamics.


# Load AG News dataset
from datasets import load_dataset
from transformers import AutoTokenizer

# Use a simple tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load AG News (4 classes: World, Sports, Business, Sci/Tech)
dataset = load_dataset('ag_news')
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")
print(f"Classes: {dataset['train'].features['label'].names}")


In [None]:
# Look at an example
example = dataset['train'][0]
print(f"Text: {example['text'][:200]}...")
print(f"Label: {example['label']} ({dataset['train'].features['label'].names[example['label']]})")


In [None]:
# Tokenize the dataset
max_length = 128

def tokenize_fn(examples):
    return tokenizer(
        examples['text'], 
        truncation=True, 
        padding='max_length',
        max_length=max_length,
        return_tensors=None
    )

tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=['text'])
tokenized.set_format('torch')


In [None]:
class MambaClassifier(nn.Module):
    """Mamba for sequence classification"""
    def __init__(self, vocab_size, d_model, n_layers, n_classes, d_state=16):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'norm': RMSNorm(d_model),
                'mamba': MambaBlock(d_model, d_state)
            })
            for _ in range(n_layers)
        ])
        
        self.norm_f = RMSNorm(d_model)
        self.classifier = nn.Linear(d_model, n_classes)
    
    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = x + layer['mamba'](layer['norm'](x))
        
        x = self.norm_f(x)
        
        # Pool: take the last token's representation (or mean)
        if attention_mask is not None:
            # Mean pooling over non-padded tokens
            mask = attention_mask.unsqueeze(-1).float()
            x = (x * mask).sum(dim=1) / mask.sum(dim=1)
        else:
            x = x.mean(dim=1)
        
        return self.classifier(x)


In [None]:
# Create dataloaders
from torch.utils.data import DataLoader

batch_size = 32
train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(tokenized['test'], batch_size=batch_size)


In [None]:
# Initialize model
classifier = MambaClassifier(
    vocab_size=tokenizer.vocab_size,
    d_model=128,
    n_layers=4,
    n_classes=4
).to(device)

print(f"Parameters: {sum(p.numel() for p in classifier.parameters()):,}")


In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids, attention_mask)
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    return correct / total


In [None]:
# Train for a few epochs
optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-4)
n_epochs = 3

for epoch in range(n_epochs):
    print(f"\nEpoch {epoch + 1}/{n_epochs}")
    train_loss, train_acc = train_epoch(classifier, train_loader, optimizer, device)
    test_acc = evaluate(classifier, test_loader, device)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")


## Language Modeling with TinyStories

Now let's train a proper language model on TinyStories - a dataset of simple children's stories that's perfect for testing language models efficiently.


In [None]:
# Load TinyStories dataset
tinystories = load_dataset('roneneldan/TinyStories', split='train[:100000]')  # subset for speed
print(f"Loaded {len(tinystories)} stories")
print(f"\nExample story:\n{tinystories[0]['text'][:500]}...")


In [None]:
# Use GPT-2 tokenizer for language modeling
from transformers import GPT2Tokenizer

lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
lm_tokenizer.pad_token = lm_tokenizer.eos_token

# Tokenize and create chunks
context_length = 256

def tokenize_lm(examples):
    tokens = lm_tokenizer(examples['text'], truncation=True, max_length=context_length, padding='max_length')
    tokens['labels'] = tokens['input_ids'].copy()
    return tokens

tokenized_stories = tinystories.map(tokenize_lm, batched=True, remove_columns=['text'])
tokenized_stories.set_format('torch')


In [None]:
# Create a smaller model for training
lm_model = MambaLM(
    vocab_size=lm_tokenizer.vocab_size,
    d_model=256,
    n_layers=6,
    d_state=16
).to(device)

print(f"LM Parameters: {sum(p.numel() for p in lm_model.parameters()):,}")


In [None]:
def train_lm_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training LM')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids)
        
        # Shift for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=lm_tokenizer.pad_token_id
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'ppl': f'{math.exp(loss.item()):.2f}'})
    
    return total_loss / len(loader)


In [None]:
# Training loop
lm_loader = DataLoader(tokenized_stories, batch_size=16, shuffle=True)
lm_optimizer = torch.optim.AdamW(lm_model.parameters(), lr=3e-4)

# Train for a few epochs (increase for better results)
for epoch in range(2):
    print(f"\nEpoch {epoch + 1}")
    loss = train_lm_epoch(lm_model, lm_loader, lm_optimizer, device)
    print(f"Average Loss: {loss:.4f}, Perplexity: {math.exp(loss):.2f}")


In [None]:
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_new_tokens):
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(input_ids[0])


In [None]:
# Generate some stories!
prompts = [
    "Once upon a time, there was a",
    "The little girl walked into the",
    "Tom and his dog went to"
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    print(f"Generated: {generate(lm_model, lm_tokenizer, prompt)}")
    print("-" * 50)


## Dissecting Mamba 1.4B

Now let's load a pretrained Mamba model and explore its internals to build deeper intuition.


In [None]:
# Install mamba-ssm if needed (run in terminal: pip install mamba-ssm)
try:
    from mamba_ssm import MambaLMHeadModel
    MAMBA_AVAILABLE = True
except ImportError:
    print("mamba-ssm not installed. Run: pip install mamba-ssm causal-conv1d>=1.1.0")
    MAMBA_AVAILABLE = False


In [None]:
if MAMBA_AVAILABLE:
    # Load the 1.4B model
    mamba_1b = MambaLMHeadModel.from_pretrained("state-spaces/mamba-1.4b", device='cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Loaded Mamba 1.4B")
    print(f"Parameters: {sum(p.numel() for p in mamba_1b.parameters()):,}")


In [None]:
if MAMBA_AVAILABLE:
    # Explore the model structure
    print("Model architecture:")
    print(mamba_1b)


In [None]:
if MAMBA_AVAILABLE:
    # Look at the A matrix in the first layer
    first_layer = mamba_1b.backbone.layers[0].mixer
    A_log = first_layer.A_log.detach().cpu()
    
    print(f"A_log shape: {A_log.shape}")
    print(f"A (negative eigenvalues): {-torch.exp(A_log[0, :10])}...")
    
    # Visualize the learned A values
    fig, ax = plt.subplots(figsize=(10, 4))
    A_vals = -torch.exp(A_log.mean(dim=0))
    ax.bar(range(len(A_vals)), A_vals.numpy())
    ax.set_xlabel('State dimension')
    ax.set_ylabel('A value (decay rate)')
    ax.set_title('Learned A matrix values (averaged over inner dims)')
    plt.tight_layout()


In [None]:
if MAMBA_AVAILABLE:
    # Generate text with the pretrained model
    from transformers import AutoTokenizer
    
    mamba_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    
    prompt = "The theory of relativity states that"
    input_ids = mamba_tokenizer(prompt, return_tensors='pt').input_ids.to(mamba_1b.device)
    
    output = mamba_1b.generate(input_ids, max_length=100, temperature=0.7)
    print(f"Prompt: {prompt}")
    print(f"Generated: {mamba_tokenizer.decode(output[0])}")


## Mamba vs Transformers: When to Use What

Now that we understand both architectures, let's compare them systematically.


### Complexity Comparison

**Memory:**
- Transformer: O(LÂ²) for attention (KV cache grows with sequence)
- Mamba: O(1) state per layer (fixed size regardless of sequence length!)

**Compute:**
- Transformer: O(LÂ² Ã— d) for attention
- Mamba: O(L Ã— d Ã— N) where N is state dimension (typically 16)

**Training:**
- Transformer: Fully parallel (all positions computed simultaneously)
- Mamba: Parallel via parallel scan (almost as fast)

**Inference:**
- Transformer: Need to store full KV cache; each new token is O(L)
- Mamba: Just update state; each new token is O(1)!


In [None]:
# Visualize memory usage during inference
seq_lengths = torch.arange(100, 10001, 100)
d_model = 1024
n_layers = 24
n_heads = 16
d_state = 16

# Transformer KV cache: 2 * n_layers * L * d_model (K and V for each layer)
transformer_memory = 2 * n_layers * seq_lengths * d_model * 4 / 1e9  # GB (float32)

# Mamba state: n_layers * d_model * d_state (fixed!)
mamba_memory = torch.ones_like(seq_lengths) * n_layers * d_model * d_state * 4 / 1e9  # GB

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(seq_lengths, transformer_memory, label='Transformer KV Cache', linewidth=2)
ax.plot(seq_lengths, mamba_memory, label='Mamba State', linewidth=2)
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Memory (GB)')
ax.set_title('Inference Memory: Transformer vs Mamba')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()


### Qualitative Differences

**Transformers excel at:**
- **Retrieval**: Looking up specific information from context ("What did John say about X?")
- **In-context learning**: Learning patterns from few examples in the prompt
- **Copying**: Reproducing exact sequences from input
- **Precise attention**: When you need to focus on specific tokens

**Mamba excels at:**
- **Compression**: Summarizing long contexts into useful representations
- **Long-range dependencies**: When information must flow across very long sequences
- **Efficiency**: Especially for long sequences and real-time generation
- **Streaming**: Processing continuous input streams

**The intuition:** Transformers "store and retrieve" while Mamba "compresses and flows".


## Hybrid Architectures: Best of Both Worlds

Given the complementary strengths, why not combine them? This is exactly what hybrid architectures like **Jamba** and **Zamba2** do.

### Design Patterns

**1. Interleaved:** Alternate Mamba and attention layers
```
Mamba â†’ Mamba â†’ Attention â†’ Mamba â†’ Mamba â†’ Attention â†’ ...
```

**2. Parallel:** Run both in parallel and combine outputs
```
x â†’ [Mamba(x) + Attention(x)] â†’ ...
```

**3. Hierarchical:** Use Mamba for local, attention for global
```
Mamba(local) â†’ Pool â†’ Attention(global) â†’ Unpool â†’ Mamba(local)
```


In [None]:
class SimpleAttention(nn.Module):
    """Simplified attention for hybrid model"""
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, L, D = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, L, D)
        
        # Causal attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ v).transpose(1, 2).reshape(B, L, D)
        return self.proj(out)


In [None]:
class HybridBlock(nn.Module):
    """A block that uses either Mamba or Attention based on layer index"""
    def __init__(self, d_model, use_attention=False, d_state=16, n_heads=8):
        super().__init__()
        self.norm = RMSNorm(d_model)
        if use_attention:
            self.mixer = SimpleAttention(d_model, n_heads)
        else:
            self.mixer = MambaBlock(d_model, d_state)
    
    def forward(self, x):
        return x + self.mixer(self.norm(x))


In [None]:
class HybridLM(nn.Module):
    """Hybrid Mamba + Attention Language Model (like Zamba/Jamba)"""
    def __init__(self, vocab_size, d_model, n_layers, attention_every=4, d_state=16, n_heads=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Interleave: every `attention_every` layers, use attention
        self.layers = nn.ModuleList([
            HybridBlock(
                d_model, 
                use_attention=(i % attention_every == attention_every - 1),
                d_state=d_state,
                n_heads=n_heads
            )
            for i in range(n_layers)
        ])
        
        self.norm_f = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # tie weights
        
        # Count layers
        n_mamba = sum(1 for l in self.layers if isinstance(l.mixer, MambaBlock))
        n_attn = sum(1 for l in self.layers if isinstance(l.mixer, SimpleAttention))
        print(f"Hybrid model: {n_mamba} Mamba layers, {n_attn} Attention layers")
    
    def forward(self, input_ids):
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        return self.lm_head(x)


In [None]:
# Test the hybrid model
hybrid = HybridLM(vocab_size=50257, d_model=256, n_layers=12, attention_every=4)
print(f"Parameters: {sum(p.numel() for p in hybrid.parameters()):,}")


## Training a Hybrid Model with W&B Logging

Now let's train a larger hybrid model with proper monitoring using Weights & Biases.


In [None]:
import wandb

# Configuration for training
config = {
    'vocab_size': 50257,
    'd_model': 512,
    'n_layers': 16,
    'attention_every': 4,  # 1 attention layer every 4 layers
    'd_state': 16,
    'n_heads': 8,
    'batch_size': 8,
    'gradient_accumulation_steps': 4,  # effective batch = 32
    'learning_rate': 3e-4,
    'warmup_steps': 500,
    'max_steps': 5000,
    'context_length': 512,
    'mixed_precision': True,
}


In [None]:
# Initialize W&B
wandb.init(project='mamba-hybrid', config=config)

# Create model
hybrid_model = HybridLM(
    vocab_size=config['vocab_size'],
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    attention_every=config['attention_every'],
    d_state=config['d_state'],
    n_heads=config['n_heads']
).to(device)

print(f"Model parameters: {sum(p.numel() for p in hybrid_model.parameters()):,}")
wandb.watch(hybrid_model, log='all', log_freq=100)


In [None]:
def get_lr(step, warmup_steps, max_lr, max_steps):
    """Cosine learning rate schedule with warmup"""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return max_lr * 0.5 * (1 + math.cos(math.pi * progress))

def log_mamba_specific_metrics(model, step):
    """Log metrics specific to Mamba layers"""
    metrics = {}
    
    for i, layer in enumerate(model.layers):
        if isinstance(layer.mixer, MambaBlock):
            # Log A matrix statistics (decay rates)
            A = -torch.exp(layer.mixer.A_log.detach())
            metrics[f'mamba_layer_{i}/A_mean'] = A.mean().item()
            metrics[f'mamba_layer_{i}/A_min'] = A.min().item()
            metrics[f'mamba_layer_{i}/A_max'] = A.max().item()
            
            # Log D (skip connection) statistics
            D = layer.mixer.D.detach()
            metrics[f'mamba_layer_{i}/D_mean'] = D.mean().item()
    
    wandb.log(metrics, step=step)


In [None]:
# Prepare data
train_loader_hybrid = DataLoader(tokenized_stories, batch_size=config['batch_size'], shuffle=True)

# Optimizer
optimizer = torch.optim.AdamW(hybrid_model.parameters(), lr=config['learning_rate'], weight_decay=0.1)

# Mixed precision
scaler = torch.cuda.amp.GradScaler() if config['mixed_precision'] and device.type == 'cuda' else None

# Training loop
global_step = 0
accumulation_steps = config['gradient_accumulation_steps']
hybrid_model.train()

pbar = tqdm(total=config['max_steps'], desc='Training Hybrid')
data_iter = iter(train_loader_hybrid)

while global_step < config['max_steps']:
    optimizer.zero_grad()
    total_loss = 0
    
    for micro_step in range(accumulation_steps):
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader_hybrid)
            batch = next(data_iter)
        
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        if scaler:
            with torch.cuda.amp.autocast():
                logits = hybrid_model(input_ids)
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = labels[:, 1:].contiguous()
                loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                    ignore_index=lm_tokenizer.pad_token_id
                ) / accumulation_steps
            scaler.scale(loss).backward()
        else:
            logits = hybrid_model(input_ids)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=lm_tokenizer.pad_token_id
            ) / accumulation_steps
            loss.backward()
        
        total_loss += loss.item()
    
    # Update learning rate
    lr = get_lr(global_step, config['warmup_steps'], config['learning_rate'], config['max_steps'])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Gradient clipping and step
    if scaler:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(hybrid_model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
    else:
        torch.nn.utils.clip_grad_norm_(hybrid_model.parameters(), 1.0)
        optimizer.step()
    
    # Logging
    if global_step % 10 == 0:
        wandb.log({
            'train/loss': total_loss,
            'train/perplexity': math.exp(total_loss),
            'train/lr': lr,
        }, step=global_step)
    
    if global_step % 100 == 0:
        log_mamba_specific_metrics(hybrid_model, global_step)
    
    pbar.update(1)
    pbar.set_postfix({'loss': f'{total_loss:.4f}', 'ppl': f'{math.exp(total_loss):.2f}'})
    global_step += 1

pbar.close()
wandb.finish()


## Productionizing Mamba: Monitoring & Best Practices

Training Mamba models has some unique considerations compared to Transformers.


### Mamba-Specific Monitoring

**What to watch in Mamba that differs from Transformers:**

1. **A matrix eigenvalues**: The decay rates control memory. If they drift too close to 0, the model "forgets everything instantly". If too close to 1, state might explode.

2. **Delta (Î”) distribution**: The step sizes should have healthy variance. If all Î” collapse to the same value, selectivity is lost.

3. **State norm**: Unlike Transformer activations, Mamba state accumulates over time. Monitor for explosion or collapse.

4. **D (skip connection)**: If D dominates, the SSM isn't doing much work.

**What's the same:**
- Loss curves, gradient norms, learning rate schedules
- Weight distributions and activation statistics
- Dead neurons in MLPs


### Training Stability Tricks

**Mamba-specific:**
- **Initialize A in log-space**: Prevents negative eigenvalues (instability)
- **Softplus for Î”**: Ensures positive step sizes
- **Bounded A initialization**: Start with reasonable decay rates (e.g., 1 to N)

**General (apply to both):**
- **Pre-norm**: Apply normalization before each block (more stable than post-norm)
- **Gradient clipping**: Clip to 1.0 to prevent explosion
- **Weight decay**: 0.1 is typical for language models
- **Learning rate warmup**: Crucial for stability at start

**Compute efficiency:**
- **Mixed precision (FP16/BF16)**: 2x memory savings, faster compute
- **Gradient accumulation**: Simulate larger batches on limited memory
- **Gradient checkpointing**: Trade compute for memory on very long sequences


In [None]:
# Generate with our trained hybrid model
hybrid_model.eval()
print("Generating with hybrid model:\n")

for prompt in ["Once upon a time", "The scientist discovered"]:
    print(f"Prompt: {prompt}")
    output = generate(hybrid_model, lm_tokenizer, prompt, max_new_tokens=50)
    print(f"Output: {output}\n")


## Cool Intuitions & Cross-Domain Connections

### SSMs as Learnable Filters
From signal processing: SSMs are essentially learnable infinite impulse response (IIR) filters. The kernel we computed earlier is the impulse response. Different A matrices create different filter characteristics (low-pass, band-pass, etc.).

### Connection to Differential Equations
Mamba's continuous SSM is literally a neural ODE! The discretization step connects to numerical methods like Euler and Runge-Kutta.

### Biological Plausibility  
Neurons in the brain can be modeled as dynamical systems with state. The selectivity mechanism mirrors how biological neurons gate information based on context.

### The Compression vs Retrieval Tradeoff
- **Transformers**: Store everything (like a database), retrieve via attention
- **Mamba**: Compress continuously (like a summary), no retrieval needed

This explains why Transformers dominate tasks requiring precise recall ("What was the 5th word?") while Mamba excels at tasks requiring integration ("What's the overall sentiment?").


## Summary: What We Learned

1. **State Space Models** come from control theoryâ€”they compress sequence history into a fixed-size state

2. **The key equations**: $x_k = \bar{A}x_{k-1} + \bar{B}u_k$ and $y_k = Cx_k$

3. **Two computation modes**: Recurrent (O(n) sequential) and Convolutional (O(n log n) parallel)

4. **HiPPO** provides optimal memory initialization for the A matrix

5. **The selectivity problem**: Fixed SSM parameters can't do content-aware processing

6. **Mamba's insight**: Make B, C, Î” input-dependent! Use parallel scan to stay efficient

7. **Hybrid models** combine Mamba's efficiency with attention's retrieval capability

8. **Production considerations**: Monitor A eigenvalues, Î” distribution, and state norms

The field is rapidly evolvingâ€”Mamba2, Griffin, and new architectures continue to push the boundaries of efficient sequence modeling!


### ðŸ“· Insert Image Here
**Suggested: Mamba paper Figure 1 showing the architecture comparison between Transformer and Mamba**

URL: https://arxiv.org/abs/2312.00752 (Figure 1)


## References & Further Reading

**Papers:**
- [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) - The original Mamba paper
- [Efficiently Modeling Long Sequences with Structured State Spaces (S4)](https://arxiv.org/abs/2111.00396) - The S4 breakthrough
- [HiPPO: Recurrent Memory with Optimal Polynomial Projections](https://arxiv.org/abs/2008.07669) - The memory theory
- [Jamba: A Hybrid Transformer-Mamba Language Model](https://arxiv.org/abs/2403.19887) - AI21's hybrid approach
- [Zamba: A Compact 7B SSM Hybrid Model](https://arxiv.org/abs/2405.16712) - Efficient hybrid design

**Code:**
- [state-spaces/mamba](https://github.com/state-spaces/mamba) - Official Mamba implementation
- [HazyResearch/safari](https://github.com/HazyResearch/safari) - S4 and SSM research code

**Tutorials:**
- [The Annotated S4](https://srush.github.io/annotated-s4/) - Detailed walkthrough of S4


https://developer.nvidia.com/blog/inside-nvidia-nemotron-3-techniques-tools-and-data-that-make-it-efficient-and-accurate/?utm_source=tldrai