# Spatio-Temporal Traffic Forecasting with Graph Neural Networks (METR-LA / PEMS)

This notebook focuses on **spatio-temporal / graph-based traffic forecasting**.

Instead of a single aggregated traffic series, we model traffic at many
sensors (nodes) connected by a road network graph, using a simple
**Spatio-Temporal Graph Convolutional Network (ST-GCN)**-style model.

We use a METR-LA / PEMS-style dataset with:

- A set of **sensors / stations** (graph nodes).
- Traffic speed or volume time series at each node.
- A graph **adjacency matrix** describing road connections.

The goal is to predict **future traffic at all sensors** given a window of
past observations, leveraging both temporal dynamics and spatial structure.


## 0. How to run this notebook

1. Obtain a preprocessed METR-LA or PEMS-style dataset.
   Many public repos (e.g. DCRNN / STGCN) provide NPZ/H5 files with
   preprocessed data and adjacency matrices.

2. For this notebook, we assume an NPZ file with at least:

   - `traffic`: 3D array of shape `(T, N, 1)` or 2D `(T, N)`
     where `T` is the number of time steps and `N` the number of sensors.
   - `adjacency`: 2D array `(N, N)` with non-negative weights.

   Save it as:

   ```text
   data/graph_traffic.npz
   ```

3. Install required packages:

   ```bash
   pip install numpy pandas matplotlib torch scikit-learn
   ```

4. Open this notebook and run it top-to-bottom. You can adjust model and
   training hyperparameters to match your hardware.


## 1. Imports and configuration


In [ ]:
from __future__ import annotations

from pathlib import Path
from typing import Tuple, Dict, List

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import mean_absolute_error, mean_squared_error

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

plt.rcParams["figure.figsize"] = (11, 5)

RANDOM_STATE: int = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

DATA_PATH: Path = Path("data") / "graph_traffic.npz"
if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Expected dataset at {DATA_PATH.resolve()}\n"
        "It should contain 'traffic' (T,N or T,N,1) and 'adjacency' (N,N)."
    )

npz = np.load(DATA_PATH)
traffic = npz["traffic"]  # (T,N) or (T,N,1)
adjacency = npz["adjacency"]  # (N,N)

print("traffic shape:", traffic.shape)
print("adjacency shape:", adjacency.shape)

## 2. Basic preprocessing and EDA

We ensure a standard shape and visualise traffic at a few sensors.


In [ ]:
# Ensure traffic has shape (T, N)
if traffic.ndim == 3 and traffic.shape[-1] == 1:
    traffic = traffic[..., 0]
elif traffic.ndim != 2:
    raise ValueError("Expected traffic array with shape (T,N) or (T,N,1).")

T, N = traffic.shape
print(f"Using traffic array of shape T={T}, N={N} sensors.")

num_example_sensors: int = min(3, N)
time_index = np.arange(T)

for node_id in range(num_example_sensors):
    plt.plot(time_index[:24 * 7], traffic[:24 * 7, node_id], label=f"sensor {node_id}")

plt.title("Example week – traffic at a few sensors")
plt.xlabel("Time step (e.g. hour or 5-min interval)")
plt.ylabel("Traffic (e.g. speed or volume)")
plt.legend()
plt.show()

## 3. Train/val/test split and normalization

We split the time axis into train/validation/test segments and normalise
per-node using statistics from the training segment only.


In [ ]:
def split_time_series(
    data: np.ndarray,
    train_frac: float = 0.6,
    val_frac: float = 0.2,
) -> Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]:
    """Return index ranges (start, end) for train, val, test on time axis.

    The end index is exclusive. Fractions should sum to <= 1.0.
    """
    T = data.shape[0]
    train_end = int(T * train_frac)
    val_end = train_end + int(T * val_frac)
    train_range = (0, train_end)
    val_range = (train_end, val_end)
    test_range = (val_end, T)
    return train_range, val_range, test_range


train_range, val_range, test_range = split_time_series(traffic)
train_range, val_range, test_range

In [ ]:
def compute_normalization_stats(train_data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute per-node mean and std from training data.

    Parameters
    ----------
    train_data : np.ndarray
        Array of shape (T_train, N).

    Returns
    -------
    mean : np.ndarray
        Mean per node, shape (N,).
    std : np.ndarray
        Standard deviation per node, shape (N,), with minimum floor applied.
    """
    mean = train_data.mean(axis=0)
    std = train_data.std(axis=0)
    std[std < 1e-3] = 1e-3
    return mean, std


train_data = traffic[train_range[0] : train_range[1]]
mean_nodes, std_nodes = compute_normalization_stats(train_data)

traffic_norm = (traffic - mean_nodes) / std_nodes
traffic_norm.shape

## 4. Sequence dataset for spatio-temporal forecasting

We now build a dataset of sliding windows:

- Input: past `input_len` steps for all nodes.
- Target: next `horizon` steps for all nodes.


In [ ]:
class SpatioTemporalDataset(Dataset):
    """Dataset of spatio-temporal sequences for graph forecasting.

    Each item is a pair `(X, Y)`:

    - `X` has shape `(input_len, N, 1)` – past inputs.
    - `Y` has shape `(horizon, N, 1)` – future targets.
    """

    def __init__(
        self,
        data: np.ndarray,
        time_range: Tuple[int, int],
        input_len: int,
        horizon: int,
    ) -> None:
        super().__init__()
        self.data = data.astype(np.float32)
        self.start, self.end = time_range
        self.input_len = input_len
        self.horizon = horizon

        max_start = self.end - self.start - input_len - horizon + 1
        if max_start <= 0:
            raise ValueError("Time range too small for given input_len and horizon.")
        self.indices = np.arange(self.start, self.start + max_start)

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        start_idx = int(self.indices[idx])
        x = self.data[start_idx : start_idx + self.input_len]
        y = self.data[start_idx + self.input_len : start_idx + self.input_len + self.horizon]

        x = x[..., None]
        y = y[..., None]
        return torch.from_numpy(x), torch.from_numpy(y)


INPUT_LEN: int = 12
HORIZON: int = 12
BATCH_SIZE: int = 64

train_dataset = SpatioTemporalDataset(traffic_norm, train_range, INPUT_LEN, HORIZON)
val_dataset = SpatioTemporalDataset(traffic_norm, val_range, INPUT_LEN, HORIZON)
test_dataset = SpatioTemporalDataset(traffic_norm, test_range, INPUT_LEN, HORIZON)

len(train_dataset), len(val_dataset), len(test_dataset)

In [ ]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## 5. Graph convolution and ST-GCN block


In [ ]:
def normalize_adjacency(adj: np.ndarray, add_self_loops: bool = True) -> np.ndarray:
    """Compute symmetric normalised adjacency: D^{-1/2} (A + I) D^{-1/2}."""
    A = adj.astype(np.float32)
    if add_self_loops:
        A = A + np.eye(A.shape[0], dtype=np.float32)
    d = A.sum(axis=1)
    d_inv_sqrt = np.power(d, -0.5)
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
    D_inv_sqrt = np.diag(d_inv_sqrt)
    return D_inv_sqrt @ A @ D_inv_sqrt


A_hat_np = normalize_adjacency(adjacency)
A_hat = torch.tensor(A_hat_np, dtype=torch.float32)
A_hat.shape

In [ ]:
class GraphConv(nn.Module):
    """Simple graph convolution layer using a precomputed A_hat.

    Given node features X of shape (B, N, F_in), computes X' = A_hat @ X @ W.
    """

    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)

    def forward(self, x: torch.Tensor, A_hat: torch.Tensor) -> torch.Tensor:
        xw = self.linear(x)
        out = torch.einsum("ij,bjf->bif", A_hat, xw)
        return out


class STGCNBlock(nn.Module):
    """Spatio-temporal block combining temporal and graph convolutions.

    Input:  (B, T, N, F_in)
    Output: (B, T, N, F_out)
    """

    def __init__(
        self,
        in_channels: int,
        spatial_channels: int,
        out_channels: int,
        kernel_size: int = 3,
    ) -> None:
        super().__init__()
        self.temporal1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=spatial_channels,
            kernel_size=(kernel_size, 1),
            padding=(kernel_size - 1, 0),
        )
        self.gconv = GraphConv(spatial_channels, spatial_channels)
        self.temporal2 = nn.Conv2d(
            in_channels=spatial_channels,
            out_channels=out_channels,
            kernel_size=(kernel_size, 1),
            padding=(kernel_size - 1, 0),
        )
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor, A_hat: torch.Tensor) -> torch.Tensor:
        x = x.permute(0, 3, 1, 2)
        x = self.temporal1(x)
        x = self.relu(x)

        x = x.permute(0, 2, 3, 1)
        B, T_seq, N_nodes, F_sp = x.shape
        x_flat = x.reshape(B * T_seq, N_nodes, F_sp)
        x_gc = self.gconv(x_flat, A_hat)
        x = x_gc.reshape(B, T_seq, N_nodes, F_sp)

        x = x.permute(0, 3, 1, 2)
        x = self.temporal2(x)
        x = self.relu(x)
        x = x.permute(0, 2, 3, 1)
        return x


## 6. Full ST-GCN model


In [ ]:
class STGCN(nn.Module):
    """Spatio-Temporal GCN for multi-step traffic forecasting.

    Input:  (B, input_len, N, 1)
    Output: (B, horizon, N, 1)
    """

    def __init__(
        self,
        num_nodes: int,
        input_len: int,
        horizon: int,
        in_channels: int = 1,
        spatial_channels: int = 16,
        hidden_channels: int = 32,
    ) -> None:
        super().__init__()
        self.num_nodes = num_nodes
        self.input_len = input_len
        self.horizon = horizon

        self.block1 = STGCNBlock(
            in_channels=in_channels,
            spatial_channels=spatial_channels,
            out_channels=hidden_channels,
        )
        self.block2 = STGCNBlock(
            in_channels=hidden_channels,
            spatial_channels=spatial_channels,
            out_channels=hidden_channels,
        )

        self.temporal_projection = nn.Conv2d(
            in_channels=hidden_channels,
            out_channels=horizon,
            kernel_size=(1, 1),
        )

    def forward(self, x: torch.Tensor, A_hat: torch.Tensor) -> torch.Tensor:
        out = self.block1(x, A_hat)
        out = self.block2(out, A_hat)
        out = out.permute(0, 3, 1, 2)
        out = self.temporal_projection(out)
        out = out[:, :, -1, :]
        out = out[..., None]
        return out


## 7. Training loop and metrics


In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = STGCN(num_nodes=N, input_len=INPUT_LEN, horizon=HORIZON).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

A_hat_device = A_hat.to(device)

def denormalize(data_norm: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    return data_norm * std + mean


def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    mae = mean_absolute_error(y_true, y_pred)
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    return {"mae": float(mae), "rmse": rmse}


def train_epoch(model: nn.Module, loader: DataLoader) -> float:
    model.train()
    losses: List[float] = []
    for X_batch, Y_batch in loader:
        X_batch = X_batch.to(device)
        Y_batch = Y_batch.to(device)
        optimizer.zero_grad()
        preds = model(X_batch, A_hat_device)
        loss = criterion(preds, Y_batch)
        loss.backward()
        optimizer.step()
        losses.append(float(loss.item()))
    return float(np.mean(losses))


def eval_epoch(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    losses: List[float] = []
    with torch.no_grad():
        for X_batch, Y_batch in loader:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)
            preds = model(X_batch, A_hat_device)
            loss = criterion(preds, Y_batch)
            losses.append(float(loss.item()))
    return float(np.mean(losses))


EPOCHS: int = 10
for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, train_loader)
    val_loss = eval_epoch(model, val_loader)
    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

### 7.1 Test-set evaluation


In [ ]:
model.eval()
y_true_list: List[np.ndarray] = []
y_pred_list: List[np.ndarray] = []

with torch.no_grad():
    for X_batch, Y_batch in test_loader:
        X_batch = X_batch.to(device)
        preds = model(X_batch, A_hat_device)
        y_true_list.append(Y_batch.numpy())
        y_pred_list.append(preds.cpu().numpy())

y_true_arr = np.concatenate(y_true_list, axis=0)
y_pred_arr = np.concatenate(y_pred_list, axis=0)

y_true_flat = y_true_arr[..., 0].reshape(-1, N)
y_pred_flat = y_pred_arr[..., 0].reshape(-1, N)

y_true_denorm = denormalize(y_true_flat, mean_nodes, std_nodes)
y_pred_denorm = denormalize(y_pred_flat, mean_nodes, std_nodes)

stgcn_metrics = compute_metrics(y_true_denorm.ravel(), y_pred_denorm.ravel())
print("STGCN test metrics:", stgcn_metrics)

## 8. Persistence baseline


In [ ]:
y_true_p_list: List[np.ndarray] = []
y_pred_p_list: List[np.ndarray] = []

for X_batch, Y_batch in test_loader:
    last_step = X_batch[:, -1, :, :]
    B_batch = X_batch.shape[0]
    last_rep = last_step[:, None, :, :].repeat(1, HORIZON, 1, 1)
    y_true_p_list.append(Y_batch.numpy())
    y_pred_p_list.append(last_rep.numpy())

y_true_p = np.concatenate(y_true_p_list, axis=0)
y_pred_p = np.concatenate(y_pred_p_list, axis=0)

y_true_p_flat = y_true_p[..., 0].reshape(-1, N)
y_pred_p_flat = y_pred_p[..., 0].reshape(-1, N)

y_true_p_denorm = denormalize(y_true_p_flat, mean_nodes, std_nodes)
y_pred_p_denorm = denormalize(y_pred_p_flat, mean_nodes, std_nodes)

baseline_metrics = compute_metrics(y_true_p_denorm.ravel(), y_pred_p_denorm.ravel())
print("Persistence baseline test metrics:", baseline_metrics)

## 9. Visualising forecasts at a sample sensor


In [ ]:
sensor_id: int = 0
steps_to_plot: int = 200

y_true_h1 = y_true_denorm[:, sensor_id]
y_pred_h1 = y_pred_denorm[:, sensor_id]

plt.plot(y_true_h1[:steps_to_plot], label="actual")
plt.plot(y_pred_h1[:steps_to_plot], label="STGCN", linestyle="--")
plt.title(f"One-step-ahead forecast at sensor {sensor_id}")
plt.xlabel("Sequence index")
plt.ylabel("Traffic (denormalised)")
plt.legend()
plt.show()

## 10. Summary and extensions

In this notebook we:

- Loaded a **graph-based traffic dataset** (METR-LA / PEMS-style).
- Built a **spatio-temporal supervised dataset**.
- Implemented a simple **Spatio-Temporal GCN** model.
- Compared it to a spatio-temporal **persistence baseline**.
- Visualised one-step-ahead forecasts for a sample sensor.

Possible extensions:

- Add dilated temporal convolutions and residual connections.
- Use multiple adjacency matrices (distance, correlation, connectivity).
- Introduce time-of-day and day-of-week embeddings as extra inputs.
- Explore more advanced architectures such as Graph WaveNet or MTGNN.
