# Tutorial 4 — Time Series Forecasting: S4 vs Transformer

In this notebook we tackle a **regression** (forecasting) task and compare **S4** against a **Transformer** baseline.

## Task: Multi-variate Sinusoid Forecasting

We generate 3-channel sinusoidal signals with varying frequencies, phases, and noise:

$$x_c(t) = \sin(2\pi f_c t + \phi_c) + \epsilon, \quad c \in \{0, 1, 2\}$$

- **Input:** steps $0 \ldots L{-}1$ (the past)
- **Target:** steps $1 \ldots L$ (one-step-ahead prediction)

This tests whether each model can learn the **temporal dynamics** of continuous signals — a natural fit for SSMs, and a case where Transformers must learn temporal structure purely from data.

We compare:
- **Test MSE**
- **Parameter count**
- **Training speed**
- **Qualitative predictions** (overlay plots)

---

In [None]:
import sys, os, time
sys.path.insert(0, os.path.join(os.getcwd(), "..") if "tutorials" in os.getcwd() else os.getcwd())

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

## 1. Generate the Dataset

In [None]:
def make_sinusoid_dataset(n_samples, seq_len=200, n_features=3, noise_std=0.05):
    """
    Generate multi-channel sinusoidal time series.

    Returns:
        X: (n_samples, seq_len, n_features)  — input (steps 0..L-1)
        Y: (n_samples, seq_len, n_features)  — target (steps 1..L)
    """
    t = np.linspace(0, 4 * np.pi, seq_len + 1)  # one extra step for target
    data = np.zeros((n_samples, seq_len + 1, n_features))
    for i in range(n_samples):
        for c in range(n_features):
            freq = np.random.uniform(0.5, 3.0)
            phase = np.random.uniform(0, 2 * np.pi)
            amp = np.random.uniform(0.5, 1.5)
            data[i, :, c] = amp * np.sin(freq * t + phase)
    data += np.random.randn(*data.shape) * noise_std

    X = torch.tensor(data[:, :-1, :], dtype=torch.float32)
    Y = torch.tensor(data[:, 1:, :],  dtype=torch.float32)
    return X, Y

N_TRAIN, N_TEST = 3000, 500
SEQ_LEN = 200
N_FEATURES = 3

X_train, Y_train = make_sinusoid_dataset(N_TRAIN, SEQ_LEN, N_FEATURES)
X_test,  Y_test  = make_sinusoid_dataset(N_TEST,  SEQ_LEN, N_FEATURES)

print(f"Train: X={X_train.shape}, Y={Y_train.shape}")
print(f"Test:  X={X_test.shape},  Y={Y_test.shape}")

In [None]:
# Visualize a sample
idx = 0
fig, axes = plt.subplots(N_FEATURES, 1, figsize=(10, 5), sharex=True)
for c in range(N_FEATURES):
    axes[c].plot(X_train[idx, :, c].numpy(), label="Input", linewidth=1.5)
    axes[c].plot(Y_train[idx, :, c].numpy(), "--", label="Target (shifted by 1)", linewidth=1.5, alpha=0.7)
    axes[c].set_ylabel(f"Channel {c}")
    if c == 0:
        axes[c].legend(fontsize=9)
axes[-1].set_xlabel("Time step")
fig.suptitle("Sample 0: input and one-step-ahead target", fontsize=12)
plt.tight_layout()
plt.show()

## 2. DataLoaders

In [None]:
from torch.utils.data import TensorDataset, DataLoader

BATCH_SIZE = 64

train_dl = DataLoader(TensorDataset(X_train, Y_train), batch_size=BATCH_SIZE, shuffle=True)
test_dl  = DataLoader(TensorDataset(X_test, Y_test),   batch_size=BATCH_SIZE)

## 3. Define Models

### S4 model
Our `S4SequenceModel` with `task="regression"` — outputs per-timestep predictions without pooling.

### Transformer baseline
A causal Transformer encoder (with causal mask to prevent future leakage) + positional embeddings → per-step linear projection. Same `d_model` and `n_layers`.

In [None]:
from s4_lib import S4SequenceModel, get_ssm_param_groups

D_MODEL = 64
N_LAYERS = 4

# ---- S4 model ----
s4_model = S4SequenceModel(
    d_input=N_FEATURES, d_model=D_MODEL, d_output=N_FEATURES,
    n_layers=N_LAYERS, d_state=64, dropout=0.1,
    task="regression",
).to(DEVICE)

s4_params = sum(p.numel() for p in s4_model.parameters())
print(f"S4 model params: {s4_params:,}")

In [None]:
class TransformerForecaster(nn.Module):
    """Causal Transformer encoder for per-step regression."""
    def __init__(self, d_input, d_model, n_layers, d_output,
                 max_len=512, nhead=4, dim_ff=256, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(d_input, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, d_model) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, d_output)

    def forward(self, x):
        B, L, _ = x.shape
        x = self.proj(x) + self.pos_emb[:, :L, :]
        # Causal mask — prevent attending to future steps
        mask = nn.Transformer.generate_square_subsequent_mask(L, device=x.device)
        x = self.encoder(x, mask=mask)
        return self.head(x)  # (B, L, d_output)

tf_model = TransformerForecaster(
    d_input=N_FEATURES, d_model=D_MODEL, n_layers=N_LAYERS, d_output=N_FEATURES,
    max_len=SEQ_LEN, nhead=4, dim_ff=D_MODEL*4, dropout=0.1,
).to(DEVICE)

tf_params = sum(p.numel() for p in tf_model.parameters())
print(f"Transformer model params: {tf_params:,}")
print(f"Ratio (Transformer / S4): {tf_params/s4_params:.2f}×")

## 4. Training

In [None]:
def train_one_epoch(model, optimizer, loader, criterion, device):
    model.train()
    total_loss, n = 0.0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    return total_loss / n


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, n = 0.0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb)
        loss = criterion(preds, yb)
        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    return total_loss / n

In [None]:
N_EPOCHS = 30
criterion = nn.MSELoss()

# S4 optimizer (dual LR)
s4_opt = torch.optim.AdamW(
    get_ssm_param_groups(s4_model, lr=0.004, ssm_lr=0.001, weight_decay=0.01)
)

# Transformer optimizer
tf_opt = torch.optim.AdamW(tf_model.parameters(), lr=0.004, weight_decay=0.01)

# ---- Train S4 ----
s4_hist = {"train_mse": [], "test_mse": [], "epoch_time": []}
print("Training S4...")
for ep in range(N_EPOCHS):
    t0 = time.time()
    tr_mse = train_one_epoch(s4_model, s4_opt, train_dl, criterion, DEVICE)
    te_mse = evaluate(s4_model, test_dl, criterion, DEVICE)
    dt = time.time() - t0
    s4_hist["train_mse"].append(tr_mse)
    s4_hist["test_mse"].append(te_mse)
    s4_hist["epoch_time"].append(dt)
    if (ep + 1) % 5 == 0:
        print(f"  [S4] Epoch {ep+1:2d}/{N_EPOCHS}: train_MSE={tr_mse:.6f}  "
              f"test_MSE={te_mse:.6f}  ({dt:.1f}s)")

# ---- Train Transformer ----
tf_hist = {"train_mse": [], "test_mse": [], "epoch_time": []}
print("\nTraining Transformer...")
for ep in range(N_EPOCHS):
    t0 = time.time()
    tr_mse = train_one_epoch(tf_model, tf_opt, train_dl, criterion, DEVICE)
    te_mse = evaluate(tf_model, test_dl, criterion, DEVICE)
    dt = time.time() - t0
    tf_hist["train_mse"].append(tr_mse)
    tf_hist["test_mse"].append(te_mse)
    tf_hist["epoch_time"].append(dt)
    if (ep + 1) % 5 == 0:
        print(f"  [TF] Epoch {ep+1:2d}/{N_EPOCHS}: train_MSE={tr_mse:.6f}  "
              f"test_MSE={te_mse:.6f}  ({dt:.1f}s)")

## 5. Results & Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# MSE curves
axes[0].semilogy(s4_hist["train_mse"], label="S4 train", linewidth=2)
axes[0].semilogy(s4_hist["test_mse"], "--", label="S4 test", linewidth=2)
axes[0].semilogy(tf_hist["train_mse"], label="TF train", linewidth=2)
axes[0].semilogy(tf_hist["test_mse"], "--", label="TF test", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("MSE (log)")
axes[0].set_title("Training & Test MSE")
axes[0].legend(fontsize=8)
axes[0].grid(True, alpha=0.3)

# Final test MSE bar
labels = ["S4", "Transformer"]
final_mse = [s4_hist["test_mse"][-1], tf_hist["test_mse"][-1]]
colors = ["steelblue", "coral"]
axes[1].bar(labels, final_mse, color=colors, edgecolor="k")
axes[1].set_ylabel("Test MSE")
axes[1].set_title("Final Test MSE")
for i, v in enumerate(final_mse):
    axes[1].text(i, v + 0.0001, f"{v:.5f}", ha="center", fontsize=10)

# Timing
avg_times = [np.mean(s4_hist["epoch_time"]), np.mean(tf_hist["epoch_time"])]
axes[2].bar(labels, avg_times, color=colors, edgecolor="k")
axes[2].set_ylabel("Avg seconds / epoch")
axes[2].set_title("Training Speed")
for i, v in enumerate(avg_times):
    axes[2].text(i, v + 0.01, f"{v:.2f}s", ha="center", fontsize=11)

plt.tight_layout()
plt.show()

## 6. Qualitative Predictions

Let's overlay the model predictions on a few test samples to visually assess how well each model tracks the signal.

In [None]:
s4_model.eval()
tf_model.eval()

with torch.no_grad():
    s4_preds = s4_model(X_test[:4].to(DEVICE)).cpu()
    tf_preds = tf_model(X_test[:4].to(DEVICE)).cpu()

fig, axes = plt.subplots(4, N_FEATURES, figsize=(14, 10), sharex=True)
for row in range(4):
    for c in range(N_FEATURES):
        ax = axes[row, c]
        ax.plot(Y_test[row, :, c].numpy(), "k-", label="Ground truth", linewidth=1.5, alpha=0.6)
        ax.plot(s4_preds[row, :, c].numpy(), "-", label="S4", linewidth=1.5)
        ax.plot(tf_preds[row, :, c].numpy(), "--", label="Transformer", linewidth=1.5)
        if row == 0:
            ax.set_title(f"Channel {c}")
        if c == 0:
            ax.set_ylabel(f"Sample {row}")
        if row == 0 and c == 0:
            ax.legend(fontsize=7)
axes[-1, 1].set_xlabel("Time step")
fig.suptitle("Predictions vs Ground Truth (4 test samples)", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Summary table
print("\n" + "=" * 60)
print(f"{'Metric':<25} {'S4':>15} {'Transformer':>15}")
print("=" * 60)
print(f"{'Parameters':<25} {s4_params:>15,} {tf_params:>15,}")
print(f"{'Best test MSE':<25} {min(s4_hist['test_mse']):>15.6f} {min(tf_hist['test_mse']):>15.6f}")
print(f"{'Final test MSE':<25} {s4_hist['test_mse'][-1]:>15.6f} {tf_hist['test_mse'][-1]:>15.6f}")
print(f"{'Avg epoch time (s)':<25} {np.mean(s4_hist['epoch_time']):>15.2f} {np.mean(tf_hist['epoch_time']):>15.2f}")
print("=" * 60)

## 7. Discussion

### Why SSMs are a natural fit for time series

Time-series data is inherently **continuous and sequential**. State space models were originally developed in control theory for exactly this kind of signal:

1. **Continuous-time formulation:** The SSM's underlying ODE $x' = Ax + Bu$ directly models continuous dynamics. The learnable step size $\Delta$ adapts to the signal's natural frequency.

2. **Causal by construction:** SSMs process data left-to-right through state updates — no need for causal masks or tricks.

3. **Efficient long-horizon forecasting:** In RNN mode, S4 generates one step at a time with O(1) cost per step and constant memory.

### How the Transformer approaches this differently

- The Transformer has **no notion of continuous time** — it treats the sequence as a bag of positions and must learn temporal structure from positional embeddings.
- The causal mask restricts attention to the past, but each step still requires O(L) computation to attend over the full context.
- Transformers can still perform well when given enough capacity and data, but they lack the **inductive bias** for smooth, continuous signals.

### When would a Transformer outperform S4 on time series?

- **Irregular time series** with complex event-driven patterns (e.g., financial trades, clinical events)
- **Multi-modal time series** where cross-channel attention patterns matter
- When combined with time-series-specific Transformer variants (PatchTST, Informer, etc.)

---

## Congratulations!

You've completed all four tutorials. You now know:
1. The **theory** behind SSMs, HiPPO, and S4
2. How to **use the s4_lib library** in practice
3. How S4 compares to Transformers on **classification** and **time-series** tasks

For further exploration:
- Try increasing the sequence length to 1000+ — watch the S4/Transformer gap widen
- Experiment with S4's `d_state` parameter — more state = longer memory
- Read the [S4 paper](https://arxiv.org/abs/2111.00396) and [S4D paper](https://arxiv.org/abs/2206.11893) for the full mathematical details