# Tutorial 2 — S4 Quickstart

A hands-on introduction to using the **s4_lib** library. You'll learn:

1. **Import & Instantiate** — Create S4 and S4D layers / full models
2. **Forward Pass** — Process a batch of sequences in CNN (parallel) mode
3. **RNN Stepping** — Auto-regressive / streaming inference
4. **Optimizer Setup** — The special learning-rate treatment S4 requires
5. **Training Loop** — A minimal training step
6. **Parameter Comparison** — S4 vs a Transformer encoder of equivalent capacity

---

In [None]:
import sys, os
sys.path.insert(0, os.path.join(os.getcwd(), "..") if "tutorials" in os.getcwd() else os.getcwd())

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(42)

## 1. Import & Instantiate

The library provides three main building blocks:

| Class | Description |
|-------|-------------|
| `S4Layer` | Full DPLR S4 layer (uses Cauchy kernel) — one layer of the original S4 |
| `S4DLayer` | Diagonal S4 variant (simpler, nearly as good) |
| `S4DBlock` | S4D + normalization + GELU + residual + dropout — a "Transformer block" analogue |
| `S4SequenceModel` | Complete model: embedding → stacked S4DBlocks → head (classification or regression) |

In [None]:
from s4_lib import S4Layer, S4DLayer, S4DBlock, S4SequenceModel

# -- Single S4 layer (DPLR) --
s4_layer = S4Layer(d_model=64, state_dim=64, dt_min=0.001, dt_max=0.1)
print("S4Layer:")
print(f"  d_model=64, state_dim=64")
print(f"  params: {sum(p.numel() for p in s4_layer.parameters()):,}")

# -- Single S4D layer (diagonal) --
s4d_layer = S4DLayer(d_model=64, state_dim=64)
print(f"\nS4DLayer:")
print(f"  params: {sum(p.numel() for p in s4d_layer.parameters()):,}")

# -- Full model for classification --
model = S4SequenceModel(
    d_input=3,          # input features per time step
    d_model=64,         # hidden dimension
    d_state=64,         # SSM state dimension
    n_layers=4,         # number of S4D blocks
    d_output=10,        # classification head (num classes)
    dropout=0.1,
    task="classification",  # mean-pool over time → single vector → classifier
)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nS4SequenceModel (4 layers, d=64):")
print(f"  Total params: {total_params:,}")

## 2. Forward Pass (CNN Mode)

In **CNN mode**, the S4 layer materializes its convolution kernel of length $L$ and applies it via FFT. This is fully parallelizable — like running a 1-D convolution over the sequence.

In [None]:
batch, seq_len, features = 8, 256, 3
x = torch.randn(batch, seq_len, features)

# Full model forward
model.eval()
with torch.no_grad():
    logits = model(x)  # (batch, n_classes)

print(f"Input:  {x.shape}  →  (batch, seq_len, features)")
print(f"Output: {logits.shape}  →  (batch, n_classes)")
print(f"\nPredicted classes: {logits.argmax(dim=-1).tolist()}")

In [None]:
# Individual layer forward — returns (batch, seq_len, d_model)
x_proj = nn.Linear(features, 64)(x)  # project to d_model
with torch.no_grad():
    y_s4d = s4d_layer(x_proj)
print(f"S4DLayer: {x_proj.shape} → {y_s4d.shape}")

## 3. RNN Stepping (Auto-regressive / Streaming Inference)

For generation or streaming applications, S4 can switch to **RNN mode** — processing one time step at a time with constant memory:

```
state = initial_state()          # shape: (batch, d_model, state_dim)
for t in range(seq_len):
    y_t, state = layer.step(u_t, state)  # O(1) per step
```

The key advantage: **constant memory** regardless of sequence length. A Transformer would need to keep the full KV cache.

In [None]:
# RNN stepping with S4DLayer
batch_rnn = 2
x_rnn = torch.randn(batch_rnn, 32, 64)  # (batch, time, d_model)

# Step-by-step
state = s4d_layer.init_state(batch_rnn)  # zero state
outputs_rnn = []
for t in range(32):
    y_t, state = s4d_layer.step(x_rnn[:, t, :], state)
    outputs_rnn.append(y_t)
y_rnn = torch.stack(outputs_rnn, dim=1)

# Compare with CNN forward
with torch.no_grad():
    y_cnn = s4d_layer(x_rnn)

diff = (y_rnn - y_cnn).abs().max().item()
print(f"RNN output shape: {y_rnn.shape}")
print(f"CNN output shape: {y_cnn.shape}")
print(f"Max difference:   {diff:.2e}  ({'✓ Match' if diff < 1e-3 else '✗ Mismatch'})")

## 4. Optimizer Setup

S4 has a quirk: the SSM parameters ($A$, $B$, $\Delta$) need a **different learning rate** than the rest of the model (typically higher, e.g., `0.001` for SSM params vs `0.004` for others).

The library provides `get_ssm_param_groups` to set this up automatically:

In [None]:
from s4_lib import get_ssm_param_groups

param_groups = get_ssm_param_groups(model, lr=0.004, ssm_lr=0.001, weight_decay=0.01)

for i, g in enumerate(param_groups):
    n_params = sum(p.numel() for p in g["params"])
    print(f"Group {i}: lr={g['lr']}, weight_decay={g['weight_decay']}, params={n_params:,}")

optimizer = torch.optim.AdamW(param_groups)

## 5. Minimal Training Step

Let's run a quick training step on random data to verify everything works end-to-end:

In [None]:
model.train()
criterion = nn.CrossEntropyLoss()

# Fake data
x_train = torch.randn(16, 256, 3)
y_train = torch.randint(0, 10, (16,))

losses = []
for step in range(20):
    optimizer.zero_grad()
    logits = model(x_train)
    loss = criterion(logits, y_train)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    if step % 5 == 0:
        print(f"Step {step:2d}: loss = {loss.item():.4f}")

print(f"\nLoss decreased: {losses[0]:.4f} → {losses[-1]:.4f}  ({'✓' if losses[-1] < losses[0] else '✗'})")

In [None]:
fig, ax = plt.subplots(figsize=(7, 3.5))
ax.plot(losses, "o-", markersize=4)
ax.set_xlabel("Step")
ax.set_ylabel("Loss")
ax.set_title("Training loss (random data — just verifying the pipeline works)")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Parameter Comparison: S4 vs Transformer

A key selling point of S4 is **parameter efficiency**. Let's build a Transformer encoder with the same hidden dimension and number of layers, then compare total parameter counts on identical classification tasks.

We'll fix:
- `d_model = 64`, `n_layers = 4`, `n_classes = 10`, `d_input = 3`

For the Transformer, we use `nn.TransformerEncoder` with `nhead=4` and `dim_feedforward=256` (4× expansion, the standard ratio).

In [None]:
class TransformerClassifier(nn.Module):
    """Simple Transformer encoder + mean-pool + classifier."""
    def __init__(self, d_input, d_model, n_layers, n_classes, nhead=4, dim_ff=256, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(d_input, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x):
        x = self.proj(x)
        x = self.encoder(x)
        x = x.mean(dim=1)  # pool over time
        return self.head(x)


# Instantiate both
s4_model = S4SequenceModel(d_input=3, d_model=64, d_state=64, n_layers=4,
                           d_output=10, dropout=0.1, task="classification")
tf_model = TransformerClassifier(d_input=3, d_model=64, n_layers=4,
                                 n_classes=10, nhead=4, dim_ff=256, dropout=0.1)

s4_params = sum(p.numel() for p in s4_model.parameters())
tf_params = sum(p.numel() for p in tf_model.parameters())

print("=" * 55)
print(f"{'Model':<25} {'Params':>12} {'Ratio':>10}")
print("=" * 55)
print(f"{'S4SequenceModel (4L)':<25} {s4_params:>12,}      1.0×")
print(f"{'TransformerEncoder (4L)':<25} {tf_params:>12,}    {tf_params/s4_params:>5.1f}×")
print("=" * 55)

In [None]:
# Breakdown by component
def param_breakdown(model, name):
    breakdown = {}
    for n, p in model.named_parameters():
        component = n.split(".")[0]
        breakdown[component] = breakdown.get(component, 0) + p.numel()
    print(f"\n{name} breakdown:")
    for k, v in sorted(breakdown.items()):
        print(f"  {k:<20} {v:>8,}")

param_breakdown(s4_model, "S4SequenceModel")
param_breakdown(tf_model, "TransformerClassifier")

In [None]:
# Scaling comparison: vary d_model
dims = [32, 64, 128, 256]
s4_counts, tf_counts = [], []

for d in dims:
    s4_m = S4SequenceModel(d_input=3, d_model=d, d_state=64, n_layers=4,
                           d_output=10, dropout=0.0, task="classification")
    tf_m = TransformerClassifier(d_input=3, d_model=d, n_layers=4,
                                n_classes=10, nhead=4, dim_ff=d*4, dropout=0.0)
    s4_counts.append(sum(p.numel() for p in s4_m.parameters()))
    tf_counts.append(sum(p.numel() for p in tf_m.parameters()))

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Absolute counts
axes[0].bar([str(d) for d in dims], s4_counts, width=0.35, label="S4", alpha=0.85)
axes[0].bar([str(d) for d in dims], tf_counts, width=0.35, label="Transformer",
            alpha=0.85, bottom=None)
x_pos = np.arange(len(dims))
w = 0.35
axes[0].cla()
axes[0].bar(x_pos - w/2, [c/1000 for c in s4_counts], w, label="S4", color="steelblue")
axes[0].bar(x_pos + w/2, [c/1000 for c in tf_counts], w, label="Transformer", color="coral")
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels([str(d) for d in dims])
axes[0].set_xlabel("d_model")
axes[0].set_ylabel("Params (thousands)")
axes[0].set_title("Total Parameters")
axes[0].legend()

# Ratio
ratios = [t / s for s, t in zip(s4_counts, tf_counts)]
axes[1].plot(dims, ratios, "o-", color="purple", linewidth=2, markersize=8)
axes[1].axhline(1, color="gray", linestyle="--", alpha=0.5)
axes[1].set_xlabel("d_model")
axes[1].set_ylabel("Transformer / S4 param ratio")
axes[1].set_title("Parameter Ratio (>1 means Transformer is bigger)")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nParameter ratio (Transformer / S4) by d_model:")
for d, r in zip(dims, ratios):
    print(f"  d_model={d:<4}: {r:.2f}×")

### Key takeaways

- The Transformer's feedforward sub-layers (two dense matrices per layer, $d \to 4d \to d$) account for the bulk of its parameters.
- S4 layers are **much more parameter-efficient** because the SSM parameters ($\Lambda$, $B$, $C$, $\log\Delta$) only depend on the state dimension $N$, not on $d_{\text{model}}^2$.
- As `d_model` grows, the gap widens — making S4 increasingly attractive for large-scale settings.

---

**Next:** [03_classification.ipynb](03_classification.ipynb) — train S4 and a Transformer on a real classification task and compare accuracy, speed, and parameters.