# Tutorial 1 — SSM Basics

This notebook walks through the foundational concepts behind **S4** (Structured State Spaces for Sequence Modeling). By the end, you'll understand:

1. **What is a State Space Model (SSM)?** — The continuous-time dynamical system at the heart of S4
2. **The HiPPO Matrix** — Why the initialization of the state matrix *A* is the secret sauce
3. **Discretization** — How we bridge the continuous math to discrete sequences (text, audio, pixels)
4. **CNN ↔ RNN Duality** — How S4 trains like a CNN but infers like an RNN
5. **SSMs vs Transformers** — Where SSMs shine compared to attention-based LLMs

---

In [None]:
import sys, os
sys.path.insert(0, os.path.join(os.getcwd(), "..") if "tutorials" in os.getcwd() else os.getcwd())

import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## 1. What is a State Space Model?

A **State Space Model (SSM)** is a linear dynamical system that maps an input signal $u(t)$ to an output $y(t)$ through a latent state $x(t)$:

$$
x'(t) = A\, x(t) + B\, u(t) \qquad \text{(state equation)}
$$
$$
y(t) = C\, x(t) + D\, u(t) \qquad \text{(output equation)}
$$

Where:
- $A \in \mathbb{R}^{N \times N}$ — **state matrix** (controls the dynamics — this is the key!)
- $B \in \mathbb{R}^{N \times 1}$ — input-to-state projection
- $C \in \mathbb{R}^{1 \times N}$ — state-to-output projection
- $D \in \mathbb{R}$ — skip / direct feed-through connection
- $N$ — state dimension (how much "memory" the model has)

### Why does this matter for deep learning?

Think of an SSM as a **learnable filter**: it reads input one step at a time, updates a compressed memory (the state), and produces output. This is similar to an RNN — but with a crucial difference: **the structure of $A$ can be exploited for efficient parallel computation**.

The **key insight** of S4: a *random* $A$ gives poor long-range performance. But a specially designed $A$ (the **HiPPO matrix**) lets the state optimally approximate the entire input history using polynomial projections.

## 2. The HiPPO Matrix

**HiPPO** (High-order Polynomial Projection Operators) provides a principled way to initialize the state matrix $A$.

The idea: at each time $t$, the state $x(t)$ should store the **optimal polynomial approximation** of the input history $u(s)$ for $s \leq t$. The HiPPO-LegS variant uses **scaled Legendre polynomials** as the basis.

The matrix entries are:
$$
A_{nk} = \begin{cases}
-\sqrt{(2n+1)(2k+1)} & \text{if } n > k \\
-(n+1) & \text{if } n = k \\
0 & \text{if } n < k
\end{cases}
$$

This makes $A$ **lower-triangular** with **all eigenvalues having negative real parts** (guaranteed stability).

Let's visualize it:

In [None]:
from s4_lib.hippo import make_hippo_legs, make_hippo_b, diagonal_init

N = 16  # small state dim for visualization
A = make_hippo_legs(N)
B = make_hippo_b(N, "legs")

print(f"HiPPO-LegS matrix A shape: {A.shape}")
print(f"A is lower-triangular: {np.allclose(A, np.tril(A))}")

eigs = np.linalg.eigvals(A)
print(f"\nEigenvalue real parts: [{eigs.real.min():.2f}, {eigs.real.max():.2f}]")
print(f"All stable (Re < 0): {(eigs.real < 0).all()}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

# Matrix heatmap
im = axes[0].imshow(A, cmap="RdBu_r", vmin=-abs(A).max(), vmax=abs(A).max())
axes[0].set_title("HiPPO-LegS Matrix (N=16)", fontsize=13)
axes[0].set_xlabel("Column k")
axes[0].set_ylabel("Row n")
plt.colorbar(im, ax=axes[0], shrink=0.8)

# Eigenvalue plot
axes[1].scatter(eigs.real, eigs.imag, c="steelblue", edgecolors="k", s=60, zorder=5)
axes[1].axvline(0, color="gray", linestyle="--", alpha=0.5)
axes[1].set_title("Eigenvalues of HiPPO-LegS", fontsize=13)
axes[1].set_xlabel("Real part")
axes[1].set_ylabel("Imaginary part")
axes[1].annotate("All eigenvalues in\nleft half-plane → stable",
                 xy=(-8, 0), fontsize=10, color="steelblue",
                 bbox=dict(boxstyle="round", fc="lightyellow", alpha=0.8))

plt.tight_layout()
plt.show()

**Key takeaway:** The HiPPO matrix gives SSMs a mathematical mechanism for capturing **long-range dependencies** — something that RNNs (which use random init) famously struggle with, and Transformers solve with quadratic-cost attention.

## 3. Discretization

The continuous SSM operates on continuous signals, but real-world data (text tokens, audio samples, image pixels) is **discrete**. We need to **discretize** the system.

Given a learnable step size $\Delta$, we convert $(A, B) \to (\bar{A}, \bar{B})$:

| Method | $\bar{A}$ | $\bar{B}$ | Used in |
|--------|-----------|-----------|--------|
| **Bilinear** (Tustin) | $(I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A)$ | $(I - \frac{\Delta}{2}A)^{-1} \Delta B$ | Original S4 |
| **Zero-Order Hold** | $e^{\Delta A}$ | $A^{-1}(e^{\Delta A} - I)B$ | S4D, Mamba |

The discrete system then becomes a standard recurrence:
$$x[k] = \bar{A}\, x[k-1] + \bar{B}\, u[k], \qquad y[k] = C\, x[k] + D\, u[k]$$

The step size $\Delta$ is **learnable** — it controls the *resolution* at which the model "sees" the input (analogous to the receptive field in CNNs).

In [None]:
from s4_lib.discretization import bilinear, zoh

# Example: discretize a 2D diagonal system
Lambda = torch.tensor([-1.0 + 2j, -3.0 - 1j], dtype=torch.cfloat)
B_ex = torch.tensor([1.0 + 0j, 1.0 + 0j], dtype=torch.cfloat)
dt = torch.tensor(0.01)

Ab_bil, Bb_bil = bilinear(Lambda, B_ex, dt)
Ab_zoh, Bb_zoh = zoh(Lambda, B_ex, dt)

print("Continuous eigenvalues:")
for i, l in enumerate(Lambda):
    print(f"  λ{i} = {l.item():.4f}  (|λ| = {abs(l):.4f})")

print(f"\nAfter discretization (Δ = {dt.item()}):\n")
print(f"  Bilinear |Ā|: {[f'{abs(a):.6f}' for a in Ab_bil.tolist()]}")
print(f"  ZOH      |Ā|: {[f'{abs(a):.6f}' for a in Ab_zoh.tolist()]}")
print(f"\n  All |Ā| < 1 (stable): Bilinear={all(abs(a)<1 for a in Ab_bil.tolist())}, "
      f"ZOH={all(abs(a)<1 for a in Ab_zoh.tolist())}")

## 4. CNN ↔ RNN Duality

The discrete SSM can be viewed in **two equivalent ways**:

### RNN view (sequential) — good for inference
```python
for k in range(L):
    x[k] = A_bar * x[k-1] + B_bar * u[k]
    y[k] = C * x[k]
```
Process one step at a time: $O(1)$ memory per step.

### CNN view (parallel) — good for training
Unroll the recurrence to get a **convolution kernel**:
$$K = (C\bar{B},\; C\bar{A}\bar{B},\; C\bar{A}^2\bar{B},\; \ldots,\; C\bar{A}^{L-1}\bar{B})$$
$$y = u * K \qquad \text{(convolution, computed via FFT in } O(L \log L) \text{)}$$

**S4's key contribution:** compute $K$ efficiently using the DPLR structure of $A$ and Cauchy kernels, avoiding the $O(N^2 L)$ naive cost.

Let's verify that both views produce **identical output**:

In [None]:
from s4_lib.hippo import s4d_inv_init
from s4_lib.kernels import kernel_diagonal, fft_conv

N_demo, L_demo = 8, 64
Lambda_d = torch.from_numpy(s4d_inv_init(N_demo))
B_d = torch.randn(N_demo, dtype=torch.cfloat) * 0.1
C_d = torch.randn(N_demo, dtype=torch.cfloat) * 0.1
dt_d = torch.tensor(0.05)

# ---- CNN mode: compute kernel, convolve ----
K = kernel_diagonal(Lambda_d, B_d, C_d, dt_d, L_demo, method="zoh")
u = torch.randn(1, L_demo)
y_cnn = fft_conv(u, K)

# ---- RNN mode: step-by-step ----
Ab_d = torch.exp(Lambda_d * dt_d)
Bb_d = B_d * (Ab_d - 1.0) / Lambda_d
state = torch.zeros(N_demo, dtype=torch.cfloat)
y_rnn = []
for k in range(L_demo):
    state = Ab_d * state + Bb_d * u[0, k]
    y_rnn.append((C_d * state).sum().real.item())
y_rnn = torch.tensor(y_rnn).unsqueeze(0)

error = (y_cnn - y_rnn).abs().max().item()
print(f"Max absolute difference between CNN and RNN outputs: {error:.2e}")
print(f"Equivalent: {error < 1e-4} ✓")

In [None]:
fig, ax = plt.subplots(figsize=(10, 3.5))
ax.plot(y_cnn[0].detach().numpy(), label="CNN mode (parallel)", linewidth=2)
ax.plot(y_rnn[0].detach().numpy(), "--", label="RNN mode (sequential)", linewidth=2)
ax.set_xlabel("Time step")
ax.set_ylabel("Output")
ax.set_title("CNN and RNN modes produce identical output")
ax.legend()
plt.tight_layout()
plt.show()

## 5. SSMs vs Transformers / LLMs

Both S4/SSMs and Transformers are **sequence models**, but they take fundamentally different approaches:

| | **Transformer (GPT, BERT, etc.)** | **S4 / SSM** |
|---|---|---|
| **Core mechanism** | Self-attention: every token attends to every other token | State recurrence: compress history into a fixed-size state |
| **Training complexity** | $O(L^2 d)$ — quadratic in sequence length | $O(L \log L \cdot N)$ — near-linear via FFT |
| **Inference (per step)** | $O(L \cdot d)$ — must attend to full KV cache | $O(N \cdot d)$ — constant, just update state |
| **Memory at inference** | KV cache grows linearly with context | Fixed $O(N \cdot d)$ state |
| **Long sequences** | Struggles beyond training length (or needs tricks like RoPE, ALiBi) | Handles 10K–100K+ steps naturally |
| **Strengths** | In-context learning, flexible attention patterns | Long-range dependencies, continuous-time signals, efficiency |

### When to use each:

- **Transformers** excel at tasks where **flexible, content-dependent routing** matters (e.g., language understanding, in-context learning, retrieval).
- **SSMs/S4** excel at tasks with **very long sequences** or **continuous signals** (audio, time series, genomics, long-range arena benchmarks).

Modern architectures like **Mamba** (Gu & Dao, 2023) combine SSM efficiency with input-dependent *selectivity*, approaching Transformer quality on language while keeping linear scaling.

### Key numbers from the S4 paper:
- **91%** on sequential CIFAR-10 (pixel-by-pixel, length 1024) — matching 2D ResNets
- **86%** average on Long Range Arena (length 1K–16K) — all prior models < 60%
- **60× faster** generation than Transformers on language modeling

In [None]:
# Quick complexity comparison
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
d = 256  # model dim
N = 64   # state dim

transformer_flops = [L * L * d for L in seq_lengths]       # O(L^2 d)
s4_flops          = [L * np.log2(L) * N for L in seq_lengths]  # O(L log L * N)

fig, ax = plt.subplots(figsize=(8, 4.5))
ax.semilogy(seq_lengths, transformer_flops, "o-", label="Transformer: O(L² d)", linewidth=2)
ax.semilogy(seq_lengths, s4_flops, "s-", label="S4: O(L log L · N)", linewidth=2)
ax.set_xlabel("Sequence length L")
ax.set_ylabel("Relative FLOPs (log scale)")
ax.set_title("Training Cost: Transformer vs S4")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"At L=16,384:  Transformer is {transformer_flops[-1]/s4_flops[-1]:.0f}× more expensive than S4")

---
## Summary

| Concept | Key idea |
|---------|----------|
| **SSM** | Linear dynamical system: $x' = Ax + Bu$, $y = Cx + Du$ |
| **HiPPO** | Special $A$ matrix that optimally compresses input history |
| **Discretization** | Convert continuous $(A,B) \to$ discrete $(\bar A, \bar B)$ with learnable step $\Delta$ |
| **CNN/RNN duality** | Train via FFT convolution (parallel), generate via recurrence (O(1)/step) |
| **vs Transformers** | S4 scales linearly with sequence length; Transformers scale quadratically |

**Next:** [02_s4_quickstart.ipynb](02_s4_quickstart.ipynb) — hands-on usage of the library.