# Tutorial 3 — Classification: S4 vs Transformer

In this notebook we train **S4** and a **Transformer** baseline on the same synthetic classification task and compare:

- **Test accuracy**
- **Parameter count**
- **Training speed** (time per epoch)

## Task: Delayed-XOR

A sequence of length $L=256$ with 3 channels:

| Channel | Content |
|---------|--------|
| 0 | random bit (0/1) placed at position 0, rest is 0 |
| 1 | random bit (0/1) placed at position 128 (halfway), rest is 0 |
| 2 | Gaussian noise (distractor) |

**Label** = XOR of the two bits = `ch0[0] ⊕ ch1[128]` → binary classification.

Why is this hard? The model must **remember** a signal from step 0 until step 128 — a 128-step long-range dependency. SSMs are built for this; Transformers rely on attention to bridge the gap.

---

In [None]:
import sys, os, time
sys.path.insert(0, os.path.join(os.getcwd(), "..") if "tutorials" in os.getcwd() else os.getcwd())

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(0)
np.random.seed(0)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

## 1. Generate the Delayed-XOR Dataset

In [None]:
def make_delayed_xor(n_samples, seq_len=256, delay=128):
    """
    Returns:
        X: (n_samples, seq_len, 3) float tensor
        y: (n_samples,) long tensor — 0 or 1
    """
    X = torch.zeros(n_samples, seq_len, 3)
    bits_a = torch.randint(0, 2, (n_samples,))
    bits_b = torch.randint(0, 2, (n_samples,))
    X[:, 0, 0] = bits_a.float()
    X[:, delay, 1] = bits_b.float()
    X[:, :, 2] = torch.randn(n_samples, seq_len) * 0.1  # noise distractor
    y = (bits_a ^ bits_b).long()
    return X, y

N_TRAIN, N_TEST = 4000, 1000
SEQ_LEN = 256

X_train, y_train = make_delayed_xor(N_TRAIN, SEQ_LEN)
X_test, y_test   = make_delayed_xor(N_TEST, SEQ_LEN)

print(f"Train: X={X_train.shape}, y={y_train.shape}, class balance: {y_train.float().mean():.2f}")
print(f"Test:  X={X_test.shape},  y={y_test.shape},  class balance: {y_test.float().mean():.2f}")

In [None]:
# Visualize one sample
idx = 0
fig, axes = plt.subplots(3, 1, figsize=(10, 4), sharex=True)
for ch, name in enumerate(["Bit A (pos 0)", "Bit B (pos 128)", "Noise"]):
    axes[ch].plot(X_train[idx, :, ch].numpy(), linewidth=1)
    axes[ch].set_ylabel(name, fontsize=9)
axes[-1].set_xlabel("Time step")
fig.suptitle(f"Sample {idx}: label = {y_train[idx].item()} (XOR)", fontsize=12)
plt.tight_layout()
plt.show()

## 2. DataLoaders

In [None]:
from torch.utils.data import TensorDataset, DataLoader

BATCH_SIZE = 64

train_ds = TensorDataset(X_train, y_train)
test_ds  = TensorDataset(X_test, y_test)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

## 3. Define Models

### S4 model
Our `S4SequenceModel` with `task="classification"` — mean-pools over time and passes through a linear head.

### Transformer baseline
A standard `nn.TransformerEncoder` with learned positional embeddings, followed by mean-pooling and a classifier head. We match `d_model` and `n_layers` for a fair comparison.

In [None]:
from s4_lib import S4SequenceModel, get_ssm_param_groups

D_MODEL = 64
N_LAYERS = 4
N_CLASSES = 2

# ---- S4 model ----
s4_model = S4SequenceModel(
    d_input=3, d_model=D_MODEL, d_output=N_CLASSES,
    n_layers=N_LAYERS, d_state=64, dropout=0.1,
    task="classification",
).to(DEVICE)

s4_params = sum(p.numel() for p in s4_model.parameters())
print(f"S4 model params: {s4_params:,}")

In [None]:
class TransformerClassifier(nn.Module):
    """Transformer encoder + positional embedding + mean-pool + classifier."""
    def __init__(self, d_input, d_model, n_layers, n_classes,
                 max_len=512, nhead=4, dim_ff=256, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(d_input, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x):
        B, L, _ = x.shape
        x = self.proj(x) + self.pos_emb[:, :L, :]
        x = self.encoder(x)
        x = x.mean(dim=1)
        return self.head(x)

tf_model = TransformerClassifier(
    d_input=3, d_model=D_MODEL, n_layers=N_LAYERS, n_classes=N_CLASSES,
    max_len=SEQ_LEN, nhead=4, dim_ff=D_MODEL*4, dropout=0.1,
).to(DEVICE)

tf_params = sum(p.numel() for p in tf_model.parameters())
print(f"Transformer model params: {tf_params:,}")
print(f"Ratio (Transformer / S4): {tf_params/s4_params:.2f}×")

## 4. Training

We train both models for 30 epochs with the same optimizer settings (AdamW, lr=0.004). S4 gets a separate SSM learning rate of 0.001 as recommended.

In [None]:
def train_one_epoch(model, optimizer, loader, criterion, device):
    model.train()
    total_loss, correct, n = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(-1) == yb).sum().item()
        n += xb.size(0)
    return total_loss / n, correct / n


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, n = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(-1) == yb).sum().item()
        n += xb.size(0)
    return total_loss / n, correct / n

In [None]:
N_EPOCHS = 30
criterion = nn.CrossEntropyLoss()

# S4 optimizer (dual LR)
s4_opt = torch.optim.AdamW(
    get_ssm_param_groups(s4_model, lr=0.004, ssm_lr=0.001, weight_decay=0.01)
)

# Transformer optimizer
tf_opt = torch.optim.AdamW(tf_model.parameters(), lr=0.004, weight_decay=0.01)

# ---- Train S4 ----
s4_hist = {"train_loss": [], "test_loss": [], "train_acc": [], "test_acc": [], "epoch_time": []}
print("Training S4...")
for ep in range(N_EPOCHS):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(s4_model, s4_opt, train_dl, criterion, DEVICE)
    te_loss, te_acc = evaluate(s4_model, test_dl, criterion, DEVICE)
    dt = time.time() - t0
    s4_hist["train_loss"].append(tr_loss)
    s4_hist["test_loss"].append(te_loss)
    s4_hist["train_acc"].append(tr_acc)
    s4_hist["test_acc"].append(te_acc)
    s4_hist["epoch_time"].append(dt)
    if (ep + 1) % 5 == 0:
        print(f"  [S4] Epoch {ep+1:2d}/{N_EPOCHS}: loss={tr_loss:.4f}  "
              f"acc={tr_acc:.3f}  test_acc={te_acc:.3f}  ({dt:.1f}s)")

# ---- Train Transformer ----
tf_hist = {"train_loss": [], "test_loss": [], "train_acc": [], "test_acc": [], "epoch_time": []}
print("\nTraining Transformer...")
for ep in range(N_EPOCHS):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(tf_model, tf_opt, train_dl, criterion, DEVICE)
    te_loss, te_acc = evaluate(tf_model, test_dl, criterion, DEVICE)
    dt = time.time() - t0
    tf_hist["train_loss"].append(tr_loss)
    tf_hist["test_loss"].append(te_loss)
    tf_hist["train_acc"].append(tr_acc)
    tf_hist["test_acc"].append(te_acc)
    tf_hist["epoch_time"].append(dt)
    if (ep + 1) % 5 == 0:
        print(f"  [TF] Epoch {ep+1:2d}/{N_EPOCHS}: loss={tr_loss:.4f}  "
              f"acc={tr_acc:.3f}  test_acc={te_acc:.3f}  ({dt:.1f}s)")

## 5. Results & Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(s4_hist["train_loss"], label="S4 train", linewidth=2)
axes[0].plot(s4_hist["test_loss"], "--", label="S4 test", linewidth=2)
axes[0].plot(tf_hist["train_loss"], label="Transformer train", linewidth=2)
axes[0].plot(tf_hist["test_loss"], "--", label="Transformer test", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Test Loss")
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(s4_hist["test_acc"], "o-", label="S4", markersize=3, linewidth=2)
axes[1].plot(tf_hist["test_acc"], "s-", label="Transformer", markersize=3, linewidth=2)
axes[1].axhline(0.5, color="gray", linestyle=":", alpha=0.5, label="random")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Test Accuracy")
axes[1].set_title("Test Accuracy")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0.4, 1.05)

# Timing
labels = ["S4", "Transformer"]
avg_times = [np.mean(s4_hist["epoch_time"]), np.mean(tf_hist["epoch_time"])]
colors = ["steelblue", "coral"]
axes[2].bar(labels, avg_times, color=colors, edgecolor="k")
axes[2].set_ylabel("Avg seconds / epoch")
axes[2].set_title("Training Speed")
for i, v in enumerate(avg_times):
    axes[2].text(i, v + 0.01, f"{v:.2f}s", ha="center", fontsize=11)

plt.tight_layout()
plt.show()

In [None]:
# Summary table
print("\n" + "=" * 60)
print(f"{'Metric':<25} {'S4':>15} {'Transformer':>15}")
print("=" * 60)
print(f"{'Parameters':<25} {s4_params:>15,} {tf_params:>15,}")
print(f"{'Best test accuracy':<25} {max(s4_hist['test_acc']):>15.3f} {max(tf_hist['test_acc']):>15.3f}")
print(f"{'Final test accuracy':<25} {s4_hist['test_acc'][-1]:>15.3f} {tf_hist['test_acc'][-1]:>15.3f}")
print(f"{'Avg epoch time (s)':<25} {np.mean(s4_hist['epoch_time']):>15.2f} {np.mean(tf_hist['epoch_time']):>15.2f}")
print(f"{'Final test loss':<25} {s4_hist['test_loss'][-1]:>15.4f} {tf_hist['test_loss'][-1]:>15.4f}")
print("=" * 60)

## 6. Discussion

### What to expect

**S4** should reach near-perfect accuracy quickly because:
- The HiPPO-initialized state matrix is specifically designed to retain information over long spans.
- The 128-step delay falls well within S4's natural operating range.

**The Transformer** can also solve this task, but:
- It relies on **positional embeddings + attention** to bridge the 128-step gap.
- With only 4 layers and `d_model=64`, it may take longer to converge or plateau at a slightly lower accuracy.
- It uses **more parameters** due to the attention weights and feedforward sub-layers.

### Key takeaways

1. **Long-range tasks favor SSMs.** S4's inductive bias (HiPPO memory) gives it a structural advantage on tasks requiring information retention across many steps.
2. **Parameter efficiency.** S4 achieves similar or better accuracy with fewer parameters.
3. **Training cost.** S4 uses O(L log L) FFT convolutions; the Transformer uses O(L²) attention. The gap widens with sequence length.

### When would the Transformer win?

- Tasks requiring **content-based retrieval** (e.g., "find the token that matches this query") — attention is explicitly designed for this.
- Very **short** sequences where the O(L²) cost is negligible.
- Tasks requiring **in-context learning** (e.g., few-shot prompting in LLMs).

---

**Next:** [04_time_series.ipynb](04_time_series.ipynb) — regression on time-series data, again comparing S4 vs Transformer.