# Multi-Horizon Traffic Forecasting on PeMS (Graph Models)

## Goal (Paper Claim)
Build a leakage-safe, reproducible pipeline on PeMS traffic data and evaluate multi-horizon forecasting models fairly.

Primary goal:
- Demonstrate the proposed **GraphWaveNet-GRU-LSTM** performs best on PeMS under the same train/val/test protocol.

Key principles:
- No time leakage (all statistics computed from train only).
- One shared dataset representation for all deep models: **X ∈ R^{T×N×F}, Y ∈ R^{T×N}**.
- One fixed evaluation harness (same horizons, same metrics, same seeds).
- Strong baselines + ablations:
  - HA / Persistence
  - GRU / LSTM (non-graph)
  - GraphWaveNet
  - GraphWaveNet+GRU
  - GraphWaveNet+LSTM
  - **GraphWaveNet+GRU+LSTM (proposed)**


In [1]:
!pip -q install -r requirements.txt


[0m

In [2]:
!pip -q install numpy pandas openpyxl scikit-learn torch tqdm


[0m

In [3]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd

import torch
from tqdm.auto import tqdm

def set_seed(seed: int = 42, deterministic: bool = True):
    """
    Sets seeds for reproducibility.
    deterministic=True makes results more reproducible but can reduce speed.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

SEED = 42
set_seed(SEED, deterministic=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


## Configuration

We fix:
- Input window length (`IN_LEN`) and forecast horizon length (`OUT_LEN`)
- Train/val/test boundaries (time-based split)
- Station inclusion rule (coverage threshold)
- Output dataset artifact path (so every model uses the same processed dataset)

Important:
GraphWaveNet expects a consistent node set and continuous time axis,
so we build a clean matrix format (timestamp × station).


In [4]:
# -------------------------
# Paths (your files are visible in the Paperspace file pane)
# -------------------------
TRAFFIC_CSV = Path("cleaned_traffic_data.csv")
META_XLSX   = Path("pems_output.xlsx")

assert TRAFFIC_CSV.exists(), f"Missing {TRAFFIC_CSV}"
assert META_XLSX.exists(), f"Missing {META_XLSX}"

# -------------------------
# Split boundaries (same as your earlier work)
# -------------------------
TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

# -------------------------
# Forecast setup
# -------------------------
IN_LEN  = 24     # hours of history used as input
OUT_LEN = 72     # predict next 72 hours (we will evaluate at 12/24/48/72)

EVAL_HORIZONS = [12, 24, 48, 72]  # hours ahead

# -------------------------
# Station coverage threshold
# -------------------------
# 1.0 means station must have ALL timestamps present.
# 0.98 is often a good compromise if some stations are missing few points.
COVERAGE_THRESHOLD = 0.98

# -------------------------
# Adjacency setup (static graph baseline)
# -------------------------
K_NEIGHBORS = 2   # connect up to 2 upstream + 2 downstream along the freeway chain

# -------------------------
# Output artifact (important for reproducibility)
# -------------------------
OUT_DIR = Path("artifacts")
OUT_DIR.mkdir(exist_ok=True)

DATASET_NPZ = OUT_DIR / "pems_graph_dataset.npz"
print("Will save processed dataset to:", DATASET_NPZ)


Will save processed dataset to: artifacts/pems_graph_dataset.npz


## Load raw traffic + metadata

We:
1) Load cleaned traffic data
2) Load station metadata
3) Standardize column names
4) Merge metadata onto traffic records (inner join so every station has metadata)
5) Verify timestamp parsing and basic integrity checks


In [5]:
def require_col(df: pd.DataFrame, candidates, friendly_name: str):
    """
    Find the first matching column in candidates.
    Raise a helpful error if not found.
    """
    for c in candidates:
        if c in df.columns:
            return c
    raise KeyError(
        f"Could not find column for '{friendly_name}'. Tried: {candidates}\n"
        f"Available columns: {list(df.columns)}"
    )

def to_datetime_safe(s: pd.Series) -> pd.Series:
    return pd.to_datetime(s, errors="coerce")

def pct_missing(s: pd.Series) -> float:
    return float(s.isna().mean() * 100.0)



In [6]:
traffic_raw = pd.read_csv(TRAFFIC_CSV)
meta_raw = pd.read_excel(META_XLSX)

print("Traffic shape:", traffic_raw.shape)
print("Meta shape:", meta_raw.shape)

# --- Identify expected columns robustly ---
ts_col   = require_col(traffic_raw, ["Timestamp", "timestamp", "Time", "Datetime"], "Timestamp")
st_col   = require_col(traffic_raw, ["Station", "station", "ID"], "Station ID")
flow_col = require_col(traffic_raw, ["Total Flow", "total_flow", "Flow", "total flow"], "Total Flow")
spd_col  = require_col(traffic_raw, ["Avg Speed", "avg_speed", "Speed", "Avg speed"], "Avg Speed")

lane_col = require_col(traffic_raw, ["Lane Type", "lane_type", "LaneType"], "Lane Type")
dir_col  = require_col(traffic_raw, ["Direction of Travel", "direction", "Dir"], "Direction")
dist_col = require_col(traffic_raw, ["District", "district"], "District")

# --- Standardize traffic ---
traffic = traffic_raw.rename(columns={
    ts_col: "timestamp",
    st_col: "station",
    flow_col: "total_flow",
    spd_col: "avg_speed",
    lane_col: "lane_type",
    dir_col: "direction",
    dist_col: "district",
}).copy()

traffic["timestamp"] = to_datetime_safe(traffic["timestamp"])
traffic["station"] = pd.to_numeric(traffic["station"], errors="coerce").astype("Int64")

traffic = traffic.dropna(subset=["timestamp", "station"]).copy()
traffic["station"] = traffic["station"].astype(int)

print("After basic parsing:", traffic.shape)
print("Timestamp range:", traffic["timestamp"].min(), "→", traffic["timestamp"].max())

# --- Standardize metadata ---
# station id in metadata usually is 'ID'
meta_id_col = require_col(meta_raw, ["ID", "station", "Station"], "Meta Station ID")
meta = meta_raw.rename(columns={meta_id_col: "station"}).copy()
meta["station"] = pd.to_numeric(meta["station"], errors="coerce").astype("Int64")
meta = meta.dropna(subset=["station"]).copy()
meta["station"] = meta["station"].astype(int)

print("Meta columns (peek):", list(meta.columns)[:20])

# Merge metadata (inner ensures we only keep stations that have metadata)
df = traffic.merge(meta, on="station", how="inner", validate="m:1")
print("Merged df shape:", df.shape)
print("Unique stations:", df["station"].nunique())


Traffic shape: (4114680, 42)
Meta shape: (1861, 15)
After basic parsing: (4114680, 42)
Timestamp range: 2024-10-01 00:00:00 → 2024-12-31 23:00:00
Meta columns (peek): ['Fwy', 'District', 'County', 'City', 'CA PM', 'Abs PM', 'Length', 'station', 'Name', 'Lanes', 'Type', 'Sensor Type', 'HOV', 'MS ID', 'IRM']
Merged df shape: (4051621, 56)
Unique stations: 1861


## Sanity checks

We check:
- Duplicate rows per (timestamp, station)
- Time frequency (hourly vs not)
- Missingness rates
These checks prevent silent data problems that can invalidate results.


In [7]:
# 1) Duplicates by (timestamp, station)
dup_count = df.duplicated(subset=["timestamp", "station"]).sum()
print("Duplicate (timestamp, station) rows:", int(dup_count))

if dup_count > 0:
    # Resolve duplicates safely: flow sums, speed averages
    df = (df.groupby(["timestamp", "station"], as_index=False)
            .agg({
                "total_flow": "sum",
                "avg_speed": "mean",
                "lane_type": "first",
                "direction": "first",
                "district": "first",
                # keep metadata columns by first
                **{c: "first" for c in meta.columns if c != "station"}
            }))
    print("After de-duplication:", df.shape)

# 2) Check time deltas
times = pd.DatetimeIndex(sorted(df["timestamp"].unique()))
deltas = pd.Series(times[1:] - times[:-1]).value_counts().head(5)
print("Most common timestamp deltas:\n", deltas)

# 3) Missingness
print("Missing total_flow (%):", pct_missing(df["total_flow"]))
print("Missing avg_speed (%):", pct_missing(df["avg_speed"]))


Duplicate (timestamp, station) rows: 0
Most common timestamp deltas:
 0 days 01:00:00    2207
Name: count, dtype: int64
Missing total_flow (%): 7.243150334150209
Missing avg_speed (%): 37.88496011843161


## Build station-time matrices

Graph models require a clean tensor format.
We create two matrices:
- Flow:  (T timestamps × N stations)
- Speed: (T timestamps × N stations)

We also select a stable station set using a coverage threshold.


In [8]:
# Full timestamp index
all_times = pd.DatetimeIndex(sorted(df["timestamp"].unique()))
T = len(all_times)

# Station coverage
counts = df.groupby("station")["timestamp"].nunique()
coverage = counts / T

keep_stations = coverage[coverage >= COVERAGE_THRESHOLD].index
df2 = df[df["station"].isin(keep_stations)].copy()

stations = np.array(sorted(df2["station"].unique()), dtype=int)
N = len(stations)

print(f"Timestamps (T) = {T}")
print(f"Stations kept (N) = {N}  (coverage threshold={COVERAGE_THRESHOLD})")

# Build matrices
flow = (df2.pivot(index="timestamp", columns="station", values="total_flow")
          .reindex(index=all_times, columns=stations)
          .sort_index())

speed = (df2.pivot(index="timestamp", columns="station", values="avg_speed")
           .reindex(index=all_times, columns=stations)
           .sort_index())

print("Flow matrix:", flow.shape, "Speed matrix:", speed.shape)
print("Flow missing fraction:", float(np.isnan(flow.to_numpy()).mean()))
print("Speed missing fraction:", float(np.isnan(speed.to_numpy()).mean()))


Timestamps (T) = 2208
Stations kept (N) = 1821  (coverage threshold=0.98)
Flow matrix: (2208, 1821) Speed matrix: (2208, 1821)
Flow missing fraction: 0.0712893656137335
Speed missing fraction: 0.3772684720928937


## Leakage-safe imputation

We must not use validation/test information when estimating fill values.

Strategy:
- Forward-fill across time (realistic streaming behavior).
- Remaining NaNs filled using TRAIN statistics only.

Flow:
- ffill → fill with per-station TRAIN mean → fill with global TRAIN mean

Speed:
- ffill → fill using a TRAIN-only group lookup (lane_type, meta type, hour, fwy, district)
- then per-station TRAIN mean → global TRAIN mean


In [9]:
# Identify metadata columns we need for speed lookup
meta_type_col = None
for cand in ["Type", "type", "Station Type"]:
    if cand in df2.columns:
        meta_type_col = cand
        break

fwy_col = None
for cand in ["Fwy", "FWY", "fwy", "Freeway"]:
    if cand in df2.columns:
        fwy_col = cand
        break

if meta_type_col is None or fwy_col is None:
    raise KeyError(f"Missing metadata columns for speed lookup. Found meta_type={meta_type_col}, fwy={fwy_col}")

train_time_mask = flow.index <= TRAIN_END

# -------------------------
# Flow imputation
# -------------------------
flow_ff = flow.ffill()

flow_train_mean_station = flow_ff.loc[train_time_mask].mean(axis=0)
flow_train_mean_global = flow_ff.loc[train_time_mask].stack().mean()

flow_imp = flow_ff.fillna(flow_train_mean_station).fillna(flow_train_mean_global)

# -------------------------
# Speed lookup (TRAIN only)
# -------------------------
train_rows = df2[df2["timestamp"] <= TRAIN_END].copy()
train_rows["hour"] = train_rows["timestamp"].dt.hour

speed_grp_cols = ["lane_type", meta_type_col, "hour", fwy_col, "district"]
speed_lookup = train_rows.groupby(speed_grp_cols)["avg_speed"].mean()

global_speed_train_mean = train_rows["avg_speed"].mean()

# Station-level "mode" descriptors used when applying the lookup
station_info = (df2.groupby("station")
                  .agg(
                      lane_type=("lane_type", lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      meta_type=(meta_type_col, lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      fwy=(fwy_col, lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      district=("district", lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                  )
                  .reindex(stations))

speed_ff = speed.ffill()
speed_np = speed_ff.to_numpy(dtype=np.float32)
miss = np.isnan(speed_np)
hours = speed_ff.index.hour.values

# Fill with lookup
for j, st in enumerate(stations):
    if not miss[:, j].any():
        continue
    info = station_info.loc[st]
    lane_type = info["lane_type"]
    meta_type = info["meta_type"]
    fwy = info["fwy"]
    district = info["district"]

    idxs = np.where(miss[:, j])[0]
    fill_vals = []
    for t_idx in idxs:
        h = int(hours[t_idx])
        key = (lane_type, meta_type, h, fwy, district)
        fill_vals.append(speed_lookup.get(key, np.nan))
    speed_np[idxs, j] = np.array(fill_vals, dtype=np.float32)

speed_imp = pd.DataFrame(speed_np, index=speed_ff.index, columns=speed_ff.columns)

# Remaining NaNs → station TRAIN mean → global TRAIN mean
speed_train_mean_station = speed_imp.loc[train_time_mask].mean(axis=0)
speed_imp = speed_imp.fillna(speed_train_mean_station).fillna(global_speed_train_mean)

print("After imputation:")
print("Flow missing fraction:", float(np.isnan(flow_imp.to_numpy()).mean()))
print("Speed missing fraction:", float(np.isnan(speed_imp.to_numpy()).mean()))


After imputation:
Flow missing fraction: 0.0
Speed missing fraction: 0.0


## Build graph-ready tensors

We create:
- X: (T, N, F)
  Features include:
  - flow (1)
  - speed (1)
  - time encodings: hour_sin, hour_cos, dow_sin, dow_cos (4)
  Total F = 6

- Y: (T, N)
  Target is flow at each station.

Later, each training sample is a sliding window:
- Input:  X[t : t+IN_LEN]
- Output: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]


In [10]:
def make_time_features(timestamps: pd.DatetimeIndex) -> np.ndarray:
    hours = timestamps.hour.values
    dow   = timestamps.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)  # (T,4)

time_feats = make_time_features(flow_imp.index)  # (T,4)
time_feats_b = np.repeat(time_feats[:, None, :], repeats=N, axis=1)  # (T,N,4)

flow_arr  = flow_imp.to_numpy(dtype=np.float32)[:, :, None]   # (T,N,1)
speed_arr = speed_imp.to_numpy(dtype=np.float32)[:, :, None]  # (T,N,1)

X = np.concatenate([flow_arr, speed_arr, time_feats_b], axis=2)  # (T,N,6)
Y = flow_arr.squeeze(-1).astype(np.float32)                      # (T,N)

print("X shape:", X.shape, " (T,N,F)")
print("Y shape:", Y.shape, " (T,N)")


X shape: (2208, 1821, 6)  (T,N,F)
Y shape: (2208, 1821)  (T,N)


## Build adjacency matrix A (static graph baseline)

We build a physical-neighborhood adjacency using metadata:
- Sort stations by (freeway, absolute postmile)
- Connect K neighbors upstream + downstream
- Weight edges using a Gaussian kernel of distance
- Add self-loops

Note:
GraphWaveNet can also learn an adaptive adjacency; this static graph is a strong baseline.


In [11]:
def build_adjacency_from_metadata(meta_df: pd.DataFrame, stations: np.ndarray, k_neighbors: int = 2) -> np.ndarray:
    """
    Build adjacency within each freeway chain using Abs PM order.
    Edge weights = exp(-(dist^2 / sigma^2)), sigma = median neighbor distance.
    """
    # Find needed meta columns
    id_col = "station"
    abs_pm_col = None
    for cand in ["Abs PM", "abs_pm", "AbsPM", "Postmile", "PM"]:
        if cand in meta_df.columns:
            abs_pm_col = cand
            break
    fwy_col2 = None
    for cand in ["Fwy", "FWY", "fwy", "Freeway"]:
        if cand in meta_df.columns:
            fwy_col2 = cand
            break

    if abs_pm_col is None or fwy_col2 is None:
        raise KeyError(f"Metadata missing Abs PM or Fwy columns. Found AbsPM={abs_pm_col}, Fwy={fwy_col2}")

    meta_sub = meta_df[meta_df[id_col].isin(stations)].copy()
    meta_sub["abs_pm"] = pd.to_numeric(meta_sub[abs_pm_col], errors="coerce")
    meta_sub["fwy"] = meta_sub[fwy_col2].astype(str)

    # station index map
    station_to_idx = {s: i for i, s in enumerate(stations)}
    N = len(stations)
    A = np.zeros((N, N), dtype=np.float32)

    # estimate sigma from typical neighbor distances
    all_dists = []
    for fwy, grp in meta_sub.sort_values(["fwy", "abs_pm"]).groupby("fwy"):
        pm = grp["abs_pm"].dropna().values
        if len(pm) < 2:
            continue
        d = np.diff(np.sort(pm))
        d = d[d > 0]
        all_dists.extend(d.tolist())

    sigma = float(np.median(all_dists)) if len(all_dists) else 0.5
    sigma = max(sigma, 1e-3)

    def w(dist):  # gaussian weight
        return float(np.exp(- (dist**2) / (sigma**2)))

    # connect neighbors
    for fwy, grp in meta_sub.sort_values(["fwy", "abs_pm"]).groupby("fwy"):
        grp = grp.dropna(subset=["abs_pm"]).sort_values("abs_pm")
        ids = grp[id_col].astype(int).tolist()
        pms = grp["abs_pm"].astype(float).tolist()

        for i, sid in enumerate(ids):
            ii = station_to_idx[sid]
            for step in range(1, k_neighbors + 1):
                if i - step >= 0:
                    sj = ids[i - step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i-step]))
                if i + step < len(ids):
                    sj = ids[i + step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i+step]))

    # self loops + symmetrize
    np.fill_diagonal(A, 1.0)
    A = np.maximum(A, A.T)
    return A

# metadata table for adjacency should be meta with standardized station column
meta_for_adj = meta.copy()
meta_for_adj["station"] = meta_for_adj["station"].astype(int)

A = build_adjacency_from_metadata(meta_for_adj, stations=stations, k_neighbors=K_NEIGHBORS)
print("A shape:", A.shape)
print("Adjacency density (A>0):", float((A > 0).mean()))


A shape: (1821, 1821)
Adjacency density (A>0): 0.0023693916932872663


## Sliding windows + splits

Each sample uses:
- Input window:  X[t : t+IN_LEN]
- Output window: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]

We split by the **time of the first predicted hour** (t + IN_LEN):
- Train if output_start_time ≤ TRAIN_END
- Val   if TRAIN_END < output_start_time ≤ VAL_END
- Test  if output_start_time > VAL_END

Then we save everything to a single `.npz` artifact so every model reads the exact same dataset.


In [12]:
# Sliding window starts
T_total = X.shape[0]
max_t = T_total - (IN_LEN + OUT_LEN) + 1
starts = np.arange(max_t, dtype=np.int32)

timestamps = pd.DatetimeIndex(flow_imp.index)
out_start_times = timestamps[starts + IN_LEN]

train_starts = starts[out_start_times <= TRAIN_END]
val_starts   = starts[(out_start_times > TRAIN_END) & (out_start_times <= VAL_END)]
test_starts  = starts[out_start_times > VAL_END]

print(f"Window starts: train={len(train_starts)}, val={len(val_starts)}, test={len(test_starts)}")

# Train-only scalers (per node) for flow and speed (channels 0 and 1)
train_time_mask = timestamps <= TRAIN_END

flow_mean = X[train_time_mask, :, 0].mean(axis=0).astype(np.float32)
flow_std  = (X[train_time_mask, :, 0].std(axis=0) + 1e-6).astype(np.float32)

speed_mean = X[train_time_mask, :, 1].mean(axis=0).astype(np.float32)
speed_std  = (X[train_time_mask, :, 1].std(axis=0) + 1e-6).astype(np.float32)

np.savez_compressed(
    DATASET_NPZ,
    X=X.astype(np.float32),
    Y=Y.astype(np.float32),
    A=A.astype(np.float32),
    stations=stations.astype(np.int32),
    timestamps=np.array(timestamps.astype("datetime64[ns]")),
    train_starts=train_starts,
    val_starts=val_starts,
    test_starts=test_starts,
    in_len=np.array([IN_LEN], dtype=np.int32),
    out_len=np.array([OUT_LEN], dtype=np.int32),
    flow_mean=flow_mean, flow_std=flow_std,
    speed_mean=speed_mean, speed_std=speed_std,
    seed=np.array([SEED], dtype=np.int32),
)

print("Saved:", DATASET_NPZ)


Window starts: train=1080, val=360, test=673
Saved: artifacts/pems_graph_dataset.npz


## Fix window split leakage (strict horizon containment)

A window starting at time t uses:
- Input:  X[t : t+IN_LEN]
- Output: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]

To prevent label leakage across train/val/test boundaries, we require:
- Train: output_end_time ≤ TRAIN_END
- Val:   output_start_time > TRAIN_END AND output_end_time ≤ VAL_END
- Test:  output_start_time > VAL_END


In [13]:
import numpy as np
import pandas as pd
from pathlib import Path

DATASET_NPZ = Path("artifacts/pems_graph_dataset.npz")
DATASET_NPZ_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")

data = np.load(DATASET_NPZ, allow_pickle=True)

X = data["X"]
Y = data["Y"]
A = data["A"]
stations = data["stations"]
timestamps = pd.to_datetime(data["timestamps"])

IN_LEN = int(data["in_len"][0])
OUT_LEN = int(data["out_len"][0])

flow_mean = data["flow_mean"]
flow_std  = data["flow_std"]
speed_mean = data["speed_mean"]
speed_std  = data["speed_std"]

T_total = X.shape[0]
max_t = T_total - (IN_LEN + OUT_LEN) + 1
starts = np.arange(max_t, dtype=np.int32)

out_start_times = timestamps[starts + IN_LEN]
out_end_times   = timestamps[starts + IN_LEN + OUT_LEN - 1]

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

# Strict splits
train_starts = starts[out_end_times <= TRAIN_END]
val_starts   = starts[(out_start_times > TRAIN_END) & (out_end_times <= VAL_END)]
test_starts  = starts[out_start_times > VAL_END]

print("STRICT window starts:")
print("train:", len(train_starts))
print("val:  ", len(val_starts))
print("test: ", len(test_starts))

np.savez_compressed(
    DATASET_NPZ_STRICT,
    X=X.astype(np.float32),
    Y=Y.astype(np.float32),
    A=A.astype(np.float32),
    stations=stations.astype(np.int32),
    timestamps=np.array(timestamps.astype("datetime64[ns]")),
    train_starts=train_starts,
    val_starts=val_starts,
    test_starts=test_starts,
    in_len=np.array([IN_LEN], dtype=np.int32),
    out_len=np.array([OUT_LEN], dtype=np.int32),
    flow_mean=flow_mean.astype(np.float32),
    flow_std=flow_std.astype(np.float32),
    speed_mean=speed_mean.astype(np.float32),
    speed_std=speed_std.astype(np.float32),
)

print("Saved strict dataset to:", DATASET_NPZ_STRICT)


STRICT window starts:
train: 1009
val:   289
test:  673
Saved strict dataset to: artifacts/pems_graph_dataset_strict.npz


## Load strict dataset artifact

We will only use the strict `.npz` going forward to ensure no leakage in labels.


In [14]:
import numpy as np
import pandas as pd
from pathlib import Path

DATASET_NPZ_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")
d = np.load(DATASET_NPZ_STRICT, allow_pickle=True)

X = d["X"]          # (T,N,F)
Y = d["Y"]          # (T,N)
A = d["A"]          # (N,N)
stations = d["stations"]
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"]
val_starts   = d["val_starts"]
test_starts  = d["test_starts"]

IN_LEN = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean = d["flow_mean"]  # (N,)
flow_std  = d["flow_std"]   # (N,)

print("X:", X.shape, "Y:", Y.shape, "A:", A.shape)
print("starts:", len(train_starts), len(val_starts), len(test_starts))
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)


X: (2208, 1821, 6) Y: (2208, 1821) A: (1821, 1821)
starts: 1009 289 673
IN_LEN: 24 OUT_LEN: 72


## Baseline evaluation

We evaluate at horizons: 12, 24, 48, 72 hours ahead.

Important detail:
Our output sequence begins at +1 hour ahead of the last input time.
So horizon `h` corresponds to output index `h-1` in the 72-step target.


In [15]:
import numpy as np
from tqdm.auto import tqdm

EVAL_HORIZONS = [12, 24, 48, 72]

def init_metric_accumulators(horizons):
    return {
        h: {"abs_sum": 0.0, "sq_sum": 0.0, "count": 0}
        for h in horizons
    }

def finalize_metrics(acc):
    out = {}
    for h, v in acc.items():
        mae = v["abs_sum"] / max(v["count"], 1)
        rmse = np.sqrt(v["sq_sum"] / max(v["count"], 1))
        out[h] = {"MAE": mae, "RMSE": rmse}
    return out

def print_metrics(title, metrics_dict):
    print("\n" + title)
    for h in sorted(metrics_dict.keys()):
        print(f"  {h:>3}h  MAE={metrics_dict[h]['MAE']:.3f}  RMSE={metrics_dict[h]['RMSE']:.3f}")


### Baseline 1 — Persistence

For each station:
- Predict that all future horizons equal the **last observed flow** in the input window.


In [16]:
def eval_persistence(X, Y, starts, in_len, horizons, desc=""):
    acc = init_metric_accumulators(horizons)

    for t in tqdm(starts, desc=desc):
        # last observed flow at the end of input window
        last_flow = X[t + in_len - 1, :, 0]  # (N,)

        for h in horizons:
            idx = h - 1
            true = Y[t + in_len + idx, :]     # (N,)
            pred = last_flow                  # (N,)

            err = pred - true
            acc[h]["abs_sum"] += float(np.abs(err).sum())
            acc[h]["sq_sum"]  += float((err ** 2).sum())
            acc[h]["count"]   += err.size

    return finalize_metrics(acc)

pers_val  = eval_persistence(X, Y, val_starts, IN_LEN, EVAL_HORIZONS, desc="Persistence (val)")
pers_test = eval_persistence(X, Y, test_starts, IN_LEN, EVAL_HORIZONS, desc="Persistence (test)")

print_metrics("Persistence — Validation", pers_val)
print_metrics("Persistence — Test", pers_test)


Persistence (val):   0%|          | 0/289 [00:00<?, ?it/s]

Persistence (test):   0%|          | 0/673 [00:00<?, ?it/s]


Persistence — Validation
   12h  MAE=910.133  RMSE=1437.403
   24h  MAE=151.755  RMSE=354.476
   48h  MAE=203.020  RMSE=451.021
   72h  MAE=220.150  RMSE=478.588

Persistence — Test
   12h  MAE=917.014  RMSE=1455.414
   24h  MAE=147.896  RMSE=340.247
   48h  MAE=200.592  RMSE=443.364
   72h  MAE=196.856  RMSE=431.611


### Baseline 2 — Historical Average (HA-168)

We compute a per-node seasonal mean using train data only:
- slot = (day_of_week * 24 + hour) ∈ [0..167]
- mean_flow[slot, node] = average flow in train for that slot

Forecast:
- For each horizon step, use the slot mean of that future timestamp.


In [17]:
def build_ha168_means(Y, timestamps, train_end):
    train_mask = timestamps <= train_end
    Y_train = Y[train_mask]  # (T_train, N)
    ts_train = timestamps[train_mask]

    slot = ts_train.dayofweek.to_numpy() * 24 + ts_train.hour.to_numpy()  # (T_train,)
    G = 168
    N = Y.shape[1]

    means = np.zeros((G, N), dtype=np.float32)
    counts = np.zeros((G,), dtype=np.int64)

    for g in range(G):
        m = (slot == g)
        if m.any():
            means[g] = Y_train[m].mean(axis=0)
            counts[g] = int(m.sum())
        else:
            # fallback (should be rare)
            means[g] = Y_train.mean(axis=0)
            counts[g] = 0
    return means, counts

def eval_ha168(Y, timestamps, starts, in_len, horizons, means_ha168, desc=""):
    acc = init_metric_accumulators(horizons)

    for t in tqdm(starts, desc=desc):
        for h in horizons:
            idx = h - 1
            future_time = timestamps[t + in_len + idx]
            g = int(future_time.dayofweek * 24 + future_time.hour)

            pred = means_ha168[g, :]                 # (N,)
            true = Y[t + in_len + idx, :]            # (N,)

            err = pred - true
            acc[h]["abs_sum"] += float(np.abs(err).sum())
            acc[h]["sq_sum"]  += float((err ** 2).sum())
            acc[h]["count"]   += err.size

    return finalize_metrics(acc)

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
ha_means, ha_counts = build_ha168_means(Y, timestamps, TRAIN_END)

ha_val  = eval_ha168(Y, timestamps, val_starts, IN_LEN, EVAL_HORIZONS, ha_means, desc="HA-168 (val)")
ha_test = eval_ha168(Y, timestamps, test_starts, IN_LEN, EVAL_HORIZONS, ha_means, desc="HA-168 (test)")

print_metrics("HA-168 — Validation", ha_val)
print_metrics("HA-168 — Test", ha_test)


HA-168 (val):   0%|          | 0/289 [00:00<?, ?it/s]

HA-168 (test):   0%|          | 0/673 [00:00<?, ?it/s]


HA-168 — Validation
   12h  MAE=116.424  RMSE=258.349
   24h  MAE=123.997  RMSE=277.816
   48h  MAE=134.454  RMSE=302.695
   72h  MAE=137.523  RMSE=306.085

HA-168 — Test
   12h  MAE=119.650  RMSE=283.458
   24h  MAE=120.382  RMSE=284.451
   48h  MAE=122.195  RMSE=286.807
   72h  MAE=126.187  RMSE=293.463


## PyTorch Dataset for sliding windows

Each item returns:
- x: (C, N, IN_LEN)   where C = number of features (6)
- y: (OUT_LEN, N)     scaled flow targets

Scaling:
- We scale flow and speed using TRAIN-only mean/std (per node)
- Time features remain unchanged (already bounded by sin/cos)


In [18]:
import torch
from torch.utils.data import Dataset, DataLoader

class PemsWindowDataset(Dataset):
    def __init__(self, X, Y, starts, in_len, out_len, flow_mean, flow_std, speed_mean, speed_std):
        self.X = X
        self.Y = Y
        self.starts = starts
        self.in_len = in_len
        self.out_len = out_len

        self.flow_mean = flow_mean.astype(np.float32)
        self.flow_std  = flow_std.astype(np.float32)
        self.speed_mean = speed_mean.astype(np.float32)
        self.speed_std  = speed_std.astype(np.float32)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])

        x = self.X[t : t + self.in_len].copy().astype(np.float32)  # (IN_LEN, N, F)
        y = self.Y[t + self.in_len : t + self.in_len + self.out_len].copy().astype(np.float32)  # (OUT_LEN, N)

        # scale input channels: flow=0, speed=1
        x[..., 0] = (x[..., 0] - self.flow_mean[None, :]) / self.flow_std[None, :]
        x[..., 1] = (x[..., 1] - self.speed_mean[None, :]) / self.speed_std[None, :]

        # scale targets (flow)
        y = (y - self.flow_mean[None, :]) / self.flow_std[None, :]

        # rearrange x to (C, N, IN_LEN) for conv models
        x = np.transpose(x, (2, 1, 0))  # (F, N, IN_LEN)

        return torch.from_numpy(x), torch.from_numpy(y)

# Load speed scaler too
speed_mean = d["speed_mean"]
speed_std  = d["speed_std"]

train_ds = PemsWindowDataset(X, Y, train_starts, IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
val_ds   = PemsWindowDataset(X, Y, val_starts,   IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
test_ds  = PemsWindowDataset(X, Y, test_starts,  IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)

BATCH_SIZE = 16

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape)
# Expect: x=(B, 6, N, 24) and y=(B, 72, N)


Batch x: torch.Size([16, 6, 1821, 24]) Batch y: torch.Size([16, 72, 1821])


In [19]:
import numpy as np

def mean_abs_change(Y, h):
    return float(np.abs(Y[h:] - Y[:-h]).mean())

for h in [1, 6, 12, 24, 48]:
    print(f"Mean |Δ| over {h}h shift: {mean_abs_change(Y, h):.3f}")


Mean |Δ| over 1h shift: 156.388
Mean |Δ| over 6h shift: 666.328
Mean |Δ| over 12h shift: 921.578
Mean |Δ| over 24h shift: 143.347
Mean |Δ| over 48h shift: 202.141


## Graph supports (normalized adjacency)

GraphWaveNet uses graph propagation through adjacency matrices ("supports").
We will build:
- A_rw  = row-normalized adjacency (random-walk normalization)
- A_rwT = transpose support (helps if graph is directed; still ok for symmetric graphs)

We store them as sparse tensors for speed (our adjacency is very sparse).


In [20]:
import numpy as np
import torch

def row_normalize_dense(A: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    d = A.sum(axis=1, keepdims=True)
    return A / (d + eps)

def dense_to_torch_sparse(A: np.ndarray, device: str):
    A = A.astype(np.float32)
    idx = np.nonzero(A)
    values = A[idx]
    indices = np.vstack(idx)  # (2, nnz)

    indices = torch.tensor(indices, dtype=torch.long, device=device)
    values  = torch.tensor(values, dtype=torch.float32, device=device)
    shape = A.shape

    sp = torch.sparse_coo_tensor(indices, values, size=shape, device=device).coalesce()
    return sp

A_rw = row_normalize_dense(A)
A_rwT = row_normalize_dense(A.T)

supports = [
    dense_to_torch_sparse(A_rw, DEVICE),
    dense_to_torch_sparse(A_rwT, DEVICE),
]

print("Supports:", [s.shape for s in supports], "nnz:", [int(s._nnz()) for s in supports])


Supports: [torch.Size([1821, 1821]), torch.Size([1821, 1821])] nnz: [7856, 7856]


## Diffusion graph convolution (sparse)

We need a fast way to compute:
A @ X  (graph propagation)

X is batched with shape (B, C, N, T).
We reshape into (N, B*C*T) so sparse matrix multiply works efficiently.


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NConv(nn.Module):
    """Sparse graph multiplication: (N,N) @ (B,C,N,T) -> (B,C,N,T)"""
    def forward(self, x, A_sp):
        # x: (B,C,N,T)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)      # (N, B*C*T)
        x_r = torch.sparse.mm(A_sp, x_r)                # (N, B*C*T)
        x_out = x_r.reshape(N, B, C, T).permute(1, 2, 0, 3)  # (B,C,N,T)
        return x_out

class DiffusionGraphConv(nn.Module):
    """
    Diffusion graph conv:
    concat [X, A1X, A2X, ...] then 1x1 conv.

    order=1 means we use only one hop per support (fast and stable).
    """
    def __init__(self, c_in, c_out, supports, order=1, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.order = order
        self.dropout = dropout

        # total input channels after concatenation
        c_total = c_in * (1 + len(supports) * order)
        self.mlp = nn.Conv2d(c_total, c_out, kernel_size=(1, 1))

    def forward(self, x):
        out = [x]
        for A in self.supports:
            x1 = self.nconv(x, A)
            out.append(x1)
            for _ in range(2, self.order + 1):
                x1 = self.nconv(x1, A)
                out.append(x1)

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h


## GraphWaveNet baseline model

We use:
- Causal dilated temporal convolutions (TCN-style)
- Gated activations (tanh * sigmoid)
- Residual + skip connections
- Diffusion graph convolution inside each layer

Output head:
- Uses the final time step embedding to predict OUT_LEN horizons directly
- Output shape: (B, OUT_LEN, N)


In [23]:
class CausalConv2d(nn.Module):
    """Causal conv along time axis only (last dimension)."""
    def __init__(self, c_in, c_out, kernel_size=2, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv2d(
            c_in, c_out,
            kernel_size=(1, kernel_size),
            dilation=(1, dilation)
        )

    def forward(self, x):
        # pad left on time dimension: (left_pad, right_pad, top, bottom) for 2d -> (time_left, time_right, node_left, node_right)
        x = F.pad(x, (self.pad, 0, 0, 0))
        return self.conv(x)

class GraphWaveNet(nn.Module):
    def __init__(
        self,
        num_nodes: int,
        in_dim: int,
        out_len: int,
        supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=3,
        gcn_order=1,
        dropout=0.3,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.out_len = out_len
        self.dropout = dropout

        self.start_conv = nn.Conv2d(in_dim, residual_channels, kernel_size=(1, 1))

        self.filter_convs = nn.ModuleList()
        self.gate_convs   = nn.ModuleList()
        self.res_convs    = nn.ModuleList()
        self.skip_convs   = nn.ModuleList()
        self.bn           = nn.ModuleList()
        self.gconvs       = nn.ModuleList()

        # Build temporal + graph blocks
        for _ in range(blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i

                self.filter_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))
                self.gate_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))

                self.res_convs.append(nn.Conv2d(dilation_channels, residual_channels, kernel_size=(1, 1)))
                self.skip_convs.append(nn.Conv2d(dilation_channels, skip_channels, kernel_size=(1, 1)))

                self.gconvs.append(
                    DiffusionGraphConv(dilation_channels, residual_channels, supports, order=gcn_order, dropout=dropout)
                )

                self.bn.append(nn.BatchNorm2d(residual_channels))

        self.end_conv_1 = nn.Conv2d(skip_channels, end_channels, kernel_size=(1, 1))
        self.end_conv_2 = nn.Conv2d(end_channels, out_len, kernel_size=(1, 1))  # outputs OUT_LEN channels

    def forward(self, x):
        """
        x: (B, F, N, T_in)
        return: (B, OUT_LEN, N)
        """
        x = self.start_conv(x)  # (B, residual, N, T)
        skip = None

        for i in range(len(self.filter_convs)):
            residual = x

            # gated TCN
            filt = torch.tanh(self.filter_convs[i](x))
            gate = torch.sigmoid(self.gate_convs[i](x))
            x = filt * gate
            x = F.dropout(x, p=self.dropout, training=self.training)

            # skip
            s = self.skip_convs[i](x)
            skip = s if skip is None else (skip + s)

            # graph conv -> residual channels
            x = self.gconvs[i](x)

            # residual connection (time length is preserved by causal padding)
            x = x + residual
            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))

        # Use last time step to predict future horizons
        x_last = x[..., -1:].contiguous()          # (B, end_channels, N, 1)
        out = self.end_conv_2(x_last).squeeze(-1)  # (B, OUT_LEN, N)
        return out


## Training & evaluation loop

We train using MSE on scaled targets (stable optimization),
then compute MAE/RMSE on the original scale at horizons 12/24/48/72.

Early stopping monitors average validation MAE across horizons.


In [25]:
import torch
import numpy as np
from tqdm.auto import tqdm

# tensors for unscaling (put on device once)
flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

EVAL_HORIZONS = [12, 24, 48, 72]

@torch.no_grad()
def eval_horizons(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)   # (B,F,N,T)
        yb = yb.to(DEVICE, non_blocking=True)   # (B,OUT,N) scaled

        pred = model(xb)                        # (B,OUT,N) scaled

        # unscale to original flow units
        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        for h in EVAL_HORIZONS:
            idx = h - 1
            err = pred_u[:, idx, :] - true_u[:, idx, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

def train_gwn(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=6,
):
    model = GraphWaveNet(
        num_nodes=X.shape[1],
        in_dim=X.shape[2],     # 6 features
        out_len=OUT_LEN,
        supports=supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=3,
        gcn_order=1,
        dropout=0.3,
    ).to(DEVICE)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.MSELoss()

    use_amp = False
    scaler = torch.cuda.amp.GradScaler(enabled=False)

    best_score = float("inf")
    best_state = None
    bad = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                pred = model(xb)
                loss = loss_fn(pred, yb)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            scaler.step(opt)
            scaler.update()

            running += float(loss.item())

        # validation metrics (original scale)
        val_metrics = eval_horizons(model, val_loader)
        score = avg_mae(val_metrics)

        print(f"Epoch {epoch}: train_loss={running/len(train_loader):.6f}  val_avg_MAE={score:.3f}")
        print_metrics("Val metrics", val_metrics)

        if score < best_score:
            best_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print(f"Early stopping at epoch {epoch}. Best val_avg_MAE={best_score:.3f}")
                break

    # load best
    model.load_state_dict(best_state)
    return model

gwn_model = train_gwn(epochs=30, lr=1e-3, weight_decay=1e-4, clip=5.0, patience=6)

# Final evaluation
val_m = eval_horizons(gwn_model, val_loader)
test_m = eval_horizons(gwn_model, test_loader)

print_metrics("GraphWaveNet — Validation", val_m)
print_metrics("GraphWaveNet — Test", test_m)


Epoch 1/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1: train_loss=0.529129  val_avg_MAE=203.973

Val metrics
   12h  MAE=225.832  RMSE=393.051
   24h  MAE=184.901  RMSE=343.538
   48h  MAE=208.168  RMSE=379.516
   72h  MAE=196.993  RMSE=369.422


Epoch 2/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2: train_loss=0.325922  val_avg_MAE=185.356

Val metrics
   12h  MAE=193.387  RMSE=348.977
   24h  MAE=166.734  RMSE=319.229
   48h  MAE=185.582  RMSE=349.533
   72h  MAE=195.720  RMSE=369.310


Epoch 3/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3: train_loss=0.287739  val_avg_MAE=180.951

Val metrics
   12h  MAE=179.654  RMSE=333.382
   24h  MAE=157.967  RMSE=309.919
   48h  MAE=187.718  RMSE=347.820
   72h  MAE=198.465  RMSE=367.385


Epoch 4/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4: train_loss=0.268916  val_avg_MAE=175.876

Val metrics
   12h  MAE=186.728  RMSE=342.408
   24h  MAE=154.416  RMSE=309.196
   48h  MAE=171.948  RMSE=328.877
   72h  MAE=190.413  RMSE=357.469


Epoch 5/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5: train_loss=0.257975  val_avg_MAE=170.299

Val metrics
   12h  MAE=184.271  RMSE=338.048
   24h  MAE=148.209  RMSE=299.617
   48h  MAE=168.291  RMSE=323.040
   72h  MAE=180.426  RMSE=343.043


Epoch 6/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6: train_loss=0.250068  val_avg_MAE=164.475

Val metrics
   12h  MAE=162.037  RMSE=303.679
   24h  MAE=150.514  RMSE=301.402
   48h  MAE=164.907  RMSE=316.562
   72h  MAE=180.444  RMSE=343.314


Epoch 7/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7: train_loss=0.244371  val_avg_MAE=167.873

Val metrics
   12h  MAE=164.548  RMSE=308.877
   24h  MAE=156.601  RMSE=304.825
   48h  MAE=176.487  RMSE=327.605
   72h  MAE=173.856  RMSE=331.363


Epoch 8/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8: train_loss=0.239806  val_avg_MAE=160.063

Val metrics
   12h  MAE=160.283  RMSE=302.982
   24h  MAE=148.200  RMSE=298.973
   48h  MAE=161.825  RMSE=314.139
   72h  MAE=169.943  RMSE=325.858


Epoch 9/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9: train_loss=0.234704  val_avg_MAE=163.708

Val metrics
   12h  MAE=158.480  RMSE=303.029
   24h  MAE=158.671  RMSE=308.487
   48h  MAE=163.579  RMSE=313.825
   72h  MAE=174.101  RMSE=332.094


Epoch 10/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10: train_loss=0.234201  val_avg_MAE=155.655

Val metrics
   12h  MAE=153.175  RMSE=294.953
   24h  MAE=141.512  RMSE=293.586
   48h  MAE=159.561  RMSE=315.417
   72h  MAE=168.371  RMSE=323.676


Epoch 11/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11: train_loss=0.231096  val_avg_MAE=154.549

Val metrics
   12h  MAE=152.754  RMSE=293.402
   24h  MAE=139.974  RMSE=287.064
   48h  MAE=157.335  RMSE=307.804
   72h  MAE=168.134  RMSE=323.799


Epoch 12/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12: train_loss=0.229034  val_avg_MAE=164.792

Val metrics
   12h  MAE=151.243  RMSE=291.496
   24h  MAE=147.712  RMSE=297.619
   48h  MAE=167.106  RMSE=321.370
   72h  MAE=193.109  RMSE=358.011


Epoch 13/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13: train_loss=0.223892  val_avg_MAE=164.343

Val metrics
   12h  MAE=148.523  RMSE=289.870
   24h  MAE=146.654  RMSE=296.672
   48h  MAE=173.357  RMSE=328.243
   72h  MAE=188.839  RMSE=347.545


Epoch 14/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14: train_loss=0.222705  val_avg_MAE=166.283

Val metrics
   12h  MAE=153.376  RMSE=294.916
   24h  MAE=155.567  RMSE=312.574
   48h  MAE=176.343  RMSE=332.107
   72h  MAE=179.847  RMSE=333.986


Epoch 15/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15: train_loss=0.221832  val_avg_MAE=163.332

Val metrics
   12h  MAE=169.177  RMSE=314.841
   24h  MAE=143.233  RMSE=290.326
   48h  MAE=168.777  RMSE=325.777
   72h  MAE=172.140  RMSE=329.309


Epoch 16/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16: train_loss=0.221454  val_avg_MAE=155.481

Val metrics
   12h  MAE=144.691  RMSE=279.696
   24h  MAE=139.187  RMSE=284.711
   48h  MAE=161.616  RMSE=311.975
   72h  MAE=176.428  RMSE=331.767


Epoch 17/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17: train_loss=0.217759  val_avg_MAE=155.037

Val metrics
   12h  MAE=151.572  RMSE=293.354
   24h  MAE=139.192  RMSE=286.145
   48h  MAE=161.421  RMSE=316.536
   72h  MAE=167.963  RMSE=321.738
Early stopping at epoch 17. Best val_avg_MAE=154.549


Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Eval:   0%|          | 0/43 [00:00<?, ?it/s]


GraphWaveNet — Validation
   12h  MAE=152.754  RMSE=293.402
   24h  MAE=139.974  RMSE=287.064
   48h  MAE=157.335  RMSE=307.804
   72h  MAE=168.134  RMSE=323.799

GraphWaveNet — Test
   12h  MAE=156.900  RMSE=298.198
   24h  MAE=134.966  RMSE=269.862
   48h  MAE=153.469  RMSE=300.160
   72h  MAE=161.221  RMSE=310.938


# GraphWaveNet Baseline (Journal-Grade) - Working on it again

This notebook section trains a strong GraphWaveNet baseline on PeMS using:
- Strict train/val/test window splits (no horizon leakage)
- Direction-aware adjacency: edges only within same freeway AND same direction
- Known future calendar covariates (hour-of-day/day-of-week), which are available at forecast time
- Multi-horizon evaluation at 12/24/48/72 hours

Outputs:
- Validation and test MAE/RMSE at each horizon


In [1]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


In [2]:
DATASET_BASE = Path("artifacts/pems_graph_dataset.npz")
DATASET_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

def make_strict_dataset(base_npz: Path, strict_npz: Path):
    d = np.load(base_npz, allow_pickle=True)

    X = d["X"]; Y = d["Y"]; A = d["A"]
    stations = d["stations"]
    timestamps = pd.to_datetime(d["timestamps"])
    IN_LEN = int(d["in_len"][0])
    OUT_LEN = int(d["out_len"][0])

    flow_mean = d["flow_mean"]; flow_std = d["flow_std"]
    speed_mean = d["speed_mean"]; speed_std = d["speed_std"]

    T_total = X.shape[0]
    max_t = T_total - (IN_LEN + OUT_LEN) + 1
    starts = np.arange(max_t, dtype=np.int32)

    out_start_times = timestamps[starts + IN_LEN]
    out_end_times   = timestamps[starts + IN_LEN + OUT_LEN - 1]

    train_starts = starts[out_end_times <= TRAIN_END]
    val_starts   = starts[(out_start_times > TRAIN_END) & (out_end_times <= VAL_END)]
    test_starts  = starts[out_start_times > VAL_END]

    np.savez_compressed(
        strict_npz,
        X=X.astype(np.float32),
        Y=Y.astype(np.float32),
        A=A.astype(np.float32),
        stations=stations.astype(np.int32),
        timestamps=np.array(timestamps.astype("datetime64[ns]")),
        train_starts=train_starts,
        val_starts=val_starts,
        test_starts=test_starts,
        in_len=np.array([IN_LEN], dtype=np.int32),
        out_len=np.array([OUT_LEN], dtype=np.int32),
        flow_mean=flow_mean.astype(np.float32),
        flow_std=flow_std.astype(np.float32),
        speed_mean=speed_mean.astype(np.float32),
        speed_std=speed_std.astype(np.float32),
        seed=np.array([SEED], dtype=np.int32),
    )
    print("Saved strict dataset:", strict_npz)

if (not DATASET_STRICT.exists()) and DATASET_BASE.exists():
    make_strict_dataset(DATASET_BASE, DATASET_STRICT)

assert DATASET_STRICT.exists(), "Strict dataset not found. Make sure artifacts are created."

d = np.load(DATASET_STRICT, allow_pickle=True)
X = d["X"]                      # (T,N,F)
Y = d["Y"]                      # (T,N)
stations = d["stations"]        # (N,)
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"]
val_starts   = d["val_starts"]
test_starts  = d["test_starts"]

IN_LEN = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean = d["flow_mean"]
flow_std  = d["flow_std"]
speed_mean = d["speed_mean"]
speed_std  = d["speed_std"]

print("X:", X.shape, "Y:", Y.shape)
print("N stations:", len(stations), "T:", len(timestamps))
print("starts:", len(train_starts), len(val_starts), len(test_starts))
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)


X: (2208, 1821, 6) Y: (2208, 1821)
N stations: 1821 T: 2208
starts: 1009 289 673
IN_LEN: 24 OUT_LEN: 72


## Direction-aware adjacency (recommended baseline graph)

We build edges only among sensors that share:
- the same freeway (Fwy)
- the same direction of travel (Direction of Travel)
sorted by Abs PM.

This avoids mixing opposite-direction stations, which often destroys graph model performance.


In [4]:
TRAFFIC_CSV = Path("cleaned_traffic_data.csv")
META_XLSX   = Path("pems_output.xlsx")

assert TRAFFIC_CSV.exists(), "cleaned_traffic_data.csv not found in working directory."
assert META_XLSX.exists(), "pems_output.xlsx not found in working directory."

# 1) station -> direction mapping (mode direction per station)
tmp = pd.read_csv(TRAFFIC_CSV, usecols=["Station", "Direction of Travel"])
tmp = tmp.rename(columns={"Station": "station", "Direction of Travel": "direction"})
tmp["station"] = pd.to_numeric(tmp["station"], errors="coerce").astype("Int64")
tmp = tmp.dropna(subset=["station"])
tmp["station"] = tmp["station"].astype(int)

station_dir = tmp.groupby("station")["direction"].agg(lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0])
station_dir = station_dir.reindex(stations)

print("Direction mapping coverage:", float(station_dir.notna().mean()))
print("Direction counts:\n", station_dir.value_counts(dropna=False))

# 2) metadata
meta = pd.read_excel(META_XLSX)
# your metadata already has "station" column (from your earlier merge output)
if "station" not in meta.columns and "ID" in meta.columns:
    meta = meta.rename(columns={"ID": "station"})
meta["station"] = pd.to_numeric(meta["station"], errors="coerce").astype("Int64")
meta = meta.dropna(subset=["station"])
meta["station"] = meta["station"].astype(int)

assert "Fwy" in meta.columns and "Abs PM" in meta.columns, "Metadata must contain 'Fwy' and 'Abs PM'."

def build_adjacency_fwy_dir(meta_df, stations, station_dir, k_neighbors=4):
    meta_sub = meta_df[meta_df["station"].isin(stations)].copy()
    meta_sub["fwy"] = meta_sub["Fwy"].astype(str)
    meta_sub["abs_pm"] = pd.to_numeric(meta_sub["Abs PM"], errors="coerce")
    meta_sub["direction"] = meta_sub["station"].map(station_dir)

    # If direction missing, we exclude from neighbor edges (but keep self-loop later)
    meta_sub = meta_sub.dropna(subset=["abs_pm", "direction"]).copy()
    meta_sub["direction"] = meta_sub["direction"].astype(str)

    station_to_idx = {s: i for i, s in enumerate(stations)}
    N = len(stations)
    A = np.zeros((N, N), dtype=np.float32)

    # sigma from typical neighbor distances
    all_dists = []
    for (fwy, direc), grp in meta_sub.sort_values(["fwy", "direction", "abs_pm"]).groupby(["fwy", "direction"]):
        pm = grp["abs_pm"].values
        if len(pm) < 2:
            continue
        d = np.diff(np.sort(pm))
        d = d[d > 0]
        all_dists.extend(d.tolist())

    sigma = float(np.median(all_dists)) if len(all_dists) else 0.5
    sigma = max(sigma, 1e-3)

    def w(dist):
        return float(np.exp(- (dist**2) / (sigma**2)))

    for (fwy, direc), grp in meta_sub.sort_values(["fwy", "direction", "abs_pm"]).groupby(["fwy", "direction"]):
        grp = grp.sort_values("abs_pm")
        ids = grp["station"].astype(int).tolist()
        pms = grp["abs_pm"].astype(float).tolist()

        for i, sid in enumerate(ids):
            ii = station_to_idx[sid]
            for step in range(1, k_neighbors + 1):
                if i - step >= 0:
                    sj = ids[i-step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i-step]))
                if i + step < len(ids):
                    sj = ids[i+step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i+step]))

    np.fill_diagonal(A, 1.0)
    A = np.maximum(A, A.T)
    return A

A_dir = build_adjacency_fwy_dir(meta, stations, station_dir, k_neighbors=4)
print("A_dir shape:", A_dir.shape)
print("A_dir density:", float((A_dir > 0).mean()))

def row_normalize_dense(A, eps=1e-6):
    d = A.sum(axis=1, keepdims=True)
    return A / (d + eps)

def dense_to_sparse(A, device):
    idx = np.nonzero(A)
    values = A[idx].astype(np.float32)
    indices = np.vstack(idx)
    return torch.sparse_coo_tensor(
        torch.tensor(indices, dtype=torch.long, device=device),
        torch.tensor(values, dtype=torch.float32, device=device),
        size=A.shape,
        device=device
    ).coalesce()

A_rw  = row_normalize_dense(A_dir)
A_rwT = row_normalize_dense(A_dir.T)

supports = [dense_to_sparse(A_rw, DEVICE), dense_to_sparse(A_rwT, DEVICE)]
print("Supports nnz:", [int(s._nnz()) for s in supports])


Direction mapping coverage: 1.0
Direction counts:
 direction
E    509
W    499
S    410
N    403
Name: count, dtype: int64
A_dir shape: (1821, 1821)
A_dir density: 0.0038374073179432942
Supports nnz: [12720, 12720]


## Dataset and loaders

Each sample returns:
- x:  (F, N, IN_LEN)
- y:  (OUT_LEN, N)  scaled flow target
- tf: (OUT_LEN, 4)  known future calendar covariates (hour/dow sin/cos)

We compute tf from timestamps (not from future traffic), so it is NOT leakage.


In [5]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)  # (T,4)

class PemsWindowDatasetTF(Dataset):
    """
    Returns:
      x:  (F, N, IN_LEN)
      y:  (OUT_LEN, N) scaled
      tf: (OUT_LEN, 4) known future time features
    """
    def __init__(self, X, Y, timestamps, starts, in_len, out_len,
                 flow_mean, flow_std, speed_mean, speed_std):
        self.X = X
        self.Y = Y
        self.timestamps = pd.DatetimeIndex(timestamps)
        self.starts = starts.astype(np.int32)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

        self.flow_mean = flow_mean.astype(np.float32)
        self.flow_std  = flow_std.astype(np.float32)
        self.speed_mean = speed_mean.astype(np.float32)
        self.speed_std  = speed_std.astype(np.float32)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])

        x = self.X[t:t+self.in_len].copy().astype(np.float32)  # (IN_LEN, N, F)
        y = self.Y[t+self.in_len:t+self.in_len+self.out_len].copy().astype(np.float32)  # (OUT_LEN, N)

        ts_future = self.timestamps[t+self.in_len:t+self.in_len+self.out_len]
        tf = time_encoding(ts_future)  # (OUT_LEN, 4)

        # scale inputs: flow=0 speed=1
        x[..., 0] = (x[..., 0] - self.flow_mean[None, :]) / self.flow_std[None, :]
        x[..., 1] = (x[..., 1] - self.speed_mean[None, :]) / self.speed_std[None, :]

        # scale targets (flow)
        y = (y - self.flow_mean[None, :]) / self.flow_std[None, :]

        # (IN_LEN, N, F) -> (F, N, IN_LEN)
        x = np.transpose(x, (2, 1, 0))

        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(tf)

# Safer batch size for GraphWaveNet on N=1821
BATCH_SIZE = 8

train_ds = PemsWindowDatasetTF(X, Y, timestamps, train_starts, IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
val_ds   = PemsWindowDatasetTF(X, Y, timestamps, val_starts,   IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
test_ds  = PemsWindowDatasetTF(X, Y, timestamps, test_starts,  IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)



Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


## GraphWaveNet baseline (sparse diffusion graph conv + dilated temporal conv)

We:
- use causal dilated temporal convolutions + gating
- do sparse graph propagation with the supports
- use residual + skip connections
- use a time-aware output head with future calendar covariates


In [6]:
class NConv(nn.Module):
    """Sparse graph multiplication: (N,N) @ (B,C,N,T) -> (B,C,N,T)"""
    def forward(self, x, A_sp):
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)  # (N, B*C*T)

        # sparse.mm requires float32 on cuda reliably
        out = torch.sparse.mm(A_sp, x_r.float())
        out = out.reshape(N, B, C, T).permute(1, 2, 0, 3)
        return out.to(dtype=x.dtype)

class DiffusionGraphConv(nn.Module):
    def __init__(self, c_in, c_out, supports, order=1, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.order = order
        self.dropout = dropout

        c_total = c_in * (1 + len(supports) * order)
        self.mlp = nn.Conv2d(c_total, c_out, kernel_size=(1, 1))

    def forward(self, x):
        out = [x]
        for A in self.supports:
            x1 = self.nconv(x, A)
            out.append(x1)
            for _ in range(2, self.order + 1):
                x1 = self.nconv(x1, A)
                out.append(x1)

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h

class CausalConv2d(nn.Module):
    """Causal conv along time axis only."""
    def __init__(self, c_in, c_out, kernel_size=2, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=(1, kernel_size), dilation=(1, dilation))

    def forward(self, x):
        x = F.pad(x, (self.pad, 0, 0, 0))
        return self.conv(x)

class GraphWaveNetTimeAware(nn.Module):
    def __init__(
        self,
        num_nodes,
        in_dim,
        out_len,
        supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=4,   # ensures receptive field >= 24
        gcn_order=1,
        dropout=0.1,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.out_len = out_len
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.blocks = blocks
        self.layers_per_block = layers_per_block

        # receptive field
        receptive_field = 1
        for _ in range(blocks):
            for i in range(layers_per_block):
                receptive_field += (kernel_size - 1) * (2 ** i)
        self.receptive_field = receptive_field

        self.start_conv = nn.Conv2d(in_dim, residual_channels, kernel_size=(1, 1))

        self.filter_convs = nn.ModuleList()
        self.gate_convs   = nn.ModuleList()
        self.skip_convs   = nn.ModuleList()
        self.bn           = nn.ModuleList()
        self.gconvs       = nn.ModuleList()

        for _ in range(blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i
                self.filter_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))
                self.gate_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))

                self.skip_convs.append(nn.Conv2d(dilation_channels, skip_channels, kernel_size=(1, 1)))
                self.gconvs.append(DiffusionGraphConv(dilation_channels, residual_channels, supports, order=gcn_order, dropout=dropout))
                self.bn.append(nn.BatchNorm2d(residual_channels))

        self.end_conv_1 = nn.Conv2d(skip_channels, end_channels, kernel_size=(1, 1))

        # time-aware head
        self.time_embed = nn.Linear(4, end_channels)
        self.horizon_out = nn.Linear(end_channels, 1)

    def forward(self, x, tf_future):
        """
        x:        (B, F, N, IN_LEN)
        tf_future:(B, OUT_LEN, 4)
        out:      (B, OUT_LEN, N)
        """
        # pad input if needed to match receptive field
        if x.size(-1) < self.receptive_field:
            pad_len = self.receptive_field - x.size(-1)
            x = F.pad(x, (pad_len, 0, 0, 0))

        x = self.start_conv(x)  # (B, residual, N, T)
        skip = None

        for i in range(len(self.filter_convs)):
            residual = x

            filt = torch.tanh(self.filter_convs[i](x))
            gate = torch.sigmoid(self.gate_convs[i](x))
            x = filt * gate
            x = F.dropout(x, p=self.dropout, training=self.training)

            s = self.skip_convs[i](x)
            skip = s if skip is None else (skip + s)

            x = self.gconvs[i](x)
            x = x + residual
            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))     # (B, end_channels, N, T)

        z = x[..., -1]                     # (B, end_channels, N)
        z = z.permute(0, 2, 1)             # (B, N, end_channels)

        te = self.time_embed(tf_future)    # (B, OUT_LEN, end_channels)

        h = F.relu(z.unsqueeze(1) + te.unsqueeze(2))  # (B, OUT_LEN, N, end_channels)
        out = self.horizon_out(h).squeeze(-1)         # (B, OUT_LEN, N)
        return out


## Training + evaluation

- Loss: SmoothL1 (Huber) on SCALED targets
- Metrics: MAE/RMSE on ORIGINAL scale
- Early stopping on average validation MAE (across horizons)
- AMP disabled (sparse graph ops do not support fp16 reliably)


In [None]:
EVAL_HORIZONS = [12, 24, 48, 72]

flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

@torch.no_grad()
def eval_horizons(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # scaled

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        for h in EVAL_HORIZONS:
            idx = h - 1
            err = pred_u[:, idx, :] - true_u[:, idx, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

def train_gwn_baseline(
    epochs=40,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=8,
):
    model = GraphWaveNetTimeAware(
        num_nodes=X.shape[1],
        in_dim=X.shape[2],
        out_len=OUT_LEN,
        supports=supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=4,
        gcn_order=1,
        dropout=0.1,
    ).to(DEVICE)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    bad = 0

    # IMPORTANT: disable AMP due to sparse graph ops
    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            running += float(loss.item())

        val_metrics = eval_horizons(model, val_loader)
        score = avg_mae(val_metrics)

        print(f"\nEpoch {epoch}: train_loss={running/len(train_loader):.6f}  val_avg_MAE={score:.3f}")
        print_metrics("Validation metrics", val_metrics)

        if score < best_score:
            best_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                break

    model.load_state_dict(best_state)
    return model

gwn = train_gwn_baseline()

val_m = eval_horizons(gwn, val_loader)
test_m = eval_horizons(gwn, test_loader)

print_metrics("GraphWaveNet Baseline — Validation", val_m)
print_metrics("GraphWaveNet Baseline — Test", test_m)


Epoch 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 1: train_loss=0.166336  val_avg_MAE=192.611

Validation metrics
   12h  MAE=191.403  RMSE=358.438
   24h  MAE=189.733  RMSE=363.071
   48h  MAE=196.351  RMSE=372.638
   72h  MAE=192.957  RMSE=360.658


Epoch 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.117438  val_avg_MAE=189.943

Validation metrics
   12h  MAE=182.940  RMSE=338.154
   24h  MAE=187.478  RMSE=348.945
   48h  MAE=195.416  RMSE=359.328
   72h  MAE=193.937  RMSE=352.510


Epoch 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 3: train_loss=0.104683  val_avg_MAE=183.902

Validation metrics
   12h  MAE=178.179  RMSE=326.433
   24h  MAE=180.125  RMSE=333.626
   48h  MAE=189.455  RMSE=346.797
   72h  MAE=187.850  RMSE=341.038


Epoch 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.097027  val_avg_MAE=166.079

Validation metrics
   12h  MAE=158.517  RMSE=299.517
   24h  MAE=160.295  RMSE=308.539
   48h  MAE=172.899  RMSE=327.715
   72h  MAE=172.604  RMSE=323.193


Epoch 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 5: train_loss=0.091661  val_avg_MAE=167.928

Validation metrics
   12h  MAE=158.100  RMSE=296.278
   24h  MAE=162.273  RMSE=308.964
   48h  MAE=175.807  RMSE=330.409
   72h  MAE=175.530  RMSE=328.313


Epoch 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.088536  val_avg_MAE=163.555

Validation metrics
   12h  MAE=154.952  RMSE=290.983
   24h  MAE=155.721  RMSE=300.056
   48h  MAE=170.827  RMSE=322.243
   72h  MAE=172.719  RMSE=322.721


Epoch 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 7: train_loss=0.085676  val_avg_MAE=168.625

Validation metrics
   12h  MAE=154.776  RMSE=290.849
   24h  MAE=162.844  RMSE=308.113
   48h  MAE=177.034  RMSE=329.264
   72h  MAE=179.846  RMSE=333.807


Epoch 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.083878  val_avg_MAE=165.275

Validation metrics
   12h  MAE=150.444  RMSE=283.392
   24h  MAE=157.638  RMSE=302.180
   48h  MAE=174.519  RMSE=327.460
   72h  MAE=178.499  RMSE=334.144


Epoch 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 9: train_loss=0.082351  val_avg_MAE=161.167

Validation metrics
   12h  MAE=143.012  RMSE=274.647
   24h  MAE=154.429  RMSE=299.217
   48h  MAE=172.190  RMSE=325.996
   72h  MAE=175.035  RMSE=329.621


Epoch 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.081194  val_avg_MAE=161.562

Validation metrics
   12h  MAE=144.849  RMSE=278.380
   24h  MAE=156.305  RMSE=303.425
   48h  MAE=171.748  RMSE=326.235
   72h  MAE=173.348  RMSE=327.463


Epoch 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 11: train_loss=0.079560  val_avg_MAE=156.545

Validation metrics
   12h  MAE=140.382  RMSE=272.981
   24h  MAE=149.853  RMSE=294.445
   48h  MAE=166.326  RMSE=321.796
   72h  MAE=169.619  RMSE=328.382


Epoch 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.078578  val_avg_MAE=153.288

Validation metrics
   12h  MAE=135.332  RMSE=264.037
   24h  MAE=145.388  RMSE=287.101
   48h  MAE=164.178  RMSE=317.646
   72h  MAE=168.254  RMSE=326.255


Epoch 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 13: train_loss=0.077964  val_avg_MAE=147.968

Validation metrics
   12h  MAE=132.972  RMSE=263.475
   24h  MAE=139.162  RMSE=277.580
   48h  MAE=157.886  RMSE=305.720
   72h  MAE=161.853  RMSE=312.469


Epoch 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.076900  val_avg_MAE=154.376

Validation metrics
   12h  MAE=139.283  RMSE=273.490
   24h  MAE=146.242  RMSE=291.323
   48h  MAE=164.243  RMSE=320.756
   72h  MAE=167.738  RMSE=327.017


Epoch 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 15: train_loss=0.075930  val_avg_MAE=155.880

Validation metrics
   12h  MAE=137.547  RMSE=271.921
   24h  MAE=148.102  RMSE=298.113
   48h  MAE=167.544  RMSE=326.551
   72h  MAE=170.325  RMSE=330.410


Epoch 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.075627  val_avg_MAE=151.961

Validation metrics
   12h  MAE=135.106  RMSE=264.507
   24h  MAE=143.022  RMSE=284.876
   48h  MAE=162.791  RMSE=315.517
   72h  MAE=166.925  RMSE=322.595


Epoch 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 17: train_loss=0.075027  val_avg_MAE=154.090

Validation metrics
   12h  MAE=137.682  RMSE=268.177
   24h  MAE=145.152  RMSE=287.310
   48h  MAE=164.215  RMSE=319.679
   72h  MAE=169.310  RMSE=329.051


Epoch 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.074583  val_avg_MAE=160.386

Validation metrics
   12h  MAE=138.560  RMSE=268.469
   24h  MAE=151.101  RMSE=294.014
   48h  MAE=172.828  RMSE=328.903
   72h  MAE=179.054  RMSE=340.562


Epoch 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 19: train_loss=0.074133  val_avg_MAE=154.503

Validation metrics
   12h  MAE=137.373  RMSE=262.704
   24h  MAE=145.690  RMSE=282.779
   48h  MAE=165.019  RMSE=313.960
   72h  MAE=169.931  RMSE=324.774


Epoch 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.074154  val_avg_MAE=154.026

Validation metrics
   12h  MAE=135.481  RMSE=266.801
   24h  MAE=145.055  RMSE=289.757
   48h  MAE=165.571  RMSE=322.917
   72h  MAE=169.995  RMSE=331.973


Epoch 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 21: train_loss=0.073161  val_avg_MAE=142.855

Validation metrics
   12h  MAE=126.640  RMSE=256.993
   24h  MAE=134.450  RMSE=273.870
   48h  MAE=153.180  RMSE=302.537
   72h  MAE=157.152  RMSE=309.364


Epoch 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.073137  val_avg_MAE=148.385

Validation metrics
   12h  MAE=130.637  RMSE=263.231
   24h  MAE=139.916  RMSE=283.505
   48h  MAE=159.444  RMSE=314.125
   72h  MAE=163.542  RMSE=323.153


Epoch 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 23: train_loss=0.072553  val_avg_MAE=147.939

Validation metrics
   12h  MAE=132.895  RMSE=261.557
   24h  MAE=137.680  RMSE=276.316
   48h  MAE=158.085  RMSE=309.733
   72h  MAE=163.094  RMSE=321.161


Epoch 24/40:   0%|          | 0/127 [00:00<?, ?it/s]