# Multi-Horizon Traffic Forecasting on PeMS (Graph Models)

## Goal (Paper Claim)
Build a leakage-safe, reproducible pipeline on PeMS traffic data and evaluate multi-horizon forecasting models fairly.

Primary goal:
- Demonstrate the proposed **GraphWaveNet-GRU-LSTM** performs best on PeMS under the same train/val/test protocol.

Key principles:
- No time leakage (all statistics computed from train only).
- One shared dataset representation for all deep models: **X ∈ R^{T×N×F}, Y ∈ R^{T×N}**.
- One fixed evaluation harness (same horizons, same metrics, same seeds).
- Strong baselines + ablations:
  - HA / Persistence
  - GRU / LSTM (non-graph)
  - GraphWaveNet
  - GraphWaveNet+GRU
  - GraphWaveNet+LSTM
  - **GraphWaveNet+GRU+LSTM (proposed)**


In [1]:
!pip -q install -r requirements.txt


[0m

In [2]:
!pip -q install numpy pandas openpyxl scikit-learn torch tqdm


[0m

In [3]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd

import torch
from tqdm.auto import tqdm

def set_seed(seed: int = 42, deterministic: bool = True):
    """
    Sets seeds for reproducibility.
    deterministic=True makes results more reproducible but can reduce speed.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

SEED = 42
set_seed(SEED, deterministic=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


## Configuration

We fix:
- Input window length (`IN_LEN`) and forecast horizon length (`OUT_LEN`)
- Train/val/test boundaries (time-based split)
- Station inclusion rule (coverage threshold)
- Output dataset artifact path (so every model uses the same processed dataset)

Important:
GraphWaveNet expects a consistent node set and continuous time axis,
so we build a clean matrix format (timestamp × station).


In [4]:
# -------------------------
# Paths (your files are visible in the Paperspace file pane)
# -------------------------
TRAFFIC_CSV = Path("cleaned_traffic_data.csv")
META_XLSX   = Path("pems_output.xlsx")

assert TRAFFIC_CSV.exists(), f"Missing {TRAFFIC_CSV}"
assert META_XLSX.exists(), f"Missing {META_XLSX}"

# -------------------------
# Split boundaries (same as your earlier work)
# -------------------------
TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

# -------------------------
# Forecast setup
# -------------------------
IN_LEN  = 24     # hours of history used as input
OUT_LEN = 72     # predict next 72 hours (we will evaluate at 12/24/48/72)

EVAL_HORIZONS = [12, 24, 48, 72]  # hours ahead

# -------------------------
# Station coverage threshold
# -------------------------
# 1.0 means station must have ALL timestamps present.
# 0.98 is often a good compromise if some stations are missing few points.
COVERAGE_THRESHOLD = 0.98

# -------------------------
# Adjacency setup (static graph baseline)
# -------------------------
K_NEIGHBORS = 2   # connect up to 2 upstream + 2 downstream along the freeway chain

# -------------------------
# Output artifact (important for reproducibility)
# -------------------------
OUT_DIR = Path("artifacts")
OUT_DIR.mkdir(exist_ok=True)

DATASET_NPZ = OUT_DIR / "pems_graph_dataset.npz"
print("Will save processed dataset to:", DATASET_NPZ)


Will save processed dataset to: artifacts/pems_graph_dataset.npz


## Load raw traffic + metadata

We:
1) Load cleaned traffic data
2) Load station metadata
3) Standardize column names
4) Merge metadata onto traffic records (inner join so every station has metadata)
5) Verify timestamp parsing and basic integrity checks


In [5]:
def require_col(df: pd.DataFrame, candidates, friendly_name: str):
    """
    Find the first matching column in candidates.
    Raise a helpful error if not found.
    """
    for c in candidates:
        if c in df.columns:
            return c
    raise KeyError(
        f"Could not find column for '{friendly_name}'. Tried: {candidates}\n"
        f"Available columns: {list(df.columns)}"
    )

def to_datetime_safe(s: pd.Series) -> pd.Series:
    return pd.to_datetime(s, errors="coerce")

def pct_missing(s: pd.Series) -> float:
    return float(s.isna().mean() * 100.0)



In [6]:
traffic_raw = pd.read_csv(TRAFFIC_CSV)
meta_raw = pd.read_excel(META_XLSX)

print("Traffic shape:", traffic_raw.shape)
print("Meta shape:", meta_raw.shape)

# --- Identify expected columns robustly ---
ts_col   = require_col(traffic_raw, ["Timestamp", "timestamp", "Time", "Datetime"], "Timestamp")
st_col   = require_col(traffic_raw, ["Station", "station", "ID"], "Station ID")
flow_col = require_col(traffic_raw, ["Total Flow", "total_flow", "Flow", "total flow"], "Total Flow")
spd_col  = require_col(traffic_raw, ["Avg Speed", "avg_speed", "Speed", "Avg speed"], "Avg Speed")

lane_col = require_col(traffic_raw, ["Lane Type", "lane_type", "LaneType"], "Lane Type")
dir_col  = require_col(traffic_raw, ["Direction of Travel", "direction", "Dir"], "Direction")
dist_col = require_col(traffic_raw, ["District", "district"], "District")

# --- Standardize traffic ---
traffic = traffic_raw.rename(columns={
    ts_col: "timestamp",
    st_col: "station",
    flow_col: "total_flow",
    spd_col: "avg_speed",
    lane_col: "lane_type",
    dir_col: "direction",
    dist_col: "district",
}).copy()

traffic["timestamp"] = to_datetime_safe(traffic["timestamp"])
traffic["station"] = pd.to_numeric(traffic["station"], errors="coerce").astype("Int64")

traffic = traffic.dropna(subset=["timestamp", "station"]).copy()
traffic["station"] = traffic["station"].astype(int)

print("After basic parsing:", traffic.shape)
print("Timestamp range:", traffic["timestamp"].min(), "→", traffic["timestamp"].max())

# --- Standardize metadata ---
# station id in metadata usually is 'ID'
meta_id_col = require_col(meta_raw, ["ID", "station", "Station"], "Meta Station ID")
meta = meta_raw.rename(columns={meta_id_col: "station"}).copy()
meta["station"] = pd.to_numeric(meta["station"], errors="coerce").astype("Int64")
meta = meta.dropna(subset=["station"]).copy()
meta["station"] = meta["station"].astype(int)

print("Meta columns (peek):", list(meta.columns)[:20])

# Merge metadata (inner ensures we only keep stations that have metadata)
df = traffic.merge(meta, on="station", how="inner", validate="m:1")
print("Merged df shape:", df.shape)
print("Unique stations:", df["station"].nunique())


Traffic shape: (4114680, 42)
Meta shape: (1861, 15)
After basic parsing: (4114680, 42)
Timestamp range: 2024-10-01 00:00:00 → 2024-12-31 23:00:00
Meta columns (peek): ['Fwy', 'District', 'County', 'City', 'CA PM', 'Abs PM', 'Length', 'station', 'Name', 'Lanes', 'Type', 'Sensor Type', 'HOV', 'MS ID', 'IRM']
Merged df shape: (4051621, 56)
Unique stations: 1861


## Sanity checks

We check:
- Duplicate rows per (timestamp, station)
- Time frequency (hourly vs not)
- Missingness rates
These checks prevent silent data problems that can invalidate results.


In [7]:
# 1) Duplicates by (timestamp, station)
dup_count = df.duplicated(subset=["timestamp", "station"]).sum()
print("Duplicate (timestamp, station) rows:", int(dup_count))

if dup_count > 0:
    # Resolve duplicates safely: flow sums, speed averages
    df = (df.groupby(["timestamp", "station"], as_index=False)
            .agg({
                "total_flow": "sum",
                "avg_speed": "mean",
                "lane_type": "first",
                "direction": "first",
                "district": "first",
                # keep metadata columns by first
                **{c: "first" for c in meta.columns if c != "station"}
            }))
    print("After de-duplication:", df.shape)

# 2) Check time deltas
times = pd.DatetimeIndex(sorted(df["timestamp"].unique()))
deltas = pd.Series(times[1:] - times[:-1]).value_counts().head(5)
print("Most common timestamp deltas:\n", deltas)

# 3) Missingness
print("Missing total_flow (%):", pct_missing(df["total_flow"]))
print("Missing avg_speed (%):", pct_missing(df["avg_speed"]))


Duplicate (timestamp, station) rows: 0
Most common timestamp deltas:
 0 days 01:00:00    2207
Name: count, dtype: int64
Missing total_flow (%): 7.243150334150209
Missing avg_speed (%): 37.88496011843161


## Build station-time matrices

Graph models require a clean tensor format.
We create two matrices:
- Flow:  (T timestamps × N stations)
- Speed: (T timestamps × N stations)

We also select a stable station set using a coverage threshold.


In [8]:
# Full timestamp index
all_times = pd.DatetimeIndex(sorted(df["timestamp"].unique()))
T = len(all_times)

# Station coverage
counts = df.groupby("station")["timestamp"].nunique()
coverage = counts / T

keep_stations = coverage[coverage >= COVERAGE_THRESHOLD].index
df2 = df[df["station"].isin(keep_stations)].copy()

stations = np.array(sorted(df2["station"].unique()), dtype=int)
N = len(stations)

print(f"Timestamps (T) = {T}")
print(f"Stations kept (N) = {N}  (coverage threshold={COVERAGE_THRESHOLD})")

# Build matrices
flow = (df2.pivot(index="timestamp", columns="station", values="total_flow")
          .reindex(index=all_times, columns=stations)
          .sort_index())

speed = (df2.pivot(index="timestamp", columns="station", values="avg_speed")
           .reindex(index=all_times, columns=stations)
           .sort_index())

print("Flow matrix:", flow.shape, "Speed matrix:", speed.shape)
print("Flow missing fraction:", float(np.isnan(flow.to_numpy()).mean()))
print("Speed missing fraction:", float(np.isnan(speed.to_numpy()).mean()))


Timestamps (T) = 2208
Stations kept (N) = 1821  (coverage threshold=0.98)
Flow matrix: (2208, 1821) Speed matrix: (2208, 1821)
Flow missing fraction: 0.0712893656137335
Speed missing fraction: 0.3772684720928937


## Leakage-safe imputation

We must not use validation/test information when estimating fill values.

Strategy:
- Forward-fill across time (realistic streaming behavior).
- Remaining NaNs filled using TRAIN statistics only.

Flow:
- ffill → fill with per-station TRAIN mean → fill with global TRAIN mean

Speed:
- ffill → fill using a TRAIN-only group lookup (lane_type, meta type, hour, fwy, district)
- then per-station TRAIN mean → global TRAIN mean


In [9]:
# Identify metadata columns we need for speed lookup
meta_type_col = None
for cand in ["Type", "type", "Station Type"]:
    if cand in df2.columns:
        meta_type_col = cand
        break

fwy_col = None
for cand in ["Fwy", "FWY", "fwy", "Freeway"]:
    if cand in df2.columns:
        fwy_col = cand
        break

if meta_type_col is None or fwy_col is None:
    raise KeyError(f"Missing metadata columns for speed lookup. Found meta_type={meta_type_col}, fwy={fwy_col}")

train_time_mask = flow.index <= TRAIN_END

# -------------------------
# Flow imputation
# -------------------------
flow_ff = flow.ffill()

flow_train_mean_station = flow_ff.loc[train_time_mask].mean(axis=0)
flow_train_mean_global = flow_ff.loc[train_time_mask].stack().mean()

flow_imp = flow_ff.fillna(flow_train_mean_station).fillna(flow_train_mean_global)

# -------------------------
# Speed lookup (TRAIN only)
# -------------------------
train_rows = df2[df2["timestamp"] <= TRAIN_END].copy()
train_rows["hour"] = train_rows["timestamp"].dt.hour

speed_grp_cols = ["lane_type", meta_type_col, "hour", fwy_col, "district"]
speed_lookup = train_rows.groupby(speed_grp_cols)["avg_speed"].mean()

global_speed_train_mean = train_rows["avg_speed"].mean()

# Station-level "mode" descriptors used when applying the lookup
station_info = (df2.groupby("station")
                  .agg(
                      lane_type=("lane_type", lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      meta_type=(meta_type_col, lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      fwy=(fwy_col, lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                      district=("district", lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0]),
                  )
                  .reindex(stations))

speed_ff = speed.ffill()
speed_np = speed_ff.to_numpy(dtype=np.float32)
miss = np.isnan(speed_np)
hours = speed_ff.index.hour.values

# Fill with lookup
for j, st in enumerate(stations):
    if not miss[:, j].any():
        continue
    info = station_info.loc[st]
    lane_type = info["lane_type"]
    meta_type = info["meta_type"]
    fwy = info["fwy"]
    district = info["district"]

    idxs = np.where(miss[:, j])[0]
    fill_vals = []
    for t_idx in idxs:
        h = int(hours[t_idx])
        key = (lane_type, meta_type, h, fwy, district)
        fill_vals.append(speed_lookup.get(key, np.nan))
    speed_np[idxs, j] = np.array(fill_vals, dtype=np.float32)

speed_imp = pd.DataFrame(speed_np, index=speed_ff.index, columns=speed_ff.columns)

# Remaining NaNs → station TRAIN mean → global TRAIN mean
speed_train_mean_station = speed_imp.loc[train_time_mask].mean(axis=0)
speed_imp = speed_imp.fillna(speed_train_mean_station).fillna(global_speed_train_mean)

print("After imputation:")
print("Flow missing fraction:", float(np.isnan(flow_imp.to_numpy()).mean()))
print("Speed missing fraction:", float(np.isnan(speed_imp.to_numpy()).mean()))


After imputation:
Flow missing fraction: 0.0
Speed missing fraction: 0.0


## Build graph-ready tensors

We create:
- X: (T, N, F)
  Features include:
  - flow (1)
  - speed (1)
  - time encodings: hour_sin, hour_cos, dow_sin, dow_cos (4)
  Total F = 6

- Y: (T, N)
  Target is flow at each station.

Later, each training sample is a sliding window:
- Input:  X[t : t+IN_LEN]
- Output: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]


In [10]:
def make_time_features(timestamps: pd.DatetimeIndex) -> np.ndarray:
    hours = timestamps.hour.values
    dow   = timestamps.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)  # (T,4)

time_feats = make_time_features(flow_imp.index)  # (T,4)
time_feats_b = np.repeat(time_feats[:, None, :], repeats=N, axis=1)  # (T,N,4)

flow_arr  = flow_imp.to_numpy(dtype=np.float32)[:, :, None]   # (T,N,1)
speed_arr = speed_imp.to_numpy(dtype=np.float32)[:, :, None]  # (T,N,1)

X = np.concatenate([flow_arr, speed_arr, time_feats_b], axis=2)  # (T,N,6)
Y = flow_arr.squeeze(-1).astype(np.float32)                      # (T,N)

print("X shape:", X.shape, " (T,N,F)")
print("Y shape:", Y.shape, " (T,N)")


X shape: (2208, 1821, 6)  (T,N,F)
Y shape: (2208, 1821)  (T,N)


## Build adjacency matrix A (static graph baseline)

We build a physical-neighborhood adjacency using metadata:
- Sort stations by (freeway, absolute postmile)
- Connect K neighbors upstream + downstream
- Weight edges using a Gaussian kernel of distance
- Add self-loops

Note:
GraphWaveNet can also learn an adaptive adjacency; this static graph is a strong baseline.


In [11]:
def build_adjacency_from_metadata(meta_df: pd.DataFrame, stations: np.ndarray, k_neighbors: int = 2) -> np.ndarray:
    """
    Build adjacency within each freeway chain using Abs PM order.
    Edge weights = exp(-(dist^2 / sigma^2)), sigma = median neighbor distance.
    """
    # Find needed meta columns
    id_col = "station"
    abs_pm_col = None
    for cand in ["Abs PM", "abs_pm", "AbsPM", "Postmile", "PM"]:
        if cand in meta_df.columns:
            abs_pm_col = cand
            break
    fwy_col2 = None
    for cand in ["Fwy", "FWY", "fwy", "Freeway"]:
        if cand in meta_df.columns:
            fwy_col2 = cand
            break

    if abs_pm_col is None or fwy_col2 is None:
        raise KeyError(f"Metadata missing Abs PM or Fwy columns. Found AbsPM={abs_pm_col}, Fwy={fwy_col2}")

    meta_sub = meta_df[meta_df[id_col].isin(stations)].copy()
    meta_sub["abs_pm"] = pd.to_numeric(meta_sub[abs_pm_col], errors="coerce")
    meta_sub["fwy"] = meta_sub[fwy_col2].astype(str)

    # station index map
    station_to_idx = {s: i for i, s in enumerate(stations)}
    N = len(stations)
    A = np.zeros((N, N), dtype=np.float32)

    # estimate sigma from typical neighbor distances
    all_dists = []
    for fwy, grp in meta_sub.sort_values(["fwy", "abs_pm"]).groupby("fwy"):
        pm = grp["abs_pm"].dropna().values
        if len(pm) < 2:
            continue
        d = np.diff(np.sort(pm))
        d = d[d > 0]
        all_dists.extend(d.tolist())

    sigma = float(np.median(all_dists)) if len(all_dists) else 0.5
    sigma = max(sigma, 1e-3)

    def w(dist):  # gaussian weight
        return float(np.exp(- (dist**2) / (sigma**2)))

    # connect neighbors
    for fwy, grp in meta_sub.sort_values(["fwy", "abs_pm"]).groupby("fwy"):
        grp = grp.dropna(subset=["abs_pm"]).sort_values("abs_pm")
        ids = grp[id_col].astype(int).tolist()
        pms = grp["abs_pm"].astype(float).tolist()

        for i, sid in enumerate(ids):
            ii = station_to_idx[sid]
            for step in range(1, k_neighbors + 1):
                if i - step >= 0:
                    sj = ids[i - step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i-step]))
                if i + step < len(ids):
                    sj = ids[i + step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i+step]))

    # self loops + symmetrize
    np.fill_diagonal(A, 1.0)
    A = np.maximum(A, A.T)
    return A

# metadata table for adjacency should be meta with standardized station column
meta_for_adj = meta.copy()
meta_for_adj["station"] = meta_for_adj["station"].astype(int)

A = build_adjacency_from_metadata(meta_for_adj, stations=stations, k_neighbors=K_NEIGHBORS)
print("A shape:", A.shape)
print("Adjacency density (A>0):", float((A > 0).mean()))


A shape: (1821, 1821)
Adjacency density (A>0): 0.0023693916932872663


## Sliding windows + splits

Each sample uses:
- Input window:  X[t : t+IN_LEN]
- Output window: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]

We split by the **time of the first predicted hour** (t + IN_LEN):
- Train if output_start_time ≤ TRAIN_END
- Val   if TRAIN_END < output_start_time ≤ VAL_END
- Test  if output_start_time > VAL_END

Then we save everything to a single `.npz` artifact so every model reads the exact same dataset.


In [12]:
# Sliding window starts
T_total = X.shape[0]
max_t = T_total - (IN_LEN + OUT_LEN) + 1
starts = np.arange(max_t, dtype=np.int32)

timestamps = pd.DatetimeIndex(flow_imp.index)
out_start_times = timestamps[starts + IN_LEN]

train_starts = starts[out_start_times <= TRAIN_END]
val_starts   = starts[(out_start_times > TRAIN_END) & (out_start_times <= VAL_END)]
test_starts  = starts[out_start_times > VAL_END]

print(f"Window starts: train={len(train_starts)}, val={len(val_starts)}, test={len(test_starts)}")

# Train-only scalers (per node) for flow and speed (channels 0 and 1)
train_time_mask = timestamps <= TRAIN_END

flow_mean = X[train_time_mask, :, 0].mean(axis=0).astype(np.float32)
flow_std  = (X[train_time_mask, :, 0].std(axis=0) + 1e-6).astype(np.float32)

speed_mean = X[train_time_mask, :, 1].mean(axis=0).astype(np.float32)
speed_std  = (X[train_time_mask, :, 1].std(axis=0) + 1e-6).astype(np.float32)

np.savez_compressed(
    DATASET_NPZ,
    X=X.astype(np.float32),
    Y=Y.astype(np.float32),
    A=A.astype(np.float32),
    stations=stations.astype(np.int32),
    timestamps=np.array(timestamps.astype("datetime64[ns]")),
    train_starts=train_starts,
    val_starts=val_starts,
    test_starts=test_starts,
    in_len=np.array([IN_LEN], dtype=np.int32),
    out_len=np.array([OUT_LEN], dtype=np.int32),
    flow_mean=flow_mean, flow_std=flow_std,
    speed_mean=speed_mean, speed_std=speed_std,
    seed=np.array([SEED], dtype=np.int32),
)

print("Saved:", DATASET_NPZ)


Window starts: train=1080, val=360, test=673
Saved: artifacts/pems_graph_dataset.npz


## Fix window split leakage (strict horizon containment)

A window starting at time t uses:
- Input:  X[t : t+IN_LEN]
- Output: Y[t+IN_LEN : t+IN_LEN+OUT_LEN]

To prevent label leakage across train/val/test boundaries, we require:
- Train: output_end_time ≤ TRAIN_END
- Val:   output_start_time > TRAIN_END AND output_end_time ≤ VAL_END
- Test:  output_start_time > VAL_END


In [13]:
import numpy as np
import pandas as pd
from pathlib import Path

DATASET_NPZ = Path("artifacts/pems_graph_dataset.npz")
DATASET_NPZ_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")

data = np.load(DATASET_NPZ, allow_pickle=True)

X = data["X"]
Y = data["Y"]
A = data["A"]
stations = data["stations"]
timestamps = pd.to_datetime(data["timestamps"])

IN_LEN = int(data["in_len"][0])
OUT_LEN = int(data["out_len"][0])

flow_mean = data["flow_mean"]
flow_std  = data["flow_std"]
speed_mean = data["speed_mean"]
speed_std  = data["speed_std"]

T_total = X.shape[0]
max_t = T_total - (IN_LEN + OUT_LEN) + 1
starts = np.arange(max_t, dtype=np.int32)

out_start_times = timestamps[starts + IN_LEN]
out_end_times   = timestamps[starts + IN_LEN + OUT_LEN - 1]

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

# Strict splits
train_starts = starts[out_end_times <= TRAIN_END]
val_starts   = starts[(out_start_times > TRAIN_END) & (out_end_times <= VAL_END)]
test_starts  = starts[out_start_times > VAL_END]

print("STRICT window starts:")
print("train:", len(train_starts))
print("val:  ", len(val_starts))
print("test: ", len(test_starts))

np.savez_compressed(
    DATASET_NPZ_STRICT,
    X=X.astype(np.float32),
    Y=Y.astype(np.float32),
    A=A.astype(np.float32),
    stations=stations.astype(np.int32),
    timestamps=np.array(timestamps.astype("datetime64[ns]")),
    train_starts=train_starts,
    val_starts=val_starts,
    test_starts=test_starts,
    in_len=np.array([IN_LEN], dtype=np.int32),
    out_len=np.array([OUT_LEN], dtype=np.int32),
    flow_mean=flow_mean.astype(np.float32),
    flow_std=flow_std.astype(np.float32),
    speed_mean=speed_mean.astype(np.float32),
    speed_std=speed_std.astype(np.float32),
)

print("Saved strict dataset to:", DATASET_NPZ_STRICT)


STRICT window starts:
train: 1009
val:   289
test:  673
Saved strict dataset to: artifacts/pems_graph_dataset_strict.npz


## Load strict dataset artifact

We will only use the strict `.npz` going forward to ensure no leakage in labels.


In [14]:
import numpy as np
import pandas as pd
from pathlib import Path

DATASET_NPZ_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")
d = np.load(DATASET_NPZ_STRICT, allow_pickle=True)

X = d["X"]          # (T,N,F)
Y = d["Y"]          # (T,N)
A = d["A"]          # (N,N)
stations = d["stations"]
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"]
val_starts   = d["val_starts"]
test_starts  = d["test_starts"]

IN_LEN = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean = d["flow_mean"]  # (N,)
flow_std  = d["flow_std"]   # (N,)

print("X:", X.shape, "Y:", Y.shape, "A:", A.shape)
print("starts:", len(train_starts), len(val_starts), len(test_starts))
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)


X: (2208, 1821, 6) Y: (2208, 1821) A: (1821, 1821)
starts: 1009 289 673
IN_LEN: 24 OUT_LEN: 72


## Baseline evaluation

We evaluate at horizons: 12, 24, 48, 72 hours ahead.

Important detail:
Our output sequence begins at +1 hour ahead of the last input time.
So horizon `h` corresponds to output index `h-1` in the 72-step target.


In [15]:
import numpy as np
from tqdm.auto import tqdm

EVAL_HORIZONS = [12, 24, 48, 72]

def init_metric_accumulators(horizons):
    return {
        h: {"abs_sum": 0.0, "sq_sum": 0.0, "count": 0}
        for h in horizons
    }

def finalize_metrics(acc):
    out = {}
    for h, v in acc.items():
        mae = v["abs_sum"] / max(v["count"], 1)
        rmse = np.sqrt(v["sq_sum"] / max(v["count"], 1))
        out[h] = {"MAE": mae, "RMSE": rmse}
    return out

def print_metrics(title, metrics_dict):
    print("\n" + title)
    for h in sorted(metrics_dict.keys()):
        print(f"  {h:>3}h  MAE={metrics_dict[h]['MAE']:.3f}  RMSE={metrics_dict[h]['RMSE']:.3f}")


### Baseline 1 — Persistence

For each station:
- Predict that all future horizons equal the **last observed flow** in the input window.


In [16]:
def eval_persistence(X, Y, starts, in_len, horizons, desc=""):
    acc = init_metric_accumulators(horizons)

    for t in tqdm(starts, desc=desc):
        # last observed flow at the end of input window
        last_flow = X[t + in_len - 1, :, 0]  # (N,)

        for h in horizons:
            idx = h - 1
            true = Y[t + in_len + idx, :]     # (N,)
            pred = last_flow                  # (N,)

            err = pred - true
            acc[h]["abs_sum"] += float(np.abs(err).sum())
            acc[h]["sq_sum"]  += float((err ** 2).sum())
            acc[h]["count"]   += err.size

    return finalize_metrics(acc)

pers_val  = eval_persistence(X, Y, val_starts, IN_LEN, EVAL_HORIZONS, desc="Persistence (val)")
pers_test = eval_persistence(X, Y, test_starts, IN_LEN, EVAL_HORIZONS, desc="Persistence (test)")

print_metrics("Persistence — Validation", pers_val)
print_metrics("Persistence — Test", pers_test)


Persistence (val):   0%|          | 0/289 [00:00<?, ?it/s]

Persistence (test):   0%|          | 0/673 [00:00<?, ?it/s]


Persistence — Validation
   12h  MAE=910.133  RMSE=1437.403
   24h  MAE=151.755  RMSE=354.476
   48h  MAE=203.020  RMSE=451.021
   72h  MAE=220.150  RMSE=478.588

Persistence — Test
   12h  MAE=917.014  RMSE=1455.414
   24h  MAE=147.896  RMSE=340.247
   48h  MAE=200.592  RMSE=443.364
   72h  MAE=196.856  RMSE=431.611


### Baseline 2 — Historical Average (HA-168)

We compute a per-node seasonal mean using train data only:
- slot = (day_of_week * 24 + hour) ∈ [0..167]
- mean_flow[slot, node] = average flow in train for that slot

Forecast:
- For each horizon step, use the slot mean of that future timestamp.


In [17]:
def build_ha168_means(Y, timestamps, train_end):
    train_mask = timestamps <= train_end
    Y_train = Y[train_mask]  # (T_train, N)
    ts_train = timestamps[train_mask]

    slot = ts_train.dayofweek.to_numpy() * 24 + ts_train.hour.to_numpy()  # (T_train,)
    G = 168
    N = Y.shape[1]

    means = np.zeros((G, N), dtype=np.float32)
    counts = np.zeros((G,), dtype=np.int64)

    for g in range(G):
        m = (slot == g)
        if m.any():
            means[g] = Y_train[m].mean(axis=0)
            counts[g] = int(m.sum())
        else:
            # fallback (should be rare)
            means[g] = Y_train.mean(axis=0)
            counts[g] = 0
    return means, counts

def eval_ha168(Y, timestamps, starts, in_len, horizons, means_ha168, desc=""):
    acc = init_metric_accumulators(horizons)

    for t in tqdm(starts, desc=desc):
        for h in horizons:
            idx = h - 1
            future_time = timestamps[t + in_len + idx]
            g = int(future_time.dayofweek * 24 + future_time.hour)

            pred = means_ha168[g, :]                 # (N,)
            true = Y[t + in_len + idx, :]            # (N,)

            err = pred - true
            acc[h]["abs_sum"] += float(np.abs(err).sum())
            acc[h]["sq_sum"]  += float((err ** 2).sum())
            acc[h]["count"]   += err.size

    return finalize_metrics(acc)

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
ha_means, ha_counts = build_ha168_means(Y, timestamps, TRAIN_END)

ha_val  = eval_ha168(Y, timestamps, val_starts, IN_LEN, EVAL_HORIZONS, ha_means, desc="HA-168 (val)")
ha_test = eval_ha168(Y, timestamps, test_starts, IN_LEN, EVAL_HORIZONS, ha_means, desc="HA-168 (test)")

print_metrics("HA-168 — Validation", ha_val)
print_metrics("HA-168 — Test", ha_test)


HA-168 (val):   0%|          | 0/289 [00:00<?, ?it/s]

HA-168 (test):   0%|          | 0/673 [00:00<?, ?it/s]


HA-168 — Validation
   12h  MAE=116.424  RMSE=258.349
   24h  MAE=123.997  RMSE=277.816
   48h  MAE=134.454  RMSE=302.695
   72h  MAE=137.523  RMSE=306.085

HA-168 — Test
   12h  MAE=119.650  RMSE=283.458
   24h  MAE=120.382  RMSE=284.451
   48h  MAE=122.195  RMSE=286.807
   72h  MAE=126.187  RMSE=293.463


## PyTorch Dataset for sliding windows

Each item returns:
- x: (C, N, IN_LEN)   where C = number of features (6)
- y: (OUT_LEN, N)     scaled flow targets

Scaling:
- We scale flow and speed using TRAIN-only mean/std (per node)
- Time features remain unchanged (already bounded by sin/cos)


In [18]:
import torch
from torch.utils.data import Dataset, DataLoader

class PemsWindowDataset(Dataset):
    def __init__(self, X, Y, starts, in_len, out_len, flow_mean, flow_std, speed_mean, speed_std):
        self.X = X
        self.Y = Y
        self.starts = starts
        self.in_len = in_len
        self.out_len = out_len

        self.flow_mean = flow_mean.astype(np.float32)
        self.flow_std  = flow_std.astype(np.float32)
        self.speed_mean = speed_mean.astype(np.float32)
        self.speed_std  = speed_std.astype(np.float32)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])

        x = self.X[t : t + self.in_len].copy().astype(np.float32)  # (IN_LEN, N, F)
        y = self.Y[t + self.in_len : t + self.in_len + self.out_len].copy().astype(np.float32)  # (OUT_LEN, N)

        # scale input channels: flow=0, speed=1
        x[..., 0] = (x[..., 0] - self.flow_mean[None, :]) / self.flow_std[None, :]
        x[..., 1] = (x[..., 1] - self.speed_mean[None, :]) / self.speed_std[None, :]

        # scale targets (flow)
        y = (y - self.flow_mean[None, :]) / self.flow_std[None, :]

        # rearrange x to (C, N, IN_LEN) for conv models
        x = np.transpose(x, (2, 1, 0))  # (F, N, IN_LEN)

        return torch.from_numpy(x), torch.from_numpy(y)

# Load speed scaler too
speed_mean = d["speed_mean"]
speed_std  = d["speed_std"]

train_ds = PemsWindowDataset(X, Y, train_starts, IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
val_ds   = PemsWindowDataset(X, Y, val_starts,   IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
test_ds  = PemsWindowDataset(X, Y, test_starts,  IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)

BATCH_SIZE = 16

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape)
# Expect: x=(B, 6, N, 24) and y=(B, 72, N)


Batch x: torch.Size([16, 6, 1821, 24]) Batch y: torch.Size([16, 72, 1821])


In [19]:
import numpy as np

def mean_abs_change(Y, h):
    return float(np.abs(Y[h:] - Y[:-h]).mean())

for h in [1, 6, 12, 24, 48]:
    print(f"Mean |Δ| over {h}h shift: {mean_abs_change(Y, h):.3f}")


Mean |Δ| over 1h shift: 156.388
Mean |Δ| over 6h shift: 666.328
Mean |Δ| over 12h shift: 921.578
Mean |Δ| over 24h shift: 143.347
Mean |Δ| over 48h shift: 202.141


## Graph supports (normalized adjacency)

GraphWaveNet uses graph propagation through adjacency matrices ("supports").
We will build:
- A_rw  = row-normalized adjacency (random-walk normalization)
- A_rwT = transpose support (helps if graph is directed; still ok for symmetric graphs)

We store them as sparse tensors for speed (our adjacency is very sparse).


In [20]:
import numpy as np
import torch

def row_normalize_dense(A: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    d = A.sum(axis=1, keepdims=True)
    return A / (d + eps)

def dense_to_torch_sparse(A: np.ndarray, device: str):
    A = A.astype(np.float32)
    idx = np.nonzero(A)
    values = A[idx]
    indices = np.vstack(idx)  # (2, nnz)

    indices = torch.tensor(indices, dtype=torch.long, device=device)
    values  = torch.tensor(values, dtype=torch.float32, device=device)
    shape = A.shape

    sp = torch.sparse_coo_tensor(indices, values, size=shape, device=device).coalesce()
    return sp

A_rw = row_normalize_dense(A)
A_rwT = row_normalize_dense(A.T)

supports = [
    dense_to_torch_sparse(A_rw, DEVICE),
    dense_to_torch_sparse(A_rwT, DEVICE),
]

print("Supports:", [s.shape for s in supports], "nnz:", [int(s._nnz()) for s in supports])


Supports: [torch.Size([1821, 1821]), torch.Size([1821, 1821])] nnz: [7856, 7856]


## Diffusion graph convolution (sparse)

We need a fast way to compute:
A @ X  (graph propagation)

X is batched with shape (B, C, N, T).
We reshape into (N, B*C*T) so sparse matrix multiply works efficiently.


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NConv(nn.Module):
    """Sparse graph multiplication: (N,N) @ (B,C,N,T) -> (B,C,N,T)"""
    def forward(self, x, A_sp):
        # x: (B,C,N,T)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)      # (N, B*C*T)
        x_r = torch.sparse.mm(A_sp, x_r)                # (N, B*C*T)
        x_out = x_r.reshape(N, B, C, T).permute(1, 2, 0, 3)  # (B,C,N,T)
        return x_out

class DiffusionGraphConv(nn.Module):
    """
    Diffusion graph conv:
    concat [X, A1X, A2X, ...] then 1x1 conv.

    order=1 means we use only one hop per support (fast and stable).
    """
    def __init__(self, c_in, c_out, supports, order=1, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.order = order
        self.dropout = dropout

        # total input channels after concatenation
        c_total = c_in * (1 + len(supports) * order)
        self.mlp = nn.Conv2d(c_total, c_out, kernel_size=(1, 1))

    def forward(self, x):
        out = [x]
        for A in self.supports:
            x1 = self.nconv(x, A)
            out.append(x1)
            for _ in range(2, self.order + 1):
                x1 = self.nconv(x1, A)
                out.append(x1)

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h


## GraphWaveNet baseline model

We use:
- Causal dilated temporal convolutions (TCN-style)
- Gated activations (tanh * sigmoid)
- Residual + skip connections
- Diffusion graph convolution inside each layer

Output head:
- Uses the final time step embedding to predict OUT_LEN horizons directly
- Output shape: (B, OUT_LEN, N)


In [23]:
class CausalConv2d(nn.Module):
    """Causal conv along time axis only (last dimension)."""
    def __init__(self, c_in, c_out, kernel_size=2, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv2d(
            c_in, c_out,
            kernel_size=(1, kernel_size),
            dilation=(1, dilation)
        )

    def forward(self, x):
        # pad left on time dimension: (left_pad, right_pad, top, bottom) for 2d -> (time_left, time_right, node_left, node_right)
        x = F.pad(x, (self.pad, 0, 0, 0))
        return self.conv(x)

class GraphWaveNet(nn.Module):
    def __init__(
        self,
        num_nodes: int,
        in_dim: int,
        out_len: int,
        supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=3,
        gcn_order=1,
        dropout=0.3,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.out_len = out_len
        self.dropout = dropout

        self.start_conv = nn.Conv2d(in_dim, residual_channels, kernel_size=(1, 1))

        self.filter_convs = nn.ModuleList()
        self.gate_convs   = nn.ModuleList()
        self.res_convs    = nn.ModuleList()
        self.skip_convs   = nn.ModuleList()
        self.bn           = nn.ModuleList()
        self.gconvs       = nn.ModuleList()

        # Build temporal + graph blocks
        for _ in range(blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i

                self.filter_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))
                self.gate_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))

                self.res_convs.append(nn.Conv2d(dilation_channels, residual_channels, kernel_size=(1, 1)))
                self.skip_convs.append(nn.Conv2d(dilation_channels, skip_channels, kernel_size=(1, 1)))

                self.gconvs.append(
                    DiffusionGraphConv(dilation_channels, residual_channels, supports, order=gcn_order, dropout=dropout)
                )

                self.bn.append(nn.BatchNorm2d(residual_channels))

        self.end_conv_1 = nn.Conv2d(skip_channels, end_channels, kernel_size=(1, 1))
        self.end_conv_2 = nn.Conv2d(end_channels, out_len, kernel_size=(1, 1))  # outputs OUT_LEN channels

    def forward(self, x):
        """
        x: (B, F, N, T_in)
        return: (B, OUT_LEN, N)
        """
        x = self.start_conv(x)  # (B, residual, N, T)
        skip = None

        for i in range(len(self.filter_convs)):
            residual = x

            # gated TCN
            filt = torch.tanh(self.filter_convs[i](x))
            gate = torch.sigmoid(self.gate_convs[i](x))
            x = filt * gate
            x = F.dropout(x, p=self.dropout, training=self.training)

            # skip
            s = self.skip_convs[i](x)
            skip = s if skip is None else (skip + s)

            # graph conv -> residual channels
            x = self.gconvs[i](x)

            # residual connection (time length is preserved by causal padding)
            x = x + residual
            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))

        # Use last time step to predict future horizons
        x_last = x[..., -1:].contiguous()          # (B, end_channels, N, 1)
        out = self.end_conv_2(x_last).squeeze(-1)  # (B, OUT_LEN, N)
        return out


## Training & evaluation loop

We train using MSE on scaled targets (stable optimization),
then compute MAE/RMSE on the original scale at horizons 12/24/48/72.

Early stopping monitors average validation MAE across horizons.


In [25]:
import torch
import numpy as np
from tqdm.auto import tqdm

# tensors for unscaling (put on device once)
flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

EVAL_HORIZONS = [12, 24, 48, 72]

@torch.no_grad()
def eval_horizons(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)   # (B,F,N,T)
        yb = yb.to(DEVICE, non_blocking=True)   # (B,OUT,N) scaled

        pred = model(xb)                        # (B,OUT,N) scaled

        # unscale to original flow units
        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        for h in EVAL_HORIZONS:
            idx = h - 1
            err = pred_u[:, idx, :] - true_u[:, idx, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

def train_gwn(
    epochs=30,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=6,
):
    model = GraphWaveNet(
        num_nodes=X.shape[1],
        in_dim=X.shape[2],     # 6 features
        out_len=OUT_LEN,
        supports=supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=3,
        gcn_order=1,
        dropout=0.3,
    ).to(DEVICE)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.MSELoss()

    use_amp = False
    scaler = torch.cuda.amp.GradScaler(enabled=False)

    best_score = float("inf")
    best_state = None
    bad = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                pred = model(xb)
                loss = loss_fn(pred, yb)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            scaler.step(opt)
            scaler.update()

            running += float(loss.item())

        # validation metrics (original scale)
        val_metrics = eval_horizons(model, val_loader)
        score = avg_mae(val_metrics)

        print(f"Epoch {epoch}: train_loss={running/len(train_loader):.6f}  val_avg_MAE={score:.3f}")
        print_metrics("Val metrics", val_metrics)

        if score < best_score:
            best_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print(f"Early stopping at epoch {epoch}. Best val_avg_MAE={best_score:.3f}")
                break

    # load best
    model.load_state_dict(best_state)
    return model

gwn_model = train_gwn(epochs=30, lr=1e-3, weight_decay=1e-4, clip=5.0, patience=6)

# Final evaluation
val_m = eval_horizons(gwn_model, val_loader)
test_m = eval_horizons(gwn_model, test_loader)

print_metrics("GraphWaveNet — Validation", val_m)
print_metrics("GraphWaveNet — Test", test_m)


Epoch 1/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1: train_loss=0.529129  val_avg_MAE=203.973

Val metrics
   12h  MAE=225.832  RMSE=393.051
   24h  MAE=184.901  RMSE=343.538
   48h  MAE=208.168  RMSE=379.516
   72h  MAE=196.993  RMSE=369.422


Epoch 2/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2: train_loss=0.325922  val_avg_MAE=185.356

Val metrics
   12h  MAE=193.387  RMSE=348.977
   24h  MAE=166.734  RMSE=319.229
   48h  MAE=185.582  RMSE=349.533
   72h  MAE=195.720  RMSE=369.310


Epoch 3/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3: train_loss=0.287739  val_avg_MAE=180.951

Val metrics
   12h  MAE=179.654  RMSE=333.382
   24h  MAE=157.967  RMSE=309.919
   48h  MAE=187.718  RMSE=347.820
   72h  MAE=198.465  RMSE=367.385


Epoch 4/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4: train_loss=0.268916  val_avg_MAE=175.876

Val metrics
   12h  MAE=186.728  RMSE=342.408
   24h  MAE=154.416  RMSE=309.196
   48h  MAE=171.948  RMSE=328.877
   72h  MAE=190.413  RMSE=357.469


Epoch 5/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5: train_loss=0.257975  val_avg_MAE=170.299

Val metrics
   12h  MAE=184.271  RMSE=338.048
   24h  MAE=148.209  RMSE=299.617
   48h  MAE=168.291  RMSE=323.040
   72h  MAE=180.426  RMSE=343.043


Epoch 6/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6: train_loss=0.250068  val_avg_MAE=164.475

Val metrics
   12h  MAE=162.037  RMSE=303.679
   24h  MAE=150.514  RMSE=301.402
   48h  MAE=164.907  RMSE=316.562
   72h  MAE=180.444  RMSE=343.314


Epoch 7/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7: train_loss=0.244371  val_avg_MAE=167.873

Val metrics
   12h  MAE=164.548  RMSE=308.877
   24h  MAE=156.601  RMSE=304.825
   48h  MAE=176.487  RMSE=327.605
   72h  MAE=173.856  RMSE=331.363


Epoch 8/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8: train_loss=0.239806  val_avg_MAE=160.063

Val metrics
   12h  MAE=160.283  RMSE=302.982
   24h  MAE=148.200  RMSE=298.973
   48h  MAE=161.825  RMSE=314.139
   72h  MAE=169.943  RMSE=325.858


Epoch 9/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9: train_loss=0.234704  val_avg_MAE=163.708

Val metrics
   12h  MAE=158.480  RMSE=303.029
   24h  MAE=158.671  RMSE=308.487
   48h  MAE=163.579  RMSE=313.825
   72h  MAE=174.101  RMSE=332.094


Epoch 10/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10: train_loss=0.234201  val_avg_MAE=155.655

Val metrics
   12h  MAE=153.175  RMSE=294.953
   24h  MAE=141.512  RMSE=293.586
   48h  MAE=159.561  RMSE=315.417
   72h  MAE=168.371  RMSE=323.676


Epoch 11/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11: train_loss=0.231096  val_avg_MAE=154.549

Val metrics
   12h  MAE=152.754  RMSE=293.402
   24h  MAE=139.974  RMSE=287.064
   48h  MAE=157.335  RMSE=307.804
   72h  MAE=168.134  RMSE=323.799


Epoch 12/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12: train_loss=0.229034  val_avg_MAE=164.792

Val metrics
   12h  MAE=151.243  RMSE=291.496
   24h  MAE=147.712  RMSE=297.619
   48h  MAE=167.106  RMSE=321.370
   72h  MAE=193.109  RMSE=358.011


Epoch 13/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13: train_loss=0.223892  val_avg_MAE=164.343

Val metrics
   12h  MAE=148.523  RMSE=289.870
   24h  MAE=146.654  RMSE=296.672
   48h  MAE=173.357  RMSE=328.243
   72h  MAE=188.839  RMSE=347.545


Epoch 14/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14: train_loss=0.222705  val_avg_MAE=166.283

Val metrics
   12h  MAE=153.376  RMSE=294.916
   24h  MAE=155.567  RMSE=312.574
   48h  MAE=176.343  RMSE=332.107
   72h  MAE=179.847  RMSE=333.986


Epoch 15/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15: train_loss=0.221832  val_avg_MAE=163.332

Val metrics
   12h  MAE=169.177  RMSE=314.841
   24h  MAE=143.233  RMSE=290.326
   48h  MAE=168.777  RMSE=325.777
   72h  MAE=172.140  RMSE=329.309


Epoch 16/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16: train_loss=0.221454  val_avg_MAE=155.481

Val metrics
   12h  MAE=144.691  RMSE=279.696
   24h  MAE=139.187  RMSE=284.711
   48h  MAE=161.616  RMSE=311.975
   72h  MAE=176.428  RMSE=331.767


Epoch 17/30:   0%|          | 0/64 [00:00<?, ?it/s]

Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17: train_loss=0.217759  val_avg_MAE=155.037

Val metrics
   12h  MAE=151.572  RMSE=293.354
   24h  MAE=139.192  RMSE=286.145
   48h  MAE=161.421  RMSE=316.536
   72h  MAE=167.963  RMSE=321.738
Early stopping at epoch 17. Best val_avg_MAE=154.549


Eval:   0%|          | 0/19 [00:00<?, ?it/s]

Eval:   0%|          | 0/43 [00:00<?, ?it/s]


GraphWaveNet — Validation
   12h  MAE=152.754  RMSE=293.402
   24h  MAE=139.974  RMSE=287.064
   48h  MAE=157.335  RMSE=307.804
   72h  MAE=168.134  RMSE=323.799

GraphWaveNet — Test
   12h  MAE=156.900  RMSE=298.198
   24h  MAE=134.966  RMSE=269.862
   48h  MAE=153.469  RMSE=300.160
   72h  MAE=161.221  RMSE=310.938


# GraphWaveNet Baseline (Journal-Grade) - Working on it again

This notebook section trains a strong GraphWaveNet baseline on PeMS using:
- Strict train/val/test window splits (no horizon leakage)
- Direction-aware adjacency: edges only within same freeway AND same direction
- Known future calendar covariates (hour-of-day/day-of-week), which are available at forecast time
- Multi-horizon evaluation at 12/24/48/72 hours

Outputs:
- Validation and test MAE/RMSE at each horizon


In [3]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


In [4]:
DATASET_BASE = Path("artifacts/pems_graph_dataset.npz")
DATASET_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")

TRAIN_END = pd.Timestamp("2024-11-15 23:59:59")
VAL_END   = pd.Timestamp("2024-11-30 23:59:59")

def make_strict_dataset(base_npz: Path, strict_npz: Path):
    d = np.load(base_npz, allow_pickle=True)

    X = d["X"]; Y = d["Y"]; A = d["A"]
    stations = d["stations"]
    timestamps = pd.to_datetime(d["timestamps"])
    IN_LEN = int(d["in_len"][0])
    OUT_LEN = int(d["out_len"][0])

    flow_mean = d["flow_mean"]; flow_std = d["flow_std"]
    speed_mean = d["speed_mean"]; speed_std = d["speed_std"]

    T_total = X.shape[0]
    max_t = T_total - (IN_LEN + OUT_LEN) + 1
    starts = np.arange(max_t, dtype=np.int32)

    out_start_times = timestamps[starts + IN_LEN]
    out_end_times   = timestamps[starts + IN_LEN + OUT_LEN - 1]

    train_starts = starts[out_end_times <= TRAIN_END]
    val_starts   = starts[(out_start_times > TRAIN_END) & (out_end_times <= VAL_END)]
    test_starts  = starts[out_start_times > VAL_END]

    np.savez_compressed(
        strict_npz,
        X=X.astype(np.float32),
        Y=Y.astype(np.float32),
        A=A.astype(np.float32),
        stations=stations.astype(np.int32),
        timestamps=np.array(timestamps.astype("datetime64[ns]")),
        train_starts=train_starts,
        val_starts=val_starts,
        test_starts=test_starts,
        in_len=np.array([IN_LEN], dtype=np.int32),
        out_len=np.array([OUT_LEN], dtype=np.int32),
        flow_mean=flow_mean.astype(np.float32),
        flow_std=flow_std.astype(np.float32),
        speed_mean=speed_mean.astype(np.float32),
        speed_std=speed_std.astype(np.float32),
        seed=np.array([SEED], dtype=np.int32),
    )
    print("Saved strict dataset:", strict_npz)

if (not DATASET_STRICT.exists()) and DATASET_BASE.exists():
    make_strict_dataset(DATASET_BASE, DATASET_STRICT)

assert DATASET_STRICT.exists(), "Strict dataset not found. Make sure artifacts are created."

d = np.load(DATASET_STRICT, allow_pickle=True)
X = d["X"]                      # (T,N,F)
Y = d["Y"]                      # (T,N)
stations = d["stations"]        # (N,)
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"]
val_starts   = d["val_starts"]
test_starts  = d["test_starts"]

IN_LEN = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean = d["flow_mean"]
flow_std  = d["flow_std"]
speed_mean = d["speed_mean"]
speed_std  = d["speed_std"]

print("X:", X.shape, "Y:", Y.shape)
print("N stations:", len(stations), "T:", len(timestamps))
print("starts:", len(train_starts), len(val_starts), len(test_starts))
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)


X: (2208, 1821, 6) Y: (2208, 1821)
N stations: 1821 T: 2208
starts: 1009 289 673
IN_LEN: 24 OUT_LEN: 72


## Direction-aware adjacency (recommended baseline graph)

We build edges only among sensors that share:
- the same freeway (Fwy)
- the same direction of travel (Direction of Travel)
sorted by Abs PM.

This avoids mixing opposite-direction stations, which often destroys graph model performance.


In [5]:
TRAFFIC_CSV = Path("cleaned_traffic_data.csv")
META_XLSX   = Path("pems_output.xlsx")

assert TRAFFIC_CSV.exists(), "cleaned_traffic_data.csv not found in working directory."
assert META_XLSX.exists(), "pems_output.xlsx not found in working directory."

# 1) station -> direction mapping (mode direction per station)
tmp = pd.read_csv(TRAFFIC_CSV, usecols=["Station", "Direction of Travel"])
tmp = tmp.rename(columns={"Station": "station", "Direction of Travel": "direction"})
tmp["station"] = pd.to_numeric(tmp["station"], errors="coerce").astype("Int64")
tmp = tmp.dropna(subset=["station"])
tmp["station"] = tmp["station"].astype(int)

station_dir = tmp.groupby("station")["direction"].agg(lambda x: x.mode().iloc[0] if len(x.mode()) else x.iloc[0])
station_dir = station_dir.reindex(stations)

print("Direction mapping coverage:", float(station_dir.notna().mean()))
print("Direction counts:\n", station_dir.value_counts(dropna=False))

# 2) metadata
meta = pd.read_excel(META_XLSX)
# your metadata already has "station" column (from your earlier merge output)
if "station" not in meta.columns and "ID" in meta.columns:
    meta = meta.rename(columns={"ID": "station"})
meta["station"] = pd.to_numeric(meta["station"], errors="coerce").astype("Int64")
meta = meta.dropna(subset=["station"])
meta["station"] = meta["station"].astype(int)

assert "Fwy" in meta.columns and "Abs PM" in meta.columns, "Metadata must contain 'Fwy' and 'Abs PM'."

def build_adjacency_fwy_dir(meta_df, stations, station_dir, k_neighbors=4):
    meta_sub = meta_df[meta_df["station"].isin(stations)].copy()
    meta_sub["fwy"] = meta_sub["Fwy"].astype(str)
    meta_sub["abs_pm"] = pd.to_numeric(meta_sub["Abs PM"], errors="coerce")
    meta_sub["direction"] = meta_sub["station"].map(station_dir)

    # If direction missing, we exclude from neighbor edges (but keep self-loop later)
    meta_sub = meta_sub.dropna(subset=["abs_pm", "direction"]).copy()
    meta_sub["direction"] = meta_sub["direction"].astype(str)

    station_to_idx = {s: i for i, s in enumerate(stations)}
    N = len(stations)
    A = np.zeros((N, N), dtype=np.float32)

    # sigma from typical neighbor distances
    all_dists = []
    for (fwy, direc), grp in meta_sub.sort_values(["fwy", "direction", "abs_pm"]).groupby(["fwy", "direction"]):
        pm = grp["abs_pm"].values
        if len(pm) < 2:
            continue
        d = np.diff(np.sort(pm))
        d = d[d > 0]
        all_dists.extend(d.tolist())

    sigma = float(np.median(all_dists)) if len(all_dists) else 0.5
    sigma = max(sigma, 1e-3)

    def w(dist):
        return float(np.exp(- (dist**2) / (sigma**2)))

    for (fwy, direc), grp in meta_sub.sort_values(["fwy", "direction", "abs_pm"]).groupby(["fwy", "direction"]):
        grp = grp.sort_values("abs_pm")
        ids = grp["station"].astype(int).tolist()
        pms = grp["abs_pm"].astype(float).tolist()

        for i, sid in enumerate(ids):
            ii = station_to_idx[sid]
            for step in range(1, k_neighbors + 1):
                if i - step >= 0:
                    sj = ids[i-step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i-step]))
                if i + step < len(ids):
                    sj = ids[i+step]; jj = station_to_idx[sj]
                    A[ii, jj] = w(abs(pms[i] - pms[i+step]))

    np.fill_diagonal(A, 1.0)
    A = np.maximum(A, A.T)
    return A

A_dir = build_adjacency_fwy_dir(meta, stations, station_dir, k_neighbors=4)
print("A_dir shape:", A_dir.shape)
print("A_dir density:", float((A_dir > 0).mean()))

def row_normalize_dense(A, eps=1e-6):
    d = A.sum(axis=1, keepdims=True)
    return A / (d + eps)

def dense_to_sparse(A, device):
    idx = np.nonzero(A)
    values = A[idx].astype(np.float32)
    indices = np.vstack(idx)
    return torch.sparse_coo_tensor(
        torch.tensor(indices, dtype=torch.long, device=device),
        torch.tensor(values, dtype=torch.float32, device=device),
        size=A.shape,
        device=device
    ).coalesce()

A_rw  = row_normalize_dense(A_dir)
A_rwT = row_normalize_dense(A_dir.T)

supports = [dense_to_sparse(A_rw, DEVICE), dense_to_sparse(A_rwT, DEVICE)]
print("Supports nnz:", [int(s._nnz()) for s in supports])


Direction mapping coverage: 1.0
Direction counts:
 direction
E    509
W    499
S    410
N    403
Name: count, dtype: int64
A_dir shape: (1821, 1821)
A_dir density: 0.0038374073179432942
Supports nnz: [12720, 12720]


## Dataset and loaders

Each sample returns:
- x:  (F, N, IN_LEN)
- y:  (OUT_LEN, N)  scaled flow target
- tf: (OUT_LEN, 4)  known future calendar covariates (hour/dow sin/cos)

We compute tf from timestamps (not from future traffic), so it is NOT leakage.


In [6]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)  # (T,4)

class PemsWindowDatasetTF(Dataset):
    """
    Returns:
      x:  (F, N, IN_LEN)
      y:  (OUT_LEN, N) scaled
      tf: (OUT_LEN, 4) known future time features
    """
    def __init__(self, X, Y, timestamps, starts, in_len, out_len,
                 flow_mean, flow_std, speed_mean, speed_std):
        self.X = X
        self.Y = Y
        self.timestamps = pd.DatetimeIndex(timestamps)
        self.starts = starts.astype(np.int32)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

        self.flow_mean = flow_mean.astype(np.float32)
        self.flow_std  = flow_std.astype(np.float32)
        self.speed_mean = speed_mean.astype(np.float32)
        self.speed_std  = speed_std.astype(np.float32)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])

        x = self.X[t:t+self.in_len].copy().astype(np.float32)  # (IN_LEN, N, F)
        y = self.Y[t+self.in_len:t+self.in_len+self.out_len].copy().astype(np.float32)  # (OUT_LEN, N)

        ts_future = self.timestamps[t+self.in_len:t+self.in_len+self.out_len]
        tf = time_encoding(ts_future)  # (OUT_LEN, 4)

        # scale inputs: flow=0 speed=1
        x[..., 0] = (x[..., 0] - self.flow_mean[None, :]) / self.flow_std[None, :]
        x[..., 1] = (x[..., 1] - self.speed_mean[None, :]) / self.speed_std[None, :]

        # scale targets (flow)
        y = (y - self.flow_mean[None, :]) / self.flow_std[None, :]

        # (IN_LEN, N, F) -> (F, N, IN_LEN)
        x = np.transpose(x, (2, 1, 0))

        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(tf)

# Safer batch size for GraphWaveNet on N=1821
BATCH_SIZE = 8

train_ds = PemsWindowDatasetTF(X, Y, timestamps, train_starts, IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
val_ds   = PemsWindowDatasetTF(X, Y, timestamps, val_starts,   IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)
test_ds  = PemsWindowDatasetTF(X, Y, timestamps, test_starts,  IN_LEN, OUT_LEN, flow_mean, flow_std, speed_mean, speed_std)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)



Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


## GraphWaveNet baseline (sparse diffusion graph conv + dilated temporal conv)

We:
- use causal dilated temporal convolutions + gating
- do sparse graph propagation with the supports
- use residual + skip connections
- use a time-aware output head with future calendar covariates


In [7]:
class NConv(nn.Module):
    """Sparse graph multiplication: (N,N) @ (B,C,N,T) -> (B,C,N,T)"""
    def forward(self, x, A_sp):
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)  # (N, B*C*T)

        # sparse.mm requires float32 on cuda reliably
        out = torch.sparse.mm(A_sp, x_r.float())
        out = out.reshape(N, B, C, T).permute(1, 2, 0, 3)
        return out.to(dtype=x.dtype)

class DiffusionGraphConv(nn.Module):
    def __init__(self, c_in, c_out, supports, order=1, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.order = order
        self.dropout = dropout

        c_total = c_in * (1 + len(supports) * order)
        self.mlp = nn.Conv2d(c_total, c_out, kernel_size=(1, 1))

    def forward(self, x):
        out = [x]
        for A in self.supports:
            x1 = self.nconv(x, A)
            out.append(x1)
            for _ in range(2, self.order + 1):
                x1 = self.nconv(x1, A)
                out.append(x1)

        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h

class CausalConv2d(nn.Module):
    """Causal conv along time axis only."""
    def __init__(self, c_in, c_out, kernel_size=2, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=(1, kernel_size), dilation=(1, dilation))

    def forward(self, x):
        x = F.pad(x, (self.pad, 0, 0, 0))
        return self.conv(x)

class GraphWaveNetTimeAware(nn.Module):
    def __init__(
        self,
        num_nodes,
        in_dim,
        out_len,
        supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=4,   # ensures receptive field >= 24
        gcn_order=1,
        dropout=0.1,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.out_len = out_len
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.blocks = blocks
        self.layers_per_block = layers_per_block

        # receptive field
        receptive_field = 1
        for _ in range(blocks):
            for i in range(layers_per_block):
                receptive_field += (kernel_size - 1) * (2 ** i)
        self.receptive_field = receptive_field

        self.start_conv = nn.Conv2d(in_dim, residual_channels, kernel_size=(1, 1))

        self.filter_convs = nn.ModuleList()
        self.gate_convs   = nn.ModuleList()
        self.skip_convs   = nn.ModuleList()
        self.bn           = nn.ModuleList()
        self.gconvs       = nn.ModuleList()

        for _ in range(blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i
                self.filter_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))
                self.gate_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))

                self.skip_convs.append(nn.Conv2d(dilation_channels, skip_channels, kernel_size=(1, 1)))
                self.gconvs.append(DiffusionGraphConv(dilation_channels, residual_channels, supports, order=gcn_order, dropout=dropout))
                self.bn.append(nn.BatchNorm2d(residual_channels))

        self.end_conv_1 = nn.Conv2d(skip_channels, end_channels, kernel_size=(1, 1))

        # time-aware head
        self.time_embed = nn.Linear(4, end_channels)
        self.horizon_out = nn.Linear(end_channels, 1)

    def forward(self, x, tf_future):
        """
        x:        (B, F, N, IN_LEN)
        tf_future:(B, OUT_LEN, 4)
        out:      (B, OUT_LEN, N)
        """
        # pad input if needed to match receptive field
        if x.size(-1) < self.receptive_field:
            pad_len = self.receptive_field - x.size(-1)
            x = F.pad(x, (pad_len, 0, 0, 0))

        x = self.start_conv(x)  # (B, residual, N, T)
        skip = None

        for i in range(len(self.filter_convs)):
            residual = x

            filt = torch.tanh(self.filter_convs[i](x))
            gate = torch.sigmoid(self.gate_convs[i](x))
            x = filt * gate
            x = F.dropout(x, p=self.dropout, training=self.training)

            s = self.skip_convs[i](x)
            skip = s if skip is None else (skip + s)

            x = self.gconvs[i](x)
            x = x + residual
            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))     # (B, end_channels, N, T)

        z = x[..., -1]                     # (B, end_channels, N)
        z = z.permute(0, 2, 1)             # (B, N, end_channels)

        te = self.time_embed(tf_future)    # (B, OUT_LEN, end_channels)

        h = F.relu(z.unsqueeze(1) + te.unsqueeze(2))  # (B, OUT_LEN, N, end_channels)
        out = self.horizon_out(h).squeeze(-1)         # (B, OUT_LEN, N)
        return out


## Training + evaluation

- Loss: SmoothL1 (Huber) on SCALED targets
- Metrics: MAE/RMSE on ORIGINAL scale
- Early stopping on average validation MAE (across horizons)
- AMP disabled (sparse graph ops do not support fp16 reliably)


In [None]:
EVAL_HORIZONS = [12, 24, 48, 72]

flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

@torch.no_grad()
def eval_horizons(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # scaled

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        for h in EVAL_HORIZONS:
            idx = h - 1
            err = pred_u[:, idx, :] - true_u[:, idx, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

def train_gwn_baseline(
    epochs=40,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=8,
):
    model = GraphWaveNetTimeAware(
        num_nodes=X.shape[1],
        in_dim=X.shape[2],
        out_len=OUT_LEN,
        supports=supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=4,
        gcn_order=1,
        dropout=0.1,
    ).to(DEVICE)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    bad = 0

    # IMPORTANT: disable AMP due to sparse graph ops
    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            running += float(loss.item())

        val_metrics = eval_horizons(model, val_loader)
        score = avg_mae(val_metrics)

        print(f"\nEpoch {epoch}: train_loss={running/len(train_loader):.6f}  val_avg_MAE={score:.3f}")
        print_metrics("Validation metrics", val_metrics)

        if score < best_score:
            best_score = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                break

    model.load_state_dict(best_state)
    return model

gwn = train_gwn_baseline()

val_m = eval_horizons(gwn, val_loader)
test_m = eval_horizons(gwn, test_loader)

print_metrics("GraphWaveNet Baseline — Validation", val_m)
print_metrics("GraphWaveNet Baseline — Test", test_m)


Epoch 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 1: train_loss=0.166336  val_avg_MAE=192.611

Validation metrics
   12h  MAE=191.403  RMSE=358.438
   24h  MAE=189.733  RMSE=363.071
   48h  MAE=196.351  RMSE=372.638
   72h  MAE=192.957  RMSE=360.658


Epoch 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.117438  val_avg_MAE=189.943

Validation metrics
   12h  MAE=182.940  RMSE=338.154
   24h  MAE=187.478  RMSE=348.945
   48h  MAE=195.416  RMSE=359.328
   72h  MAE=193.937  RMSE=352.510


Epoch 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

In [None]:
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h-1 for h in EVAL_HORIZONS], device=DEVICE)

flow_std_t = torch.tensor(flow_std, dtype=torch.float32, device=DEVICE).view(1, 1, -1)

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=True):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        # pred, yb are SCALED
        pred = model(xb, tfb)  # (B, OUT, N)

        # Convert error to ORIGINAL units via * std
        err = (pred[:, h_idx, :] - yb[:, h_idx, :]) * flow_std_t  # (B, H, N)

        abs_err = err.abs()
        sq_err  = err * err

        for k, h in enumerate(EVAL_HORIZONS):
            acc[h]["abs"]   += float(abs_err[:, k, :].sum())
            acc[h]["sq"]    += float(sq_err[:, k, :].sum())
            acc[h]["count"] += abs_err[:, k, :].numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae  = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics


In [17]:
print("Evaluating on TEST set...")
test_m = eval_horizons_fast(gwn, test_loader)

print_metrics("GraphWaveNet Baseline — Test", test_m)


Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


GraphWaveNet Baseline — Test
   12h  MAE=132.135  RMSE=263.511
   24h  MAE=132.739  RMSE=266.139
   48h  MAE=142.070  RMSE=285.151
   72h  MAE=146.225  RMSE=292.688


# GraphWaveNet-RNN Ablations

We extend the GraphWaveNet encoder with optional recurrent layers applied to node embeddings:

Ablations:
1) GWN (baseline): use_gru=False, use_lstm=False
2) GWN+GRU:       use_gru=True,  use_lstm=False
3) GWN+LSTM:      use_gru=False, use_lstm=True
4) GWN+GRU+LSTM:  use_gru=True,  use_lstm=True  (proposed)

All experiments share:
- same dataset, splits, adjacency, loaders
- same evaluation horizons + metrics
- same early stopping rule


In [8]:
class GraphWaveNetEncoder(nn.Module):
    """
    Returns node embeddings across time: (B, end_channels, N, T)
    """
    def __init__(
        self,
        num_nodes,
        in_dim,
        supports,
        residual_channels=32,
        dilation_channels=32,
        skip_channels=64,
        end_channels=128,
        kernel_size=2,
        blocks=2,
        layers_per_block=4,
        gcn_order=1,
        dropout=0.1,
    ):
        super().__init__()
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.blocks = blocks
        self.layers_per_block = layers_per_block

        # receptive field
        receptive_field = 1
        for _ in range(blocks):
            for i in range(layers_per_block):
                receptive_field += (kernel_size - 1) * (2 ** i)
        self.receptive_field = receptive_field

        self.start_conv = nn.Conv2d(in_dim, residual_channels, kernel_size=(1, 1))

        self.filter_convs = nn.ModuleList()
        self.gate_convs   = nn.ModuleList()
        self.skip_convs   = nn.ModuleList()
        self.bn           = nn.ModuleList()
        self.gconvs       = nn.ModuleList()

        for _ in range(blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i
                self.filter_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))
                self.gate_convs.append(CausalConv2d(residual_channels, dilation_channels, kernel_size, dilation))

                self.skip_convs.append(nn.Conv2d(dilation_channels, skip_channels, kernel_size=(1, 1)))
                self.gconvs.append(DiffusionGraphConv(dilation_channels, residual_channels, supports, order=gcn_order, dropout=dropout))
                self.bn.append(nn.BatchNorm2d(residual_channels))

        self.end_conv_1 = nn.Conv2d(skip_channels, end_channels, kernel_size=(1, 1))

    def forward(self, x):
        """
        x: (B, F, N, IN_LEN)
        returns: (B, end_channels, N, T)
        """
        if x.size(-1) < self.receptive_field:
            pad_len = self.receptive_field - x.size(-1)
            x = F.pad(x, (pad_len, 0, 0, 0))

        x = self.start_conv(x)
        skip = None

        for i in range(len(self.filter_convs)):
            residual = x

            filt = torch.tanh(self.filter_convs[i](x))
            gate = torch.sigmoid(self.gate_convs[i](x))
            x = filt * gate
            x = F.dropout(x, p=self.dropout, training=self.training)

            s = self.skip_convs[i](x)
            skip = s if skip is None else (skip + s)

            x = self.gconvs[i](x)
            x = x + residual
            x = self.bn[i](x)

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        return x  # (B, end_channels, N, T)


class GraphWaveNetRNN(nn.Module):
    """
    GraphWaveNet encoder + optional GRU/LSTM over time for each node.
    Time-aware head using future calendar covariates.

    Output: (B, OUT_LEN, N)
    """
    def __init__(
        self,
        num_nodes,
        in_dim,
        out_len,
        supports,
        end_channels=128,
        use_gru=False,
        use_lstm=False,
        rnn_hidden=128,
        dropout=0.1,
        **encoder_kwargs
    ):
        super().__init__()
        self.out_len = out_len
        self.use_gru = use_gru
        self.use_lstm = use_lstm

        self.encoder = GraphWaveNetEncoder(
            num_nodes=num_nodes,
            in_dim=in_dim,
            supports=supports,
            end_channels=end_channels,
            dropout=dropout,
            **encoder_kwargs
        )

        # RNN operates on per-node sequences: (B*N, T, C)
        if use_gru:
            self.gru = nn.GRU(input_size=end_channels, hidden_size=rnn_hidden, batch_first=True)
        else:
            self.gru = None

        if use_lstm:
            self.lstm = nn.LSTM(input_size=(rnn_hidden if use_gru else end_channels),
                                hidden_size=rnn_hidden,
                                batch_first=True)
        else:
            self.lstm = None

        # final node embedding size after RNN(s)
        final_dim = rnn_hidden if (use_gru or use_lstm) else end_channels

        # time-aware head
        self.time_embed = nn.Linear(4, final_dim)
        self.horizon_out = nn.Linear(final_dim, 1)

    def forward(self, x, tf_future):
        """
        x: (B, F, N, IN_LEN)
        tf_future: (B, OUT_LEN, 4)
        """
        h = self.encoder(x)              # (B, C, N, T)
        B, C, N, T = h.shape

        # Take full temporal sequence per node
        h_seq = h.permute(0, 2, 3, 1).contiguous()   # (B, N, T, C)
        h_seq = h_seq.view(B*N, T, C)                # (B*N, T, C)

        # Optional GRU then optional LSTM
        if self.gru is not None:
            out_gru, _ = self.gru(h_seq)             # (B*N, T, H)
            h_seq2 = out_gru
        else:
            h_seq2 = h_seq

        if self.lstm is not None:
            out_lstm, (hn, cn) = self.lstm(h_seq2)   # (B*N, T, H)
            last = out_lstm[:, -1, :]                # (B*N, H)
        else:
            last = h_seq2[:, -1, :]                  # (B*N, C or H)

        # Reshape back to (B, N, D)
        z = last.view(B, N, -1)                      # (B, N, D)

        # time embedding for each horizon
        te = self.time_embed(tf_future)              # (B, OUT_LEN, D)

        # combine
        out = F.relu(z.unsqueeze(1) + te.unsqueeze(2))  # (B, OUT_LEN, N, D)
        out = self.horizon_out(out).squeeze(-1)         # (B, OUT_LEN, N)
        return out


In [9]:
import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import tqdm

# Horizons we report in the paper
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h - 1 for h in EVAL_HORIZONS], device=DEVICE)

# For converting scaled error -> original units (multiply by std)
flow_std_t = torch.tensor(flow_std, dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    """
    Works for both:
      - GraphWaveNetTimeAware (forward(x, tf_future))
      - GraphWaveNetRNN      (forward(x, tf_future))
    Assumes y is scaled and predictions are scaled.
    """
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=True):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # (B, OUT, N) scaled

        # Original-units error at selected horizons: (pred - true) * std
        err = (pred[:, h_idx, :] - yb[:, h_idx, :]) * flow_std_t  # (B, H, N)

        abs_err = err.abs()
        sq_err = err * err

        for k, h in enumerate(EVAL_HORIZONS):
            acc[h]["abs"] += float(abs_err[:, k, :].sum())
            acc[h]["sq"]  += float(sq_err[:, k, :].sum())
            acc[h]["count"] += abs_err[:, k, :].numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

def train_gwn_baseline_fast(
    model,
    epochs=40,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=6,
    eval_every=2,
):
    """
    Generic trainer for:
      - GraphWaveNet baseline
      - GraphWaveNet+GRU/LSTM variants

    Uses:
      - SmoothL1Loss on scaled targets
      - Early stopping on avg validation MAE (original units)
      - Evaluates every `eval_every` epochs to save time
    """
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    bad = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=True):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            running += float(loss.item())

        print(f"\nEpoch {epoch}: train_loss={running/len(train_loader):.6f}")

        # Validate periodically
        if (epoch % eval_every == 0) or (epoch == epochs):
            val_metrics = eval_horizons_fast(model, val_loader)
            score = avg_mae(val_metrics)

            print(f"Epoch {epoch}: val_avg_MAE={score:.3f}")
            print_metrics("Validation metrics", val_metrics)

            if score < best_score:
                best_score = score
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model


In [10]:
gwn_base = GraphWaveNetRNN(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    end_channels=128,
    use_gru=False,
    use_lstm=False,
    rnn_hidden=128,
    dropout=0.1,
    residual_channels=32,
    dilation_channels=32,
    skip_channels=64,
    kernel_size=2,
    blocks=2,
    layers_per_block=4,
    gcn_order=1,
).to(DEVICE)


In [11]:
gwn_gru = GraphWaveNetRNN(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,

    # RNN options
    use_gru=True,
    use_lstm=False,
    rnn_hidden=128,   # you can try 64 later if needed

    # encoder hyperparameters
    end_channels=128,
    dropout=0.1,

    residual_channels=32,
    dilation_channels=32,
    skip_channels=64,
    kernel_size=2,
    blocks=2,
    layers_per_block=4,
    gcn_order=1,
).to(DEVICE)



In [12]:
gwn_lstm = GraphWaveNetRNN(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,

    # RNN options
    use_gru=False,
    use_lstm=True,
    rnn_hidden=128,

    # encoder hyperparameters
    end_channels=128,
    dropout=0.1,

    residual_channels=32,
    dilation_channels=32,
    skip_channels=64,
    kernel_size=2,
    blocks=2,
    layers_per_block=4,
    gcn_order=1,
).to(DEVICE)


In [13]:
gwn_gru_lstm = GraphWaveNetRNN(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,

    # RNN options (both ON)
    use_gru=True,
    use_lstm=True,
    rnn_hidden=128,   # consider 64 if memory is tight

    # encoder hyperparameters
    end_channels=128,
    dropout=0.1,

    residual_channels=32,
    dilation_channels=32,
    skip_channels=64,
    kernel_size=2,
    blocks=2,
    layers_per_block=4,
    gcn_order=1,
).to(DEVICE)


In [None]:
model = gwn_gru
model = train_gwn_baseline_fast(model, epochs=40, patience=6, eval_every=2)

print("Evaluating on TEST set...")
test_m = eval_horizons_fast(model, test_loader)
print_metrics("GraphWaveNet+GRU — Test", test_m)



Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.069344


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.068216


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=142.784

Validation metrics
   12h  MAE=122.702  RMSE=250.600
   24h  MAE=133.762  RMSE=270.968
   48h  MAE=153.171  RMSE=305.410
   72h  MAE=161.501  RMSE=323.483


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.067949


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.067564


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=145.270

Validation metrics
   12h  MAE=122.212  RMSE=245.460
   24h  MAE=135.322  RMSE=270.348
   48h  MAE=156.585  RMSE=308.799
   72h  MAE=166.961  RMSE=330.038


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.067505


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.067179


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=138.746

Validation metrics
   12h  MAE=118.996  RMSE=244.390
   24h  MAE=130.782  RMSE=268.204
   48h  MAE=149.420  RMSE=301.413
   72h  MAE=155.786  RMSE=314.527


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.066889


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.066778


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=136.142

Validation metrics
   12h  MAE=121.889  RMSE=239.475
   24h  MAE=127.331  RMSE=256.420
   48h  MAE=144.405  RMSE=285.289
   72h  MAE=150.942  RMSE=299.515


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.067129


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.066682


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=147.497

Validation metrics
   12h  MAE=126.991  RMSE=251.459
   24h  MAE=138.185  RMSE=275.211
   48h  MAE=158.519  RMSE=311.785
   72h  MAE=166.292  RMSE=327.812


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.066730


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.067255


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=152.662

Validation metrics
   12h  MAE=138.393  RMSE=287.913
   24h  MAE=144.928  RMSE=300.079
   48h  MAE=160.973  RMSE=331.954
   72h  MAE=166.352  RMSE=341.669


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.066909


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.066567


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=142.554

Validation metrics
   12h  MAE=122.701  RMSE=244.487
   24h  MAE=132.333  RMSE=266.945
   48h  MAE=153.941  RMSE=304.071
   72h  MAE=161.242  RMSE=320.493


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 15: train_loss=0.066083


Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

In [None]:
model = gwn_lstm
model = train_gwn_baseline_fast(model, epochs=40, patience=6, eval_every=2)

print("Evaluating on TEST set...")
test_m = eval_horizons_fast(model, test_loader)
print_metrics("GraphWaveNet+GRU — Test", test_m)


In [None]:
model = gwn_gru_lstm
model = train_gwn_baseline_fast(model, epochs=40, patience=6, eval_every=2)

print("Evaluating on TEST set...")
test_m = eval_horizons_fast(model, test_loader)
print_metrics("GraphWaveNet+GRU — Test", test_m)


In [14]:
import json
import time
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm.auto import tqdm

# Choose Excel engine (xlsxwriter is much faster if available)
def _excel_engine():
    try:
        import xlsxwriter  # noqa: F401
        return "xlsxwriter"
    except Exception:
        return "openpyxl"

def make_run_dir(model_name: str, root: str = "artifacts/runs") -> Path:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe = "".join([c if c.isalnum() or c in "-_." else "_" for c in model_name])
    run_dir = Path(root) / f"{ts}_{safe}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(obj, path: Path):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, default=str)

@torch.inference_mode()
def collect_predictions_for_horizons(
    model,
    loader,
    horizons=(12, 24, 48, 72),
    stations=None,
    timestamps=None,
    in_len=None,
    flow_mean=None,
    flow_std=None,
    device="cuda",
):
    """
    Collect predictions/targets on loader for selected horizons.
    Assumes y and pred are SCALED. Converts both to ORIGINAL units before saving.

    Returns:
      pred: (S, H, N) float32
      true: (S, H, N) float32
      times: (S, H) datetime64[ns]
      horizons: list[int]
    """
    model.eval()

    horizons = list(horizons)
    h_idx = np.array([h - 1 for h in horizons], dtype=np.int32)

    # We rely on test_loader shuffle=False so order is dataset order
    # Try to get starts from dataset (FastPemsWindowDataset had .starts)
    ds = loader.dataset
    if not hasattr(ds, "starts"):
        raise AttributeError("loader.dataset must have attribute `.starts` to reconstruct timestamps per window.")
    starts = ds.starts

    assert timestamps is not None and in_len is not None, "Need timestamps and in_len to compute target times."
    assert stations is not None, "Need stations list for column labels."
    assert flow_mean is not None and flow_std is not None, "Need flow_mean/flow_std to unscale."

    N = len(stations)
    flow_mean = flow_mean.astype(np.float32)
    flow_std  = flow_std.astype(np.float32)

    preds_list, trues_list, times_list = [], [], []
    offset = 0

    for xb, yb, tfb in tqdm(loader, desc="Collect preds", leave=True):
        B = xb.shape[0]

        xb = xb.to(device, non_blocking=True)
        tfb = tfb.to(device, non_blocking=True)

        pred_scaled = model(xb, tfb).detach().cpu().numpy().astype(np.float32)  # (B, OUT, N)

        # yb is already on CPU typically; ensure numpy float32
        true_scaled = yb.detach().cpu().numpy().astype(np.float32)

        # Select horizons
        pred_sel_scaled = pred_scaled[:, h_idx, :]  # (B, H, N)
        true_sel_scaled = true_scaled[:, h_idx, :]  # (B, H, N)

        # Unscale to original units
        pred_sel = pred_sel_scaled * flow_std[None, None, :] + flow_mean[None, None, :]
        true_sel = true_sel_scaled * flow_std[None, None, :] + flow_mean[None, None, :]

        preds_list.append(pred_sel.astype(np.float32))
        trues_list.append(true_sel.astype(np.float32))

        # target timestamps for each sample/horizon
        batch_times = np.zeros((B, len(horizons)), dtype="datetime64[ns]")
        for j in range(B):
            t0 = int(starts[offset + j])
            target_indices = t0 + int(in_len) + h_idx
            batch_times[j, :] = np.array(pd.to_datetime(timestamps[target_indices]).astype("datetime64[ns]"))
        times_list.append(batch_times)

        offset += B

    pred = np.concatenate(preds_list, axis=0)
    true = np.concatenate(trues_list, axis=0)
    times = np.concatenate(times_list, axis=0)
    return pred, true, times, horizons

def export_preds_to_excel(
    excel_path: Path,
    pred: np.ndarray,
    true: np.ndarray,
    times: np.ndarray,
    horizons,
    stations,
    max_stations_excel: int | None = None,
):
    """
    Writes an Excel workbook:
      - metrics sheet can be added separately
      - pred_hXX and true_hXX sheets for each horizon

    If max_stations_excel is set, only the first max_stations are exported (file smaller).
    Full arrays should still be saved to NPZ.
    """
    engine = _excel_engine()
    stations = [str(s) for s in stations]

    if max_stations_excel is not None:
        stations = stations[:max_stations_excel]
        pred = pred[:, :, :max_stations_excel]
        true = true[:, :, :max_stations_excel]

    with pd.ExcelWriter(excel_path, engine=engine) as writer:
        # Also save station list
        pd.DataFrame({"station": stations}).to_excel(writer, sheet_name="stations", index=False)

        # Save each horizon
        for k, h in enumerate(horizons):
            tcol = pd.to_datetime(times[:, k])
            pred_df = pd.DataFrame(pred[:, k, :], columns=stations)
            pred_df.insert(0, "target_time", tcol)

            true_df = pd.DataFrame(true[:, k, :], columns=stations)
            true_df.insert(0, "target_time", tcol)

            pred_df.to_excel(writer, sheet_name=f"pred_h{h}", index=False)
            true_df.to_excel(writer, sheet_name=f"true_h{h}", index=False)

def append_summary_csv(summary_csv: Path, row: dict):
    summary_csv.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame([row])
    if summary_csv.exists():
        old = pd.read_csv(summary_csv)
        out = pd.concat([old, df], ignore_index=True)
    else:
        out = df
    out.to_csv(summary_csv, index=False)

def train_and_save_best(
    model,
    model_name: str,
    run_dir: Path,
    train_loader,
    val_loader,
    eval_fn,                 # eval_horizons_fast
    epochs=40,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=6,
    eval_every=2,
):
    """
    Trains model, saves:
      - best checkpoint to run_dir/best.pt
      - training history to run_dir/history.csv
      - config to run_dir/config.json
    Returns:
      model (loaded with best checkpoint)
      history_df
      best_val_metrics
    """
    run_dir.mkdir(parents=True, exist_ok=True)

    config = dict(
        model_name=model_name,
        epochs=epochs,
        lr=lr,
        weight_decay=weight_decay,
        clip=clip,
        patience=patience,
        eval_every=eval_every,
        timestamp=str(datetime.now()),
    )
    save_json(config, run_dir / "config.json")

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    best_val_metrics = None
    bad = 0

    history = []
    t_start = time.time()

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=True):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            running += float(loss.item())

        train_loss = running / max(len(train_loader), 1)
        row = {"epoch": epoch, "train_loss": train_loss}

        print(f"\nEpoch {epoch}: train_loss={train_loss:.6f}")

        if (epoch % eval_every == 0) or (epoch == epochs):
            val_metrics = eval_fn(model, val_loader)
            score = float(np.mean([val_metrics[h]["MAE"] for h in val_metrics]))

            row["val_avg_MAE"] = score
            for h in sorted(val_metrics.keys()):
                row[f"val_MAE_{h}h"] = val_metrics[h]["MAE"]
                row[f"val_RMSE_{h}h"] = val_metrics[h]["RMSE"]

            print(f"Epoch {epoch}: val_avg_MAE={score:.3f}")

            # Save best checkpoint
            if score < best_score:
                best_score = score
                best_val_metrics = val_metrics
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                torch.save(
                    {
                        "model_name": model_name,
                        "epoch": epoch,
                        "best_val_avg_MAE": best_score,
                        "state_dict": best_state,
                    },
                    run_dir / "best.pt",
                )
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    history.append(row)
                    break

        history.append(row)

        # Write history every epoch so you don't lose it
        pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)

    # Load best checkpoint if exists
    best_path = run_dir / "best.pt"
    if best_path.exists():
        ckpt = torch.load(best_path, map_location="cpu")
        model.load_state_dict(ckpt["state_dict"])
    else:
        print("Warning: best.pt not found; using current model weights.")

    elapsed = time.time() - t_start
    save_json({"train_seconds": elapsed, "best_val_avg_MAE": best_score}, run_dir / "train_summary.json")

    history_df = pd.read_csv(run_dir / "history.csv")
    return model, history_df, best_val_metrics


In [15]:

model_name = "GraphWaveNet_GRU"
model = gwn_gru


run_dir = make_run_dir(model_name)

# Train and save best checkpoint + history
model, history_df, best_val_metrics = train_and_save_best(
    model=model,
    model_name=model_name,
    run_dir=run_dir,
    train_loader=train_loader,
    val_loader=val_loader,
    eval_fn=eval_horizons_fast,   # use your fast evaluator
    epochs=40,
    patience=6,
    eval_every=2,
)

print("Evaluating on TEST set...")
test_metrics = eval_horizons_fast(model, test_loader)
print_metrics(f"{model_name} — Test", test_metrics)

# Save metrics to JSON + Excel-friendly table
metrics_row = {"model_name": model_name, "run_dir": str(run_dir)}
for h in sorted(test_metrics.keys()):
    metrics_row[f"test_MAE_{h}h"] = test_metrics[h]["MAE"]
    metrics_row[f"test_RMSE_{h}h"] = test_metrics[h]["RMSE"]

save_json({"test_metrics": test_metrics}, run_dir / "test_metrics.json")
pd.DataFrame([metrics_row]).to_csv(run_dir / "test_metrics.csv", index=False)

# Also append to a master summary file (so you have one file for all models)
append_summary_csv(Path("artifacts/results_summary.csv"), metrics_row)

# Collect predictions + actuals on TEST set (original units), for paper horizons
HORIZONS_TO_SAVE = [12, 24, 48, 72]

pred, true, times, horizons = collect_predictions_for_horizons(
    model=model,
    loader=test_loader,
    horizons=HORIZONS_TO_SAVE,
    stations=stations,
    timestamps=timestamps,
    in_len=IN_LEN,
    flow_mean=flow_mean,
    flow_std=flow_std,
    device=DEVICE,
)

# Save compact full arrays (best format for later plots)
np.savez_compressed(
    run_dir / "test_pred_true_selected_horizons.npz",
    pred=pred,
    true=true,
    times=times,
    horizons=np.array(horizons, dtype=np.int32),
    stations=stations,
)

# Export to Excel (pred + true sheets for each horizon)
# If Excel gets too heavy, set max_stations_excel=200 (or 500)
excel_path = run_dir / "test_pred_true_selected_horizons.xlsx"
export_preds_to_excel(
    excel_path=excel_path,
    pred=pred,
    true=true,
    times=times,
    horizons=horizons,
    stations=stations,
    max_stations_excel=None,  
)

print("\nSaved run outputs to:", run_dir)
print(" - best checkpoint:", run_dir / "best.pt")
print(" - history:", run_dir / "history.csv")
print(" - test metrics:", run_dir / "test_metrics.json")
print(" - predictions (npz):", run_dir / "test_pred_true_selected_horizons.npz")
print(" - predictions (xlsx):", excel_path)
print(" - master summary:", Path("artifacts/results_summary.csv"))


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.167076


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.116750


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=181.891


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.100684


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.093186


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=170.453


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.088388


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.085367


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=156.216


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.082692


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.080544


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=153.305


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.079169


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.077482


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=160.192


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.076396


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.075392


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=147.165


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.074821


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.074361


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=149.617


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 15: train_loss=0.073265


Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 16: train_loss=0.072879


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 16: val_avg_MAE=149.144


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 17: train_loss=0.072138


Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 18: train_loss=0.072228


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 18: val_avg_MAE=154.570


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 19: train_loss=0.071228


Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 20: train_loss=0.070735


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 20: val_avg_MAE=146.966


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 21: train_loss=0.070750


Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 22: train_loss=0.069830


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 22: val_avg_MAE=150.344


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 23: train_loss=0.069634


Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 24: train_loss=0.069842


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 24: val_avg_MAE=145.888


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 25: train_loss=0.069440


Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 26: train_loss=0.068699


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 26: val_avg_MAE=141.426


Train 27/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 27: train_loss=0.068965


Train 28/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 28: train_loss=0.068485


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 28: val_avg_MAE=137.197


Train 29/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 29: train_loss=0.068209


Train 30/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 30: train_loss=0.067903


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 30: val_avg_MAE=143.328


Train 31/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 31: train_loss=0.067960


Train 32/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 32: train_loss=0.067656


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 32: val_avg_MAE=143.007


Train 33/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 33: train_loss=0.067636


Train 34/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 34: train_loss=0.067217


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 34: val_avg_MAE=156.258


Train 35/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 35: train_loss=0.067948


Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 36: train_loss=0.066877


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 36: val_avg_MAE=151.170


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 37: train_loss=0.067017


Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 38: train_loss=0.067049


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 38: val_avg_MAE=145.086


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 39: train_loss=0.066604


Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 40: train_loss=0.066550


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 40: val_avg_MAE=141.399

Early stopping. Best val_avg_MAE=137.197
Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


GraphWaveNet_GRU — Test
   12h  MAE=122.689  RMSE=245.462
   24h  MAE=127.521  RMSE=256.627
   48h  MAE=138.532  RMSE=283.292
   72h  MAE=143.747  RMSE=292.758


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260208_182608_GraphWaveNet_GRU
 - best checkpoint: artifacts/runs/20260208_182608_GraphWaveNet_GRU/best.pt
 - history: artifacts/runs/20260208_182608_GraphWaveNet_GRU/history.csv
 - test metrics: artifacts/runs/20260208_182608_GraphWaveNet_GRU/test_metrics.json
 - predictions (npz): artifacts/runs/20260208_182608_GraphWaveNet_GRU/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260208_182608_GraphWaveNet_GRU/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv


In [16]:
# --------- USER EDITS THESE TWO LINES PER MODEL ----------
model_name = "GraphWaveNet_LSTM"
model = gwn_lstm
# --------------------------------------------------------

run_dir = make_run_dir(model_name)

# Train and save best checkpoint + history
model, history_df, best_val_metrics = train_and_save_best(
    model=model,
    model_name=model_name,
    run_dir=run_dir,
    train_loader=train_loader,
    val_loader=val_loader,
    eval_fn=eval_horizons_fast,   # use your fast evaluator
    epochs=40,
    patience=6,
    eval_every=2,
)

print("Evaluating on TEST set...")
test_metrics = eval_horizons_fast(model, test_loader)
print_metrics(f"{model_name} — Test", test_metrics)

# Save metrics to JSON + Excel-friendly table
metrics_row = {"model_name": model_name, "run_dir": str(run_dir)}
for h in sorted(test_metrics.keys()):
    metrics_row[f"test_MAE_{h}h"] = test_metrics[h]["MAE"]
    metrics_row[f"test_RMSE_{h}h"] = test_metrics[h]["RMSE"]

save_json({"test_metrics": test_metrics}, run_dir / "test_metrics.json")
pd.DataFrame([metrics_row]).to_csv(run_dir / "test_metrics.csv", index=False)

# Also append to a master summary file (so you have one file for all models)
append_summary_csv(Path("artifacts/results_summary.csv"), metrics_row)

# Collect predictions + actuals on TEST set (original units), for paper horizons
HORIZONS_TO_SAVE = [12, 24, 48, 72]

pred, true, times, horizons = collect_predictions_for_horizons(
    model=model,
    loader=test_loader,
    horizons=HORIZONS_TO_SAVE,
    stations=stations,
    timestamps=timestamps,
    in_len=IN_LEN,
    flow_mean=flow_mean,
    flow_std=flow_std,
    device=DEVICE,
)

# Save compact full arrays (best format for later plots)
np.savez_compressed(
    run_dir / "test_pred_true_selected_horizons.npz",
    pred=pred,
    true=true,
    times=times,
    horizons=np.array(horizons, dtype=np.int32),
    stations=stations,
)

# Export to Excel (pred + true sheets for each horizon)
# If Excel gets too heavy, set max_stations_excel=200 (or 500)
excel_path = run_dir / "test_pred_true_selected_horizons.xlsx"
export_preds_to_excel(
    excel_path=excel_path,
    pred=pred,
    true=true,
    times=times,
    horizons=horizons,
    stations=stations,
    max_stations_excel=None,  

print("\nSaved run outputs to:", run_dir)
print(" - best checkpoint:", run_dir / "best.pt")
print(" - history:", run_dir / "history.csv")
print(" - test metrics:", run_dir / "test_metrics.json")
print(" - predictions (npz):", run_dir / "test_pred_true_selected_horizons.npz")
print(" - predictions (xlsx):", excel_path)
print(" - master summary:", Path("artifacts/results_summary.csv"))


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.155551


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.109867


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=185.194


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.099236


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.091060


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=180.625


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.086698


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.083360


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=160.233


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.080657


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.078295


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=150.111


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.077220


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.075237


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=153.954


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.073782


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.073607


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=146.033


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.072411


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.071308


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=153.542


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 15: train_loss=0.072370


Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 16: train_loss=0.070219


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 16: val_avg_MAE=147.814


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 17: train_loss=0.070281


Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 18: train_loss=0.068997


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 18: val_avg_MAE=148.378


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 19: train_loss=0.068725


Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 20: train_loss=0.068504


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 20: val_avg_MAE=146.096


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 21: train_loss=0.068134


Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 22: train_loss=0.067596


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 22: val_avg_MAE=145.132


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 23: train_loss=0.067309


Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 24: train_loss=0.067599


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 24: val_avg_MAE=139.573


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 25: train_loss=0.066927


Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 26: train_loss=0.066799


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 26: val_avg_MAE=154.140


Train 27/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 27: train_loss=0.066406


Train 28/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 28: train_loss=0.066337


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 28: val_avg_MAE=148.620


Train 29/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 29: train_loss=0.065917


Train 30/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 30: train_loss=0.066229


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 30: val_avg_MAE=136.750


Train 31/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 31: train_loss=0.065791


Train 32/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 32: train_loss=0.065526


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 32: val_avg_MAE=140.470


Train 33/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 33: train_loss=0.065260


Train 34/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 34: train_loss=0.066075


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 34: val_avg_MAE=148.643


Train 35/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 35: train_loss=0.064958


Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 36: train_loss=0.065093


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 36: val_avg_MAE=137.769


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 37: train_loss=0.064704


Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 38: train_loss=0.065052


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 38: val_avg_MAE=137.614


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 39: train_loss=0.064986


Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 40: train_loss=0.064698


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 40: val_avg_MAE=143.630
Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


GraphWaveNet_LSTM — Test
   12h  MAE=123.368  RMSE=246.555
   24h  MAE=125.954  RMSE=252.869
   48h  MAE=135.831  RMSE=277.550
   72h  MAE=142.286  RMSE=289.365


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260208_200244_GraphWaveNet_LSTM
 - best checkpoint: artifacts/runs/20260208_200244_GraphWaveNet_LSTM/best.pt
 - history: artifacts/runs/20260208_200244_GraphWaveNet_LSTM/history.csv
 - test metrics: artifacts/runs/20260208_200244_GraphWaveNet_LSTM/test_metrics.json
 - predictions (npz): artifacts/runs/20260208_200244_GraphWaveNet_LSTM/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260208_200244_GraphWaveNet_LSTM/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv


In [17]:
# --------- USER EDITS THESE TWO LINES PER MODEL ----------
model_name = "GraphWaveNet_GRU_LSTM"
model = gwn_gru_lstm
# --------------------------------------------------------

run_dir = make_run_dir(model_name)

# Train and save best checkpoint + history
model, history_df, best_val_metrics = train_and_save_best(
    model=model,
    model_name=model_name,
    run_dir=run_dir,
    train_loader=train_loader,
    val_loader=val_loader,
    eval_fn=eval_horizons_fast,   # use your fast evaluator
    epochs=40,
    patience=6,
    eval_every=2,
)

print("Evaluating on TEST set...")
test_metrics = eval_horizons_fast(model, test_loader)
print_metrics(f"{model_name} — Test", test_metrics)

# Save metrics to JSON + Excel-friendly table
metrics_row = {"model_name": model_name, "run_dir": str(run_dir)}
for h in sorted(test_metrics.keys()):
    metrics_row[f"test_MAE_{h}h"] = test_metrics[h]["MAE"]
    metrics_row[f"test_RMSE_{h}h"] = test_metrics[h]["RMSE"]

save_json({"test_metrics": test_metrics}, run_dir / "test_metrics.json")
pd.DataFrame([metrics_row]).to_csv(run_dir / "test_metrics.csv", index=False)

# Also append to a master summary file (so you have one file for all models)
append_summary_csv(Path("artifacts/results_summary.csv"), metrics_row)

# Collect predictions + actuals on TEST set (original units), for paper horizons
HORIZONS_TO_SAVE = [12, 24, 48, 72]

pred, true, times, horizons = collect_predictions_for_horizons(
    model=model,
    loader=test_loader,
    horizons=HORIZONS_TO_SAVE,
    stations=stations,
    timestamps=timestamps,
    in_len=IN_LEN,
    flow_mean=flow_mean,
    flow_std=flow_std,
    device=DEVICE,
)

# Save compact full arrays (best format for later plots)
np.savez_compressed(
    run_dir / "test_pred_true_selected_horizons.npz",
    pred=pred,
    true=true,
    times=times,
    horizons=np.array(horizons, dtype=np.int32),
    stations=stations,
)

# Export to Excel (pred + true sheets for each horizon)
# If Excel gets too heavy, set max_stations_excel=200 (or 500)
excel_path = run_dir / "test_pred_true_selected_horizons.xlsx"
export_preds_to_excel(
    excel_path=excel_path,
    pred=pred,
    true=true,
    times=times,
    horizons=horizons,
    stations=stations,
    max_stations_excel=None,  # set to 200 if you want smaller files
)

print("\nSaved run outputs to:", run_dir)
print(" - best checkpoint:", run_dir / "best.pt")
print(" - history:", run_dir / "history.csv")
print(" - test metrics:", run_dir / "test_metrics.json")
print(" - predictions (npz):", run_dir / "test_pred_true_selected_horizons.npz")
print(" - predictions (xlsx):", excel_path)
print(" - master summary:", Path("artifacts/results_summary.csv"))


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.169730


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.119613


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=190.120


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.102975


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.097156


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=178.502


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.093715


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.097541


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=176.582


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.087460


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.084724


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=170.203


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.083168


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.081871


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=163.213


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.081001


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.080082


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=163.116


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.079269


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.078608


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=154.176


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 15: train_loss=0.077285


Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 16: train_loss=0.076048


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 16: val_avg_MAE=155.512


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 17: train_loss=0.075317


Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 18: train_loss=0.074346


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 18: val_avg_MAE=152.205


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 19: train_loss=0.073666


Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 20: train_loss=0.073199


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 20: val_avg_MAE=156.183


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 21: train_loss=0.072152


Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 22: train_loss=0.071850


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 22: val_avg_MAE=144.455


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 23: train_loss=0.071661


Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 24: train_loss=0.070736


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 24: val_avg_MAE=147.032


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 25: train_loss=0.070533


Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 26: train_loss=0.069998


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 26: val_avg_MAE=144.378


Train 27/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 27: train_loss=0.069685


Train 28/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 28: train_loss=0.069719


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 28: val_avg_MAE=156.022


Train 29/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 29: train_loss=0.069103


Train 30/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 30: train_loss=0.068630


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 30: val_avg_MAE=153.476


Train 31/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 31: train_loss=0.068539


Train 32/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 32: train_loss=0.067885


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 32: val_avg_MAE=148.778


Train 33/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 33: train_loss=0.067969


Train 34/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 34: train_loss=0.068308


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 34: val_avg_MAE=141.396


Train 35/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 35: train_loss=0.067742


Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 36: train_loss=0.067355


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 36: val_avg_MAE=152.712


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 37: train_loss=0.067230


Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 38: train_loss=0.067206


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 38: val_avg_MAE=136.619


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 39: train_loss=0.066833


Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 40: train_loss=0.067344


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 40: val_avg_MAE=139.761
Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


GraphWaveNet_GRU_LSTM — Test
   12h  MAE=119.049  RMSE=241.427
   24h  MAE=125.124  RMSE=253.402
   48h  MAE=135.325  RMSE=278.629
   72h  MAE=141.891  RMSE=291.309


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM
 - best checkpoint: artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM/best.pt
 - history: artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM/history.csv
 - test metrics: artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM/test_metrics.json
 - predictions (npz): artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260208_214919_GraphWaveNet_GRU_LSTM/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv


### STGCN MODELS 

In [3]:
import numpy as np
import pandas as pd
from pathlib import Path

DATASET_STRICT = Path("artifacts/pems_graph_dataset_strict.npz")
assert DATASET_STRICT.exists(), f"Missing: {DATASET_STRICT}. Check artifacts folder."

d = np.load(DATASET_STRICT, allow_pickle=True)

X = d["X"].astype(np.float32)           # (T,N,F)
Y = d["Y"].astype(np.float32)           # (T,N)
stations = d["stations"]
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"]
val_starts   = d["val_starts"]
test_starts  = d["test_starts"]

IN_LEN  = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean = d["flow_mean"].astype(np.float32)
flow_std  = d["flow_std"].astype(np.float32)
speed_mean = d["speed_mean"].astype(np.float32)
speed_std  = d["speed_std"].astype(np.float32)

print("Loaded strict dataset")
print("X:", X.shape, "Y:", Y.shape)
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)
print("starts:", len(train_starts), len(val_starts), len(test_starts))


Loaded strict dataset
X: (2208, 1821, 6) Y: (2208, 1821)
IN_LEN: 24 OUT_LEN: 72
starts: 1009 289 673


In [4]:
print("Keys:", d.files)

Keys: ['X', 'Y', 'A', 'stations', 'timestamps', 'train_starts', 'val_starts', 'test_starts', 'in_len', 'out_len', 'flow_mean', 'flow_std', 'speed_mean', 'speed_std']


In [5]:
A = d["A"].astype(np.float32)
print("A:", A.shape, "density:", float((A > 0).mean()))

A: (1821, 1821) density: 0.0023693916932872663


In [6]:
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

def row_normalize_dense(A, eps=1e-6):
    d = A.sum(axis=1, keepdims=True)
    return A / (d + eps)

def dense_to_sparse(A, device):
    idx = np.nonzero(A)
    values = A[idx].astype(np.float32)
    indices = np.vstack(idx)
    return torch.sparse_coo_tensor(
        torch.tensor(indices, dtype=torch.long, device=device),
        torch.tensor(values, dtype=torch.float32, device=device),
        size=A.shape,
        device=device
    ).coalesce()

A_rw  = row_normalize_dense(A)
A_rwT = row_normalize_dense(A.T)
supports = [dense_to_sparse(A_rw, DEVICE), dense_to_sparse(A_rwT, DEVICE)]

print("Supports nnz:", [int(s._nnz()) for s in supports])


DEVICE: cuda
Supports nnz: [7856, 7856]


In [7]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)

# Pre-scale once
X_scaled = X.copy()
X_scaled[:, :, 0] = (X_scaled[:, :, 0] - flow_mean[None, :]) / flow_std[None, :]
X_scaled[:, :, 1] = (X_scaled[:, :, 1] - speed_mean[None, :]) / speed_std[None, :]

Y_scaled = (Y - flow_mean[None, :]) / flow_std[None, :]

X_fnt = np.transpose(X_scaled, (2, 1, 0)).copy()  # (F,N,T)
TF_all = time_encoding(pd.DatetimeIndex(timestamps))

class FastPemsWindowDataset(Dataset):
    def __init__(self, X_fnt, Y_scaled, TF_all, starts, in_len, out_len):
        self.X_fnt = X_fnt
        self.Ys = Y_scaled
        self.TF = TF_all
        self.starts = starts.astype(np.int32)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])
        x = self.X_fnt[:, :, t:t+self.in_len]  # (F,N,IN)
        y = self.Ys[t+self.in_len:t+self.in_len+self.out_len, :]  # (OUT,N)
        tf = self.TF[t+self.in_len:t+self.in_len+self.out_len, :] # (OUT,4)
        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(tf)

BATCH_SIZE = 8
train_ds = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, train_starts, IN_LEN, OUT_LEN)
val_ds   = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, val_starts,   IN_LEN, OUT_LEN)
test_ds  = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, test_starts,  IN_LEN, OUT_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- Sparse node convolution (safe for CUDA sparse mm; uses float32 internally) ----
class NConv(nn.Module):
    def forward(self, x, A_sp):
        # x: (B,C,N,T), A_sp: sparse (N,N)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)     # (N, B*C*T)
        out = torch.sparse.mm(A_sp, x_r.float())       # float32 for sparse-mm
        out = out.reshape(N, B, C, T).permute(1, 2, 0, 3)
        return out.to(dtype=x.dtype)

class SimpleGraphConv(nn.Module):
    """
    STGCN-style graph conv: concat [x, A1x, A2x, ...] then 1x1 conv
    """
    def __init__(self, c_in, c_out, supports, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.mlp = nn.Conv2d(c_in * (1 + len(supports)), c_out, kernel_size=(1,1))
        self.dropout = dropout

    def forward(self, x):
        out = [x]
        for A in self.supports:
            out.append(self.nconv(x, A))
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h

class CausalTemporalGLU(nn.Module):
    """
    Causal temporal convolution + GLU gating (STGCN temporal block).
    Preserves time length via left padding.
    """
    def __init__(self, c_in, c_out, kernel_size=3, dilation=1, dropout=0.0):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv2d(
            c_in, 2*c_out,
            kernel_size=(1, kernel_size),
            dilation=(1, dilation)
        )
        self.dropout = dropout

    def forward(self, x):
        # x: (B,C,N,T)
        x = F.pad(x, (self.pad, 0, 0, 0))  # left pad time only
        h = self.conv(x)                   # (B,2C,N,T)
        a, b = torch.chunk(h, 2, dim=1)
        out = a * torch.sigmoid(b)         # GLU
        out = F.dropout(out, p=self.dropout, training=self.training)
        return out

class STConvBlock(nn.Module):
    """
    TemporalGLU -> GraphConv -> TemporalGLU -> BN
    """
    def __init__(self, c_in, c_out, supports, kt=3, dropout=0.1):
        super().__init__()
        self.temp1 = CausalTemporalGLU(c_in, c_out, kernel_size=kt, dropout=dropout)
        self.gconv = SimpleGraphConv(c_out, c_out, supports, dropout=dropout)
        self.temp2 = CausalTemporalGLU(c_out, c_out, kernel_size=kt, dropout=dropout)
        self.bn = nn.BatchNorm2d(c_out)

    def forward(self, x):
        x = self.temp1(x)
        x = self.gconv(x)
        x = self.temp2(x)
        x = self.bn(x)
        return x

class STGCNEncoder(nn.Module):
    """
    Encodes x -> (B, C, N, T)
    """
    def __init__(
        self,
        in_dim,
        supports,
        hidden=64,
        end_channels=128,
        blocks=2,
        kt=3,
        dropout=0.1,
    ):
        super().__init__()
        self.start = nn.Conv2d(in_dim, hidden, kernel_size=(1,1))
        self.blocks = nn.ModuleList([
            STConvBlock(hidden, hidden, supports, kt=kt, dropout=dropout)
            for _ in range(blocks)
        ])
        self.end = nn.Conv2d(hidden, end_channels, kernel_size=(1,1))

    def forward(self, x):
        # x: (B,F,N,T)
        x = self.start(x)
        for blk in self.blocks:
            x = blk(x)
        x = F.relu(self.end(x))
        return x  # (B,end_channels,N,T)

class STGCNForecast(nn.Module):
    """
    STGCN + optional GRU/LSTM per-node over time + time-aware horizon head.
    Output: (B, OUT_LEN, N)
    """
    def __init__(
        self,
        num_nodes,
        in_dim,
        out_len,
        supports,
        hidden=64,
        end_channels=128,
        st_blocks=2,
        kt=3,
        dropout=0.1,
        use_gru=False,
        use_lstm=False,
        rnn_hidden=128,
    ):
        super().__init__()
        self.out_len = out_len
        self.use_gru = use_gru
        self.use_lstm = use_lstm

        self.encoder = STGCNEncoder(
            in_dim=in_dim,
            supports=supports,
            hidden=hidden,
            end_channels=end_channels,
            blocks=st_blocks,
            kt=kt,
            dropout=dropout,
        )

        # RNNs operate on (B*N, T, C)
        if use_gru:
            self.gru = nn.GRU(end_channels, rnn_hidden, batch_first=True)
        else:
            self.gru = None

        if use_lstm:
            self.lstm = nn.LSTM(
                (rnn_hidden if use_gru else end_channels),
                rnn_hidden,
                batch_first=True
            )
        else:
            self.lstm = None

        final_dim = rnn_hidden if (use_gru or use_lstm) else end_channels

        self.time_embed = nn.Linear(4, final_dim)
        self.horizon_out = nn.Linear(final_dim, 1)

    def forward(self, x, tf_future):
        """
        x: (B,F,N,T_in)
        tf_future: (B, OUT_LEN, 4)
        """
        h = self.encoder(x)  # (B,C,N,T)
        B, C, N, T = h.shape

        # per-node sequences
        seq = h.permute(0, 2, 3, 1).contiguous().view(B*N, T, C)  # (B*N,T,C)

        if self.gru is not None:
            seq, _ = self.gru(seq)           # (B*N,T,H)

        if self.lstm is not None:
            seq, _ = self.lstm(seq)          # (B*N,T,H)

        last = seq[:, -1, :]                 # (B*N,D)
        z = last.view(B, N, -1)              # (B,N,D)

        te = self.time_embed(tf_future)      # (B,OUT,D)
        out = F.relu(z.unsqueeze(1) + te.unsqueeze(2))  # (B,OUT,N,D)
        out = self.horizon_out(out).squeeze(-1)         # (B,OUT,N)
        return out


In [11]:
# Recommended starting hyperparams
ST_HIDDEN = 64
ST_END = 128
ST_BLOCKS = 2
KT = 3
DROP = 0.1
RNN_H = 128  # if slow/memory heavy, drop to 64

stgcn_base = STGCNForecast(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    hidden=ST_HIDDEN,
    end_channels=ST_END,
    st_blocks=ST_BLOCKS,
    kt=KT,
    dropout=DROP,
    use_gru=False,
    use_lstm=False,
    rnn_hidden=RNN_H,
).to(DEVICE)

stgcn_gru = STGCNForecast(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    hidden=ST_HIDDEN,
    end_channels=ST_END,
    st_blocks=ST_BLOCKS,
    kt=KT,
    dropout=DROP,
    use_gru=True,
    use_lstm=False,
    rnn_hidden=RNN_H,
).to(DEVICE)

stgcn_lstm = STGCNForecast(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    hidden=ST_HIDDEN,
    end_channels=ST_END,
    st_blocks=ST_BLOCKS,
    kt=KT,
    dropout=DROP,
    use_gru=False,
    use_lstm=True,
    rnn_hidden=RNN_H,
).to(DEVICE)

stgcn_gru_lstm = STGCNForecast(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    hidden=ST_HIDDEN,
    end_channels=ST_END,
    st_blocks=ST_BLOCKS,
    kt=KT,
    dropout=DROP,
    use_gru=True,
    use_lstm=True,
    rnn_hidden=RNN_H,
).to(DEVICE)


In [12]:
xb, yb, tfb = next(iter(train_loader))
out = stgcn_base(xb.to(DEVICE), tfb.to(DEVICE))
print(out.shape)  # (B, OUT_LEN, N)


torch.Size([8, 72, 1821])


In [13]:
model_name = "STGCN"
model = stgcn_base
# run your existing train_and_save_best / export block

In [14]:
model_name = "STGCN_GRU"
model = stgcn_gru


In [15]:
model_name = "STGCN_LSTM"
model = stgcn_lstm


In [25]:
import numpy as np
import torch
from tqdm.auto import tqdm

# Horizons reported in paper
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h - 1 for h in EVAL_HORIZONS], device=DEVICE)

# Convert scaled error -> original units using std
flow_std_t = torch.tensor(flow_std, dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    """
    Evaluates MAE/RMSE at specific horizons in ORIGINAL units.
    Assumes:
      - yb is scaled target
      - model output is scaled predictions
    """
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=True):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # (B, OUT_LEN, N), scaled

        # Original units error at selected horizons
        err = (pred[:, h_idx, :] - yb[:, h_idx, :]) * flow_std_t  # (B, H, N)

        abs_err = err.abs()
        sq_err  = err * err

        for k, h in enumerate(EVAL_HORIZONS):
            acc[h]["abs"]   += float(abs_err[:, k, :].sum())
            acc[h]["sq"]    += float(sq_err[:, k, :].sum())
            acc[h]["count"] += abs_err[:, k, :].numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae  = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}

    return metrics


In [None]:
import json
import time
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm.auto import tqdm

# ---------- File/Run utilities ----------
def make_run_dir(model_name: str, root: str = "artifacts/runs") -> Path:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe = "".join([c if c.isalnum() or c in "-_." else "_" for c in model_name])
    run_dir = Path(root) / f"{ts}_{safe}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(path: Path, obj: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, default=str)

def append_summary_csv(summary_csv: Path, row: dict):
    summary_csv.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame([row])
    if summary_csv.exists():
        old = pd.read_csv(summary_csv)
        out = pd.concat([old, df], ignore_index=True)
    else:
        out = df
    out.to_csv(summary_csv, index=False)

def _excel_engine():
    try:
        import xlsxwriter  # noqa: F401
        return "xlsxwriter"
    except Exception:
        return "openpyxl"

def export_preds_to_excel(
    excel_path: Path,
    pred: np.ndarray,
    true: np.ndarray,
    times: np.ndarray,
    horizons,
    stations,
    max_stations_excel: int | None = 300,
):
    engine = _excel_engine()
    stations = [str(s) for s in stations]

    if max_stations_excel is not None:
        stations = stations[:max_stations_excel]
        pred = pred[:, :, :max_stations_excel]
        true = true[:, :, :max_stations_excel]

    with pd.ExcelWriter(excel_path, engine=engine) as writer:
        pd.DataFrame({"station": stations}).to_excel(writer, sheet_name="stations", index=False)
        for k, h in enumerate(horizons):
            tcol = pd.to_datetime(times[:, k])

            pred_df = pd.DataFrame(pred[:, k, :], columns=stations)
            pred_df.insert(0, "target_time", tcol)
            pred_df.to_excel(writer, sheet_name=f"pred_h{h}", index=False)

            true_df = pd.DataFrame(true[:, k, :], columns=stations)
            true_df.insert(0, "target_time", tcol)
            true_df.to_excel(writer, sheet_name=f"true_h{h}", index=False)

# ---------- Trainer (uses your eval_horizons_fast) ----------
def train_and_save_best(
    model,
    model_name: str,
    run_dir: Path,
    train_loader,
    val_loader,
    eval_fn,   # eval_horizons_fast
    device: str,
    epochs=40,
    lr=1e-3,
    weight_decay=1e-4,
    clip=5.0,
    patience=6,
    eval_every=2,
):
    run_dir.mkdir(parents=True, exist_ok=True)
    save_json(run_dir / "config.json", {
        "model_name": model_name,
        "epochs": epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "clip": clip,
        "patience": patience,
        "eval_every": eval_every,
        "timestamp": str(datetime.now()),
    })

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    best_val_metrics = None
    bad = 0

    history = []
    t0 = time.time()

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=True):
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            tfb = tfb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            running += float(loss.item())

        row = {"epoch": epoch, "train_loss": running / max(len(train_loader), 1)}
        print(f"\nEpoch {epoch}: train_loss={row['train_loss']:.6f}")

        if (epoch % eval_every == 0) or (epoch == epochs):
            val_metrics = eval_fn(model, val_loader)
            val_avg = float(np.mean([val_metrics[h]["MAE"] for h in val_metrics]))
            row["val_avg_MAE"] = val_avg
            for h in sorted(val_metrics.keys()):
                row[f"val_MAE_{h}h"] = float(val_metrics[h]["MAE"])
                row[f"val_RMSE_{h}h"] = float(val_metrics[h]["RMSE"])
            print(f"Epoch {epoch}: val_avg_MAE={val_avg:.3f}")

            if val_avg < best_score:
                best_score = val_avg
                best_val_metrics = val_metrics
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

                torch.save(
                    {"model_name": model_name, "epoch": epoch, "best_val_avg_MAE": best_score, "state_dict": best_state},
                    run_dir / "best.pt"
                )
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    history.append(row)
                    break

        history.append(row)
        pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)

    # load best checkpoint
    best_path = run_dir / "best.pt"
    if best_path.exists():
        ckpt = torch.load(best_path, map_location="cpu")
        model.load_state_dict(ckpt["state_dict"])

    save_json(run_dir / "train_summary.json", {
        "train_seconds": time.time() - t0,
        "best_val_avg_MAE": best_score
    })
    return model, pd.read_csv(run_dir / "history.csv"), best_val_metrics

# ---------- Prediction collector (selected horizons) ----------
@torch.inference_mode()
def collect_predictions_for_horizons(
    model,
    loader,
    horizons,
    stations,
    timestamps,
    in_len,
    flow_mean,
    flow_std,
    device,
):
    horizons = list(horizons)
    h_idx = np.array([h - 1 for h in horizons], dtype=np.int32)

    ds = loader.dataset
    if not hasattr(ds, "starts"):
        raise AttributeError("loader.dataset must have attribute `.starts` to reconstruct timestamps per window.")
    starts = ds.starts

    flow_mean = flow_mean.astype(np.float32)
    flow_std  = flow_std.astype(np.float32)

    preds_list, trues_list, times_list = [], [], []
    offset = 0

    for xb, yb, tfb in tqdm(loader, desc="Collect preds", leave=True):
        B = xb.shape[0]
        xb = xb.to(device, non_blocking=True)
        tfb = tfb.to(device, non_blocking=True)

        pred_scaled = model(xb, tfb).detach().cpu().numpy().astype(np.float32)  # (B, OUT, N)
        true_scaled = yb.detach().cpu().numpy().astype(np.float32)

        pred_sel = pred_scaled[:, h_idx, :] * flow_std[None, None, :] + flow_mean[None, None, :]
        true_sel = true_scaled[:, h_idx, :] * flow_std[None, None, :] + flow_mean[None, None, :]

        preds_list.append(pred_sel.astype(np.float32))
        trues_list.append(true_sel.astype(np.float32))

        batch_times = np.zeros((B, len(horizons)), dtype="datetime64[ns]")
        for j in range(B):
            t0 = int(starts[offset + j])
            target_indices = t0 + int(in_len) + h_idx
            batch_times[j, :] = np.array(pd.to_datetime(timestamps[target_indices]).astype("datetime64[ns]"))
        times_list.append(batch_times)

        offset += B

    pred = np.concatenate(preds_list, axis=0)
    true = np.concatenate(trues_list, axis=0)
    times = np.concatenate(times_list, axis=0)
    return pred, true, times, horizons

# ---------- One-call experiment runner ----------
def run_experiment_and_save(
    model_name: str,
    model,
    device: str,
    train_loader,
    val_loader,
    test_loader,
    eval_fn,             # eval_horizons_fast
    print_metrics_fn,    # print_metrics
    stations,
    timestamps,
    in_len,
    flow_mean,
    flow_std,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12, 24, 48, 72),
    max_stations_excel=300,
):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    # 1) Train
    model = model.to(device)
    model, history_df, best_val_metrics = train_and_save_best(
        model=model,
        model_name=model_name,
        run_dir=run_dir,
        train_loader=train_loader,
        val_loader=val_loader,
        eval_fn=eval_fn,
        device=device,
        epochs=epochs,
        patience=patience,
        eval_every=eval_every,
    )

    # 2) Test metrics saved immediately
    print("Evaluating on TEST set...")
    test_metrics = eval_fn(model, test_loader)
    print_metrics_fn(f"{model_name} — Test", test_metrics)
    save_json(run_dir / "test_metrics.json", {"test_metrics": test_metrics})
    pd.DataFrame([{
        "model_name": model_name,
        "run_dir": str(run_dir),
        **{f"test_MAE_{h}h": float(test_metrics[h]["MAE"]) for h in sorted(test_metrics.keys())},
        **{f"test_RMSE_{h}h": float(test_metrics[h]["RMSE"]) for h in sorted(test_metrics.keys())},
    }]).to_csv(run_dir / "test_metrics.csv", index=False)

    # master summary
    row = {"model_name": model_name, "run_dir": str(run_dir)}
    for h in sorted(test_metrics.keys()):
        row[f"test_MAE_{h}h"] = float(test_metrics[h]["MAE"])
        row[f"test_RMSE_{h}h"] = float(test_metrics[h]["RMSE"])
    append_summary_csv(Path("artifacts/results_summary.csv"), row)

    # 3) Predictions (NPZ always), Excel best-effort
    try:
        pred, true, times, horizons = collect_predictions_for_horizons(
            model=model,
            loader=test_loader,
            horizons=horizons_to_save,
            stations=stations,
            timestamps=timestamps,
            in_len=in_len,
            flow_mean=flow_mean,
            flow_std=flow_std,
            device=device,
        )

        np.savez_compressed(
            run_dir / "test_pred_true_selected_horizons.npz",
            pred=pred, true=true, times=times,
            horizons=np.array(horizons, dtype=np.int32),
            stations=stations,
        )

        export_preds_to_excel(
            excel_path=run_dir / "test_pred_true_selected_horizons.xlsx",
            pred=pred, true=true, times=times,
            horizons=horizons, stations=stations,
            max_stations_excel=max_stations_excel,
        )
    except Exception as e:
        print("Prediction export failed (metrics already saved):", repr(e))
        with open(run_dir / "pred_export_error.txt", "w") as f:
            f.write(repr(e))

    print("\nDONE. Saved outputs in:", run_dir)
    return run_dir, test_metrics


In [26]:
model_name = "STGCN"
model = stgcn_base

run_dir, test_metrics = run_experiment_and_save(
    model_name=model_name,
    model=model,
    device=DEVICE,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    eval_fn=eval_horizons_fast,
    print_metrics_fn=print_metrics,
    stations=stations,
    timestamps=timestamps,
    in_len=IN_LEN,
    flow_mean=flow_mean,
    flow_std=flow_std,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12, 24, 48, 72),
    max_stations_excel=300,
)


Run dir: artifacts/runs/20260209_015638_STGCN


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.197108


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.127995


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=7263.597


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.113510


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.105284


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=29297.267


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.100544


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.097170


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=24280.974


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.094620


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.093069


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=17905.704


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.091367


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.090104


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=29436.580


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.089176


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.088238


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=37901.065


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.087290


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.087152


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=36249.339

Early stopping. Best val_avg_MAE=7263.597
Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN — Test
   12h  MAE=8775.112  RMSE=295355.097
   24h  MAE=8771.100  RMSE=295357.880
   48h  MAE=8771.417  RMSE=295357.973
   72h  MAE=8770.536  RMSE=295358.391


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


DONE. Saved outputs in: artifacts/runs/20260209_015638_STGCN


# STGCN Models (Baseline + Ablations)

This notebook trains and evaluates:
- STGCN
- STGCN-GRU
- STGCN-LSTM
- STGCN-GRU-LSTM

We use the same dataset artifact, splits, scaling, and evaluation horizons as the GraphWaveNet experiments.
All outputs are saved under `artifacts/runs/` so results are not lost across Paperspace sessions.


In [28]:
import os
import json
import time
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# Reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


In [29]:
DATASET_PATH = Path("artifacts/pems_graph_dataset_strict.npz")
if not DATASET_PATH.exists():
    # fallback (if strict doesn't exist)
    DATASET_PATH = Path("artifacts/pems_graph_dataset.npz")

assert DATASET_PATH.exists(), f"Dataset file not found: {DATASET_PATH}"

d = np.load(DATASET_PATH, allow_pickle=True)
print("Loaded:", DATASET_PATH)
print("Keys:", d.files)

X = d["X"].astype(np.float32)                 # (T, N, F)
Y = d["Y"].astype(np.float32)                 # (T, N)
stations = d["stations"]
timestamps = pd.to_datetime(d["timestamps"])

train_starts = d["train_starts"].astype(np.int32)
val_starts   = d["val_starts"].astype(np.int32)
test_starts  = d["test_starts"].astype(np.int32)

IN_LEN  = int(d["in_len"][0])
OUT_LEN = int(d["out_len"][0])

flow_mean  = d["flow_mean"].astype(np.float32)
flow_std   = d["flow_std"].astype(np.float32)
speed_mean = d["speed_mean"].astype(np.float32)
speed_std  = d["speed_std"].astype(np.float32)

print("\nShapes:")
print("X:", X.shape, "(T,N,F)")
print("Y:", Y.shape, "(T,N)")
print("Stations:", len(stations))
print("Timestamp range:", timestamps.min(), "→", timestamps.max())
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)
print("Window starts:", len(train_starts), len(val_starts), len(test_starts))

# Adjacency matrix (if present)
A = None
if "A" in d.files:
    A = d["A"].astype(np.float32)
    print("A:", A.shape, "density:", float((A > 0).mean()))
else:
    print("WARNING: 'A' not in dataset npz. You must rebuild adjacency supports separately.")


Loaded: artifacts/pems_graph_dataset_strict.npz
Keys: ['X', 'Y', 'A', 'stations', 'timestamps', 'train_starts', 'val_starts', 'test_starts', 'in_len', 'out_len', 'flow_mean', 'flow_std', 'speed_mean', 'speed_std']

Shapes:
X: (2208, 1821, 6) (T,N,F)
Y: (2208, 1821) (T,N)
Stations: 1821
Timestamp range: 2024-10-01 00:00:00 → 2024-12-31 23:00:00
IN_LEN: 24 OUT_LEN: 72
Window starts: 1009 289 673
A: (1821, 1821) density: 0.0023693916932872663


In [30]:
def dense_to_sparse(A_dense: np.ndarray, device: str):
    idx = np.nonzero(A_dense)
    values = A_dense[idx].astype(np.float32)
    indices = np.vstack(idx)
    sp = torch.sparse_coo_tensor(
        torch.tensor(indices, dtype=torch.long, device=device),
        torch.tensor(values, dtype=torch.float32, device=device),
        size=A_dense.shape,
        device=device
    ).coalesce()
    return sp

def row_normalize(A_dense: np.ndarray, eps=1e-6):
    d = A_dense.sum(axis=1, keepdims=True)
    return A_dense / (d + eps)

assert A is not None, "Adjacency A not found in NPZ. Rebuild A first, then come back here."

# Add self-loops (important for stability in many GNNs)
A_hat = A + np.eye(A.shape[0], dtype=np.float32)

# Keep same style as earlier (directed supports)
A_rw  = row_normalize(A_hat)
A_rwT = row_normalize(A_hat.T)

supports = [dense_to_sparse(A_rw, DEVICE), dense_to_sparse(A_rwT, DEVICE)]
print("Supports nnz:", [int(s._nnz()) for s in supports])


Supports nnz: [7852, 7852]


In [31]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)

TF_all = time_encoding(pd.DatetimeIndex(timestamps))

# Scale X channels: assume channel0=flow, channel1=speed (as in your pipeline)
X_scaled = X.copy()
X_scaled[:, :, 0] = (X_scaled[:, :, 0] - flow_mean[None, :]) / flow_std[None, :]
X_scaled[:, :, 1] = (X_scaled[:, :, 1] - speed_mean[None, :]) / speed_std[None, :]

# Scale targets (flow)
Y_scaled = (Y - flow_mean[None, :]) / flow_std[None, :]

# Store X as (F, N, T) for fast slicing
X_fnt = np.transpose(X_scaled, (2, 1, 0)).copy()  # (F,N,T)
print("X_fnt:", X_fnt.shape, "Y_scaled:", Y_scaled.shape, "TF_all:", TF_all.shape)


X_fnt: (6, 1821, 2208) Y_scaled: (2208, 1821) TF_all: (2208, 4)


In [32]:
class FastPemsWindowDataset(Dataset):
    """
    Returns:
      x  = (F, N, IN_LEN)
      y  = (OUT_LEN, N)
      tf = (OUT_LEN, 4)
    """
    def __init__(self, X_fnt, Y_scaled, TF_all, starts, in_len, out_len):
        self.X_fnt = X_fnt
        self.Ys = Y_scaled
        self.TF = TF_all
        self.starts = starts.astype(np.int32)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])
        x = self.X_fnt[:, :, t:t+self.in_len]
        y = self.Ys[t+self.in_len:t+self.in_len+self.out_len, :]
        tf = self.TF[t+self.in_len:t+self.in_len+self.out_len, :]
        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(tf)

BATCH_SIZE = 8
NUM_WORKERS = 2

train_ds = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, train_starts, IN_LEN, OUT_LEN)
val_ds   = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, val_starts,   IN_LEN, OUT_LEN)
test_ds  = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, test_starts,  IN_LEN, OUT_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          persistent_workers=(NUM_WORKERS > 0))
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          persistent_workers=(NUM_WORKERS > 0))
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          persistent_workers=(NUM_WORKERS > 0))

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)
# Expect: x=(B,F,N,IN_LEN), y=(B,OUT_LEN,N), tf=(B,OUT_LEN,4)


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


In [33]:
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h - 1 for h in EVAL_HORIZONS], device=DEVICE)
flow_std_t = torch.tensor(flow_std, dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=True):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # scaled (B,OUT,N)

        err = (pred[:, h_idx, :] - yb[:, h_idx, :]) * flow_std_t  # original units
        abs_err = err.abs()
        sq_err  = err * err

        for k, h in enumerate(EVAL_HORIZONS):
            acc[h]["abs"] += float(abs_err[:, k, :].sum())
            acc[h]["sq"]  += float(sq_err[:, k, :].sum())
            acc[h]["count"] += abs_err[:, k, :].numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics


In [34]:
def make_run_dir(model_name: str, root: str = "artifacts/runs") -> Path:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe = "".join([c if c.isalnum() or c in "-_." else "_" for c in model_name])
    run_dir = Path(root) / f"{ts}_{safe}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(path: Path, obj: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, default=str)

def append_summary_csv(summary_csv: Path, row: dict):
    summary_csv.parent.mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame([row])
    if summary_csv.exists():
        old = pd.read_csv(summary_csv)
        out = pd.concat([old, df], ignore_index=True)
    else:
        out = df
    out.to_csv(summary_csv, index=False)

def _excel_engine():
    try:
        import xlsxwriter  # noqa: F401
        return "xlsxwriter"
    except Exception:
        return "openpyxl"

def export_preds_to_excel(excel_path: Path, pred, true, times, horizons, stations, max_stations_excel=300):
    engine = _excel_engine()
    stations = [str(s) for s in stations]

    if max_stations_excel is not None:
        stations = stations[:max_stations_excel]
        pred = pred[:, :, :max_stations_excel]
        true = true[:, :, :max_stations_excel]

    with pd.ExcelWriter(excel_path, engine=engine) as writer:
        pd.DataFrame({"station": stations}).to_excel(writer, sheet_name="stations", index=False)
        for k, h in enumerate(horizons):
            tcol = pd.to_datetime(times[:, k])

            pred_df = pd.DataFrame(pred[:, k, :], columns=stations)
            pred_df.insert(0, "target_time", tcol)
            pred_df.to_excel(writer, sheet_name=f"pred_h{h}", index=False)

            true_df = pd.DataFrame(true[:, k, :], columns=stations)
            true_df.insert(0, "target_time", tcol)
            true_df.to_excel(writer, sheet_name=f"true_h{h}", index=False)

@torch.inference_mode()
def collect_predictions_for_horizons(model, loader, horizons, stations, timestamps, in_len, flow_mean, flow_std):
    horizons = list(horizons)
    h_idx_np = np.array([h - 1 for h in horizons], dtype=np.int32)

    ds = loader.dataset
    starts = ds.starts

    preds_list, trues_list, times_list = [], [], []
    offset = 0

    for xb, yb, tfb in tqdm(loader, desc="Collect preds", leave=True):
        B = xb.shape[0]
        xb = xb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred_scaled = model(xb, tfb).detach().cpu().numpy().astype(np.float32)  # (B,OUT,N)
        true_scaled = yb.detach().cpu().numpy().astype(np.float32)

        pred_sel = pred_scaled[:, h_idx_np, :] * flow_std[None, None, :] + flow_mean[None, None, :]
        true_sel = true_scaled[:, h_idx_np, :] * flow_std[None, None, :] + flow_mean[None, None, :]

        preds_list.append(pred_sel)
        trues_list.append(true_sel)

        batch_times = np.zeros((B, len(horizons)), dtype="datetime64[ns]")
        for j in range(B):
            t0 = int(starts[offset + j])
            target_indices = t0 + int(in_len) + h_idx_np
            batch_times[j, :] = np.array(pd.to_datetime(timestamps[target_indices]).astype("datetime64[ns]"))
        times_list.append(batch_times)

        offset += B

    pred = np.concatenate(preds_list, axis=0)
    true = np.concatenate(trues_list, axis=0)
    times = np.concatenate(times_list, axis=0)
    return pred, true, times, horizons

def train_and_save_best(model, model_name, run_dir, epochs=40, lr=1e-3, weight_decay=1e-4, clip=5.0, patience=6, eval_every=2):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    bad = 0
    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=True):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            running += float(loss.item())

        row = {"epoch": epoch, "train_loss": running / max(len(train_loader), 1)}
        print(f"\nEpoch {epoch}: train_loss={row['train_loss']:.6f}")

        if (epoch % eval_every == 0) or (epoch == epochs):
            val_metrics = eval_horizons_fast(model, val_loader)
            val_avg = float(np.mean([val_metrics[h]["MAE"] for h in val_metrics]))
            row["val_avg_MAE"] = val_avg
            print(f"Epoch {epoch}: val_avg_MAE={val_avg:.3f}")
            print_metrics("Validation metrics", val_metrics)

            if val_avg < best_score:
                best_score = val_avg
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                torch.save({"model_name": model_name, "best_val_avg_MAE": best_score, "state_dict": best_state}, run_dir / "best.pt")
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    history.append(row)
                    break

        history.append(row)
        pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

def run_experiment_and_save(model_name, model, epochs=40, patience=6, eval_every=2, horizons_to_save=(12,24,48,72), max_stations_excel=300):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    model = model.to(DEVICE)

    # Train
    model = train_and_save_best(model, model_name, run_dir, epochs=epochs, patience=patience, eval_every=eval_every)

    # Test metrics (save immediately)
    print("Evaluating on TEST set...")
    test_metrics = eval_horizons_fast(model, test_loader)
    print_metrics(f"{model_name} — Test", test_metrics)
    save_json(run_dir / "test_metrics.json", {"test_metrics": test_metrics})

    row = {"model_name": model_name, "run_dir": str(run_dir)}
    for h in sorted(test_metrics.keys()):
        row[f"test_MAE_{h}h"] = float(test_metrics[h]["MAE"])
        row[f"test_RMSE_{h}h"] = float(test_metrics[h]["RMSE"])
    append_summary_csv(Path("artifacts/results_summary.csv"), row)
    pd.DataFrame([row]).to_csv(run_dir / "test_metrics.csv", index=False)

    # Predictions export (best effort)
    try:
        pred, true, times, horizons = collect_predictions_for_horizons(
            model, test_loader, horizons_to_save, stations, timestamps, IN_LEN, flow_mean, flow_std
        )
        np.savez_compressed(run_dir / "test_pred_true_selected_horizons.npz",
                            pred=pred, true=true, times=times,
                            horizons=np.array(horizons, dtype=np.int32), stations=stations)
        export_preds_to_excel(run_dir / "test_pred_true_selected_horizons.xlsx",
                              pred, true, times, horizons, stations,
                              max_stations_excel=max_stations_excel)
    except Exception as e:
        print("Prediction export failed (metrics already saved):", repr(e))
        with open(run_dir / "pred_export_error.txt", "w") as f:
            f.write(repr(e))

    print("\nDONE. Saved to:", run_dir)
    return run_dir


In [35]:
class NConv(nn.Module):
    def forward(self, x, A_sp):
        # x: (B,C,N,T)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1)
        out = torch.sparse.mm(A_sp, x_r.float())  # keep float32 for sparse op
        out = out.reshape(N, B, C, T).permute(1, 2, 0, 3)
        return out.to(dtype=x.dtype)

class SimpleGraphConv(nn.Module):
    def __init__(self, c_in, c_out, supports, dropout=0.0):
        super().__init__()
        self.nconv = NConv()
        self.supports = supports
        self.mlp = nn.Conv2d(c_in * (1 + len(supports)), c_out, kernel_size=(1,1))
        self.dropout = dropout

    def forward(self, x):
        out = [x]
        for A in self.supports:
            out.append(self.nconv(x, A))
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, p=self.dropout, training=self.training)
        return h

class CausalTemporalGLU(nn.Module):
    def __init__(self, c_in, c_out, kernel_size=3, dropout=0.0):
        super().__init__()
        self.pad = kernel_size - 1
        self.conv = nn.Conv2d(c_in, 2*c_out, kernel_size=(1, kernel_size))
        self.dropout = dropout

    def forward(self, x):
        x = F.pad(x, (self.pad, 0, 0, 0))  # left pad time
        h = self.conv(x)
        a, b = torch.chunk(h, 2, dim=1)
        out = a * torch.sigmoid(b)
        return F.dropout(out, p=self.dropout, training=self.training)

class STConvBlock(nn.Module):
    def __init__(self, c_in, c_out, supports, kt=3, dropout=0.1):
        super().__init__()
        self.t1 = CausalTemporalGLU(c_in, c_out, kernel_size=kt, dropout=dropout)
        self.g  = SimpleGraphConv(c_out, c_out, supports, dropout=dropout)
        self.t2 = CausalTemporalGLU(c_out, c_out, kernel_size=kt, dropout=dropout)
        self.bn = nn.BatchNorm2d(c_out)

    def forward(self, x):
        x = self.t1(x)
        x = self.g(x)
        x = self.t2(x)
        return self.bn(x)

class STGCNEncoder(nn.Module):
    def __init__(self, in_dim, supports, hidden=64, end_channels=128, blocks=2, kt=3, dropout=0.1):
        super().__init__()
        self.start = nn.Conv2d(in_dim, hidden, kernel_size=(1,1))
        self.blocks = nn.ModuleList([STConvBlock(hidden, hidden, supports, kt=kt, dropout=dropout) for _ in range(blocks)])
        self.end = nn.Conv2d(hidden, end_channels, kernel_size=(1,1))

    def forward(self, x):
        x = self.start(x)
        for blk in self.blocks:
            x = blk(x)
        return F.relu(self.end(x))  # (B,C,N,T)

class STGCNForecast(nn.Module):
    def __init__(self, num_nodes, in_dim, out_len, supports,
                 hidden=64, end_channels=128, st_blocks=2, kt=3, dropout=0.1,
                 use_gru=False, use_lstm=False, rnn_hidden=128):
        super().__init__()
        self.out_len = out_len
        self.use_gru = use_gru
        self.use_lstm = use_lstm

        self.encoder = STGCNEncoder(in_dim, supports, hidden=hidden, end_channels=end_channels, blocks=st_blocks, kt=kt, dropout=dropout)

        if use_gru:
            self.gru = nn.GRU(end_channels, rnn_hidden, batch_first=True)
        else:
            self.gru = None

        if use_lstm:
            self.lstm = nn.LSTM((rnn_hidden if use_gru else end_channels), rnn_hidden, batch_first=True)
        else:
            self.lstm = None

        final_dim = rnn_hidden if (use_gru or use_lstm) else end_channels
        self.time_embed = nn.Linear(4, final_dim)
        self.horizon_out = nn.Linear(final_dim, 1)

    def forward(self, x, tf_future):
        h = self.encoder(x)                  # (B,C,N,T)
        B, C, N, T = h.shape

        seq = h.permute(0, 2, 3, 1).contiguous().view(B*N, T, C)  # (B*N,T,C)

        if self.gru is not None:
            seq, _ = self.gru(seq)
        if self.lstm is not None:
            seq, _ = self.lstm(seq)

        last = seq[:, -1, :]                 # (B*N,D)
        z = last.view(B, N, -1)              # (B,N,D)

        te = self.time_embed(tf_future)      # (B,OUT,D)
        out = F.relu(z.unsqueeze(1) + te.unsqueeze(2))  # (B,OUT,N,D)
        out = self.horizon_out(out).squeeze(-1)         # (B,OUT,N)
        return out


In [36]:
ST_HIDDEN = 64
ST_END = 128
ST_BLOCKS = 2
KT = 3
DROP = 0.1
RNN_H = 128

stgcn_base = STGCNForecast(
    num_nodes=X.shape[1],
    in_dim=X.shape[2],
    out_len=OUT_LEN,
    supports=supports,
    hidden=ST_HIDDEN,
    end_channels=ST_END,
    st_blocks=ST_BLOCKS,
    kt=KT,
    dropout=DROP,
    use_gru=False,
    use_lstm=False,
    rnn_hidden=RNN_H,
).to(DEVICE)

# sanity forward
xb, yb, tfb = next(iter(train_loader))
out = stgcn_base(xb.to(DEVICE), tfb.to(DEVICE))
print("Forward out:", out.shape)  # (B, OUT_LEN, N)


Forward out: torch.Size([8, 72, 1821])


In [40]:
T = X.shape[0]

def check_starts(name, starts):
    mx = int(starts.max())
    mn = int(starts.min())
    end = mx + IN_LEN + OUT_LEN
    print(f"{name}: len={len(starts)}, min={mn}, max={mx}, max_end={end}, T={T}")
    assert mn >= 0
    assert end <= T, f"{name} has out-of-range windows: max_start+IN+OUT={end} > T={T}"

check_starts("train_starts", train_starts)
check_starts("val_starts", val_starts)
check_starts("test_starts", test_starts)
print(" starts are valid")


train_starts: len=1009, min=0, max=1008, max_end=1104, T=2208
val_starts: len=289, min=1080, max=1368, max_end=1464, T=2208
test_starts: len=673, min=1440, max=2112, max_end=2208, T=2208
 starts are valid


In [41]:
x0, y0, tf0 = train_ds[0]
print("x0:", x0.shape, "y0:", y0.shape, "tf0:", tf0.shape)
print(" dataset __getitem__ works")


x0: torch.Size([6, 1821, 24]) y0: torch.Size([72, 1821]) tf0: torch.Size([72, 4])
 dataset __getitem__ works


In [42]:
BATCH_SIZE = 8

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,          # key fix
    pin_memory=False,       # also safer; we can turn back on later
    persistent_workers=False
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False
)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)
print(" DataLoader works with num_workers=0")


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])
 DataLoader works with num_workers=0


In [43]:
run_dir = run_experiment_and_save(
    model_name="STGCN",
    model=stgcn_base,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12, 24, 48, 72),
    max_stations_excel=300
)


Run dir: artifacts/runs/20260209_025423_STGCN


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 1: train_loss=0.122658


Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 2: train_loss=0.106960


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2: val_avg_MAE=24947.003

Validation metrics
   12h  MAE=24944.116  RMSE=787686.593
   24h  MAE=24941.407  RMSE=787687.508
   48h  MAE=24950.916  RMSE=787687.290
   72h  MAE=24951.572  RMSE=787687.659


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 3: train_loss=0.099522


Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 4: train_loss=0.095018


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4: val_avg_MAE=4632.463

Validation metrics
   12h  MAE=4631.190  RMSE=122150.499
   24h  MAE=4623.557  RMSE=122142.468
   48h  MAE=4636.045  RMSE=122142.950
   72h  MAE=4639.059  RMSE=122142.454


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 5: train_loss=0.092440


Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 6: train_loss=0.090638


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6: val_avg_MAE=29694.880

Validation metrics
   12h  MAE=29694.358  RMSE=969844.303
   24h  MAE=29683.950  RMSE=969841.750
   48h  MAE=29698.549  RMSE=969841.826
   72h  MAE=29702.661  RMSE=969841.022


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 7: train_loss=0.089389


Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 8: train_loss=0.088278


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8: val_avg_MAE=15737.964

Validation metrics
   12h  MAE=15737.883  RMSE=503787.702
   24h  MAE=15727.043  RMSE=503785.292
   48h  MAE=15742.061  RMSE=503785.437
   72h  MAE=15744.869  RMSE=503785.055


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 9: train_loss=0.087326


Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 10: train_loss=0.086552


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10: val_avg_MAE=8177.184

Validation metrics
   12h  MAE=8176.299  RMSE=266203.593
   24h  MAE=8165.281  RMSE=266201.527
   48h  MAE=8181.872  RMSE=266201.591
   72h  MAE=8185.282  RMSE=266200.853


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 11: train_loss=0.086058


Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 12: train_loss=0.085822


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12: val_avg_MAE=13813.572

Validation metrics
   12h  MAE=13815.922  RMSE=466099.101
   24h  MAE=13800.936  RMSE=466099.831
   48h  MAE=13816.460  RMSE=466099.816
   72h  MAE=13820.970  RMSE=466098.654


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 13: train_loss=0.084875


Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 14: train_loss=0.084417


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14: val_avg_MAE=18132.726

Validation metrics
   12h  MAE=18137.434  RMSE=561861.961
   24h  MAE=18119.242  RMSE=561864.220
   48h  MAE=18134.502  RMSE=561864.219
   72h  MAE=18139.727  RMSE=561863.250


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 15: train_loss=0.084169


Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]


Epoch 16: train_loss=0.083938


Eval:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 16: val_avg_MAE=5477.790

Validation metrics
   12h  MAE=5469.164  RMSE=202127.674
   24h  MAE=5469.919  RMSE=202124.106
   48h  MAE=5484.458  RMSE=202123.825
   72h  MAE=5487.620  RMSE=202123.465

Early stopping. Best val_avg_MAE=4632.463
Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN — Test
   12h  MAE=4932.438  RMSE=122701.834
   24h  MAE=4919.577  RMSE=122693.714
   48h  MAE=4923.405  RMSE=122693.767
   72h  MAE=4926.036  RMSE=122693.390


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]

Exception ignored in: <function tqdm.__del__ at 0x7f967c903380>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/tqdm/std.py", line 1149, in __del__
    self.close()
  File "/usr/local/lib/python3.11/dist-packages/tqdm/notebook.py", line 278, in close
    self.disp(bar_style='danger', check_delay=False)
    ^^^^^^^^^
AttributeError: 'tqdm' object has no attribute 'disp'



DONE. Saved to: artifacts/runs/20260209_025423_STGCN


## Working on STGCN models again 

In [46]:
import os, json, time
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

# ---------------- Repro ----------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))


Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


In [47]:
DATA_PATH = Path("artifacts/pems_graph_dataset_strict.npz")
assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Rebuild the dataset first."

data = np.load(DATA_PATH, allow_pickle=True)
print("Loaded:", DATA_PATH)
print("Keys:", list(data.keys()))

X_raw = data["X"]            # (T, N, F)
Y_raw = data["Y"]            # (T, N)  (flow target)
A = data["A"]                # (N, N)
stations = data["stations"]
timestamps = data["timestamps"]

train_starts = data["train_starts"]
val_starts   = data["val_starts"]
test_starts  = data["test_starts"]

IN_LEN  = int(data["in_len"])
OUT_LEN = int(data["out_len"])

flow_mean = data["flow_mean"]   # (N,)
flow_std  = data["flow_std"]    # (N,)
speed_mean = data["speed_mean"] # (N,)
speed_std  = data["speed_std"]  # (N,)

T, N, Fdim = X_raw.shape
print("\nShapes:")
print("X_raw:", X_raw.shape, "(T,N,F)")
print("Y_raw:", Y_raw.shape, "(T,N)")
print("A:", A.shape)
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)
print("train/val/test starts:", len(train_starts), len(val_starts), len(test_starts))


Loaded: artifacts/pems_graph_dataset_strict.npz
Keys: ['X', 'Y', 'A', 'stations', 'timestamps', 'train_starts', 'val_starts', 'test_starts', 'in_len', 'out_len', 'flow_mean', 'flow_std', 'speed_mean', 'speed_std']

Shapes:
X_raw: (2208, 1821, 6) (T,N,F)
Y_raw: (2208, 1821) (T,N)
A: (1821, 1821)
IN_LEN: 24 OUT_LEN: 72
train/val/test starts: 1009 289 673


  IN_LEN  = int(data["in_len"])
  OUT_LEN = int(data["out_len"])


In [48]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)

dt_idx = pd.to_datetime(timestamps)
TF_all = time_encoding(dt_idx)         # (T,4)

# ----- scale inputs -----
X_scaled = X_raw.astype(np.float32).copy()
X_scaled[:, :, 0] = (X_scaled[:, :, 0] - flow_mean[None, :]) / (flow_std[None, :] + 1e-6)
X_scaled[:, :, 1] = (X_scaled[:, :, 1] - speed_mean[None, :]) / (speed_std[None, :] + 1e-6)

# ----- scale targets (flow) -----
Y_scaled = (Y_raw.astype(np.float32) - flow_mean[None, :]) / (flow_std[None, :] + 1e-6)

# Store for fast slicing as (F,N,T)
X_fnt = np.transpose(X_scaled, (2, 1, 0)).copy()  # (F,N,T)

print("X_fnt:", X_fnt.shape, "Y_scaled:", Y_scaled.shape, "TF_all:", TF_all.shape)
print("Sanity (Y_scaled mean/std approx):", float(Y_scaled.mean()), float(Y_scaled.std()))


X_fnt: (6, 1821, 2208) Y_scaled: (2208, 1821) TF_all: (2208, 4)
Sanity (Y_scaled mean/std approx): -780.4212036132812 30666.189453125


In [49]:
class FastPeMSWindowDataset(Dataset):
    def __init__(self, X_fnt, Y_scaled, TF_all, starts, in_len, out_len):
        self.X_fnt = X_fnt
        self.Y = Y_scaled
        self.TF = TF_all
        self.starts = starts.astype(np.int64)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])
        x = self.X_fnt[:, :, t:t+self.in_len]  # (F,N,IN_LEN)
        y = self.Y[t+self.in_len:t+self.in_len+self.out_len, :]  # (OUT_LEN,N)
        tf = self.TF[t+self.in_len:t+self.in_len+self.out_len, :]  # (OUT_LEN,4)
        return (
            torch.from_numpy(x).float(),
            torch.from_numpy(y).float(),
            torch.from_numpy(tf).float()
        )

train_ds = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, train_starts, IN_LEN, OUT_LEN)
val_ds   = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, val_starts,   IN_LEN, OUT_LEN)
test_ds  = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, test_starts,  IN_LEN, OUT_LEN)

BATCH_SIZE = 8
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


In [50]:
def dense_to_sparse(A_dense: np.ndarray, device: str):
    idx = np.nonzero(A_dense)
    indices = torch.from_numpy(np.vstack(idx)).long()
    values  = torch.from_numpy(A_dense[idx].astype(np.float32))
    sp = torch.sparse_coo_tensor(indices, values, size=A_dense.shape, device=device)
    return sp.coalesce()

def scaled_laplacian(A: np.ndarray) -> np.ndarray:
    """
    Build scaled Laplacian L_tilde = (2/lambda_max)*L - I.
    We follow the common approximation lambda_max ≈ 2. :contentReference[oaicite:3]{index=3}
    """
    A = A.astype(np.float32)
    # Make undirected for STGCN (common choice)
    A = np.maximum(A, A.T)

    # Add self loops
    A = A + np.eye(A.shape[0], dtype=np.float32)

    d = A.sum(axis=1)
    d_inv_sqrt = np.power(d, -0.5, where=(d > 0))
    d_inv_sqrt[~np.isfinite(d_inv_sqrt)] = 0.0

    A_norm = (d_inv_sqrt[:, None] * A) * d_inv_sqrt[None, :]
    L = np.eye(A.shape[0], dtype=np.float32) - A_norm

    lambda_max = 2.0
    L_tilde = (2.0 / lambda_max) * L - np.eye(A.shape[0], dtype=np.float32)
    return L_tilde

L_tilde = scaled_laplacian(A)
L_sp = dense_to_sparse(L_tilde, DEVICE)
print("L_sp nnz:", int(L_sp._nnz()))


L_sp nnz: 7856


In [51]:
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h-1 for h in EVAL_HORIZONS], device=DEVICE)

flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # MUST be scaled outputs (B,OUT_LEN,N)

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        # selected horizons
        pred_sel = pred_u[:, h_idx, :]
        true_sel = true_u[:, h_idx, :]
        for i, h in enumerate(EVAL_HORIZONS):
            err = pred_sel[:, i, :] - true_sel[:, i, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics


In [52]:
class NConv(nn.Module):
    """Sparse matrix multiply along node dimension."""
    def forward(self, x, A_sp):
        # x: (B, C, N, T)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1).float()      # (N, B*C*T)
        x_r = torch.sparse.mm(A_sp, x_r)                         # (N, B*C*T)
        x_out = x_r.reshape(N, B, C, T).permute(1, 2, 0, 3)      # (B, C, N, T)
        return x_out

class ChebGraphConv(nn.Module):
    """
    Chebyshev graph conv using recurrence:
      T0(X)=X
      T1(X)=L~ X
      Tk(X)=2 L~ T_{k-1}(X) - T_{k-2}(X)
    Then 1x1 conv mixes the K stacks.
    """
    def __init__(self, c_in, c_out, K, L_sp):
        super().__init__()
        self.K = K
        self.L_sp = L_sp
        self.nconv = NConv()
        self.mlp = nn.Conv2d(c_in * K, c_out, kernel_size=(1,1))

    def forward(self, x):
        # x: (B,C,N,T)
        out = [x]
        if self.K > 1:
            x1 = self.nconv(x, self.L_sp)
            out.append(x1)
        for k in range(2, self.K):
            x2 = 2.0 * self.nconv(out[-1], self.L_sp) - out[-2]
            out.append(x2)

        h = torch.cat(out, dim=1)  # (B, C*K, N, T)
        return self.mlp(h)

class TemporalGLU(nn.Module):
    """Temporal convolution + GLU gating. No padding -> time shrinks."""
    def __init__(self, c_in, c_out, kt):
        super().__init__()
        self.kt = kt
        self.conv = nn.Conv2d(c_in, 2*c_out, kernel_size=(1, kt))

    def forward(self, x):
        # x: (B,C,N,T)
        z = self.conv(x)                 # (B,2C,N,T-kt+1)
        P, Q = torch.chunk(z, 2, dim=1)  # each (B,C,N,T')
        return P * torch.sigmoid(Q)

class STConvBlock(nn.Module):
    """
    STGCN block: TemporalGLU -> ChebGraphConv -> ReLU -> TemporalGLU
    + residual (time-aligned) + LayerNorm over channels
    """
    def __init__(self, c_in, c_t, c_s, c_out, kt, Ks, L_sp, dropout=0.0):
        super().__init__()
        self.temporal1 = TemporalGLU(c_in, c_t, kt)
        self.graphconv = ChebGraphConv(c_t, c_s, Ks, L_sp)
        self.temporal2 = TemporalGLU(c_s, c_out, kt)

        self.res_conv = None
        if c_in != c_out:
            self.res_conv = nn.Conv2d(c_in, c_out, kernel_size=(1,1))

        self.ln = nn.LayerNorm(c_out)
        self.drop = nn.Dropout(dropout)

        self.kt = kt

    def forward(self, x):
        # x: (B,C_in,N,T)
        x_in = x
        x = self.temporal1(x)            # (B,c_t,N,T1)
        x = self.graphconv(x)            # (B,c_s,N,T1)
        x = F.relu(x)
        x = self.temporal2(x)            # (B,c_out,N,T2)

        # residual: align last T2 timesteps
        T2 = x.shape[-1]
        res = x_in[..., -T2:]
        if self.res_conv is not None:
            res = self.res_conv(res)
        x = x + res

        x = self.drop(x)

        # LayerNorm over channels (per node per time)
        x = x.permute(0, 2, 3, 1)        # (B,N,T,C)
        x = self.ln(x)
        x = x.permute(0, 3, 1, 2)        # (B,C,N,T)
        return x

class STGCN_MultiHorizon(nn.Module):
    """
    STGCN encoder + multi-horizon head.
    Output is (B, OUT_LEN, N) in SCALED space (no unscale inside).
    """
    def __init__(self, num_nodes, in_dim, out_len, L_sp,
                 kt=3, Ks=3, dropout=0.1,
                 c_t=64, c_s=16, c_out=64, blocks=2):
        super().__init__()
        self.out_len = out_len

        layers = []
        c_in = in_dim
        for _ in range(blocks):
            layers.append(STConvBlock(c_in, c_t=c_t, c_s=c_s, c_out=c_out,
                                      kt=kt, Ks=Ks, L_sp=L_sp, dropout=dropout))
            c_in = c_out
        self.blocks = nn.ModuleList(layers)

        # After blocks, time is reduced by blocks * 2*(kt-1)
        # We will infer the remaining time at runtime and build head lazily if needed.
        self.head = None
        self.c_out = c_out

    def _build_head(self, T_rem):
        # Collapse time dimension into 1, output channels = out_len
        self.head = nn.Conv2d(self.c_out, self.out_len, kernel_size=(1, T_rem))

    def forward(self, x, tf_future=None):
        # x: (B,F,N,IN_LEN)
        for blk in self.blocks:
            x = blk(x)

        T_rem = x.shape[-1]
        if self.head is None:
            self._build_head(T_rem)
            self.head = self.head.to(x.device)

        y = self.head(x)       # (B,OUT_LEN,N,1)
        y = y.squeeze(-1)      # (B,OUT_LEN,N)
        return y


In [53]:
ART_DIR = Path("artifacts")
RUNS_DIR = ART_DIR / "runs"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

def make_run_dir(model_name: str) -> Path:
    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = RUNS_DIR / f"{ts}_{model_name}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(path: Path, obj: dict):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def metrics_to_flat_row(model_name: str, split: str, metrics: dict) -> dict:
    row = {"model_name": model_name, "split": split}
    for h in EVAL_HORIZONS:
        row[f"{split}_MAE_{h}h"] = metrics[h]["MAE"]
        row[f"{split}_RMSE_{h}h"] = metrics[h]["RMSE"]
    row[f"{split}_avg_MAE"] = avg_mae(metrics)
    return row

def append_results_summary(row: dict, out_csv: Path = ART_DIR/"results_summary.csv"):
    df_new = pd.DataFrame([row])
    if out_csv.exists():
        df_old = pd.read_csv(out_csv)
        df = pd.concat([df_old, df_new], ignore_index=True)
    else:
        df = df_new
    df.to_csv(out_csv, index=False)
    return out_csv

@torch.inference_mode()
def collect_predictions_selected_horizons(model, loader, horizons=(12,24,48,72)):
    model.eval()
    h_idx_local = torch.tensor([h-1 for h in horizons], device=DEVICE)
    preds_all = []
    trues_all = []
    for xb, yb, tfb in tqdm(loader, desc="Collect preds", leave=False):
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        tfb = tfb.to(DEVICE)
        pred = model(xb, tfb)  # scaled

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        preds_all.append(pred_u[:, h_idx_local, :].detach().cpu().numpy())
        trues_all.append(true_u[:, h_idx_local, :].detach().cpu().numpy())

    preds_all = np.concatenate(preds_all, axis=0)  # (Btot, Hsel, N)
    trues_all = np.concatenate(trues_all, axis=0)
    return preds_all, trues_all, horizons

def save_predictions_excel(run_dir: Path, preds, trues, horizons, stations, max_stations=300):
    N_total = preds.shape[-1]
    N_use = min(max_stations, N_total)
    st_sel = stations[:N_use]

    out_xlsx = run_dir / "test_pred_true_selected_horizons.xlsx"
    with pd.ExcelWriter(out_xlsx, engine="openpyxl") as writer:
        for hi, h in enumerate(horizons):
            df = pd.DataFrame({
                "station": np.repeat(st_sel, preds.shape[0]),
                "sample":  np.tile(np.arange(preds.shape[0]), N_use),
                "true":    trues[:, hi, :N_use].T.reshape(-1),
                "pred":    preds[:, hi, :N_use].T.reshape(-1),
            })
            df.to_excel(writer, sheet_name=f"h{h}", index=False)
    return out_xlsx

def train_and_save_best(
    model, model_name: str, run_dir: Path,
    epochs=40, lr=1e-3, weight_decay=1e-4, clip=5.0,
    patience=6, eval_every=2
):
    model = model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.MSELoss()

    best_score = float("inf")
    best_state = None
    bad = 0

    history = []

    for epoch in range(1, epochs+1):
        model.train()
        run_loss = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            tfb = tfb.to(DEVICE)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)               # scaled
            loss = loss_fn(pred, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            run_loss += float(loss.item())

        row = {"epoch": epoch, "train_loss": run_loss / max(1, len(train_loader))}
        history.append(row)

        # Evaluate every eval_every epochs
        if epoch % eval_every == 0:
            val_m = eval_horizons_fast(model, val_loader)
            score = avg_mae(val_m)

            print(f"\nEpoch {epoch}: train_loss={row['train_loss']:.6f} val_avg_MAE={score:.3f}")
            print_metrics("VAL", val_m)

            row.update({f"val_MAE_{h}h": val_m[h]["MAE"] for h in EVAL_HORIZONS})
            row.update({f"val_RMSE_{h}h": val_m[h]["RMSE"] for h in EVAL_HORIZONS})
            row["val_avg_MAE"] = score

            if score < best_score:
                best_score = score
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                bad = 0
                torch.save(best_state, run_dir / "best.pt")
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    break

    # Save history
    hist_df = pd.DataFrame(history)
    hist_df.to_csv(run_dir / "history.csv", index=False)
    print("Saved history:", run_dir / "history.csv")

    # Load best
    assert best_state is not None, "best_state is None (evaluation never ran?)"
    model.load_state_dict(best_state)
    return model, hist_df

def run_experiment_and_save(
    model_name: str,
    model: nn.Module,
    epochs=40, patience=6, eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    # Train
    model, history_df = train_and_save_best(
        model=model,
        model_name=model_name,
        run_dir=run_dir,
        epochs=epochs,
        patience=patience,
        eval_every=eval_every,
    )

    # Test metrics
    print("\nEvaluating on TEST set...")
    test_m = eval_horizons_fast(model, test_loader)
    print_metrics(f"{model_name} — TEST", test_m)

    # Save test metrics
    save_json(run_dir / "test_metrics.json", test_m)
    pd.DataFrame([metrics_to_flat_row(model_name, "test", test_m)]).to_csv(run_dir / "test_metrics.csv", index=False)

    # Collect & save predictions
    preds, trues, horizons = collect_predictions_selected_horizons(model, test_loader, horizons=horizons_to_save)
    np.savez_compressed(run_dir / "test_pred_true_selected_horizons.npz",
                        preds=preds, trues=trues, horizons=np.array(horizons))
    out_xlsx = save_predictions_excel(run_dir, preds, trues, horizons, stations, max_stations=max_stations_excel)

    # Update master summary CSV
    summary_row = metrics_to_flat_row(model_name, "test", test_m)
    out_summary = append_results_summary(summary_row)

    print("\nSaved run outputs to:", run_dir)
    print(" - best checkpoint:", run_dir / "best.pt")
    print(" - history:", run_dir / "history.csv")
    print(" - test metrics:", run_dir / "test_metrics.json")
    print(" - predictions (npz):", run_dir / "test_pred_true_selected_horizons.npz")
    print(" - predictions (xlsx):", out_xlsx)
    print(" - master summary:", out_summary)

    return run_dir


In [54]:
stgcn_base = STGCN_MultiHorizon(
    num_nodes=N,
    in_dim=Fdim,
    out_len=OUT_LEN,
    L_sp=L_sp,
    kt=3,       # temporal kernel
    Ks=3,       # cheb order
    dropout=0.1,
    c_t=64, c_s=16, c_out=64,
    blocks=2
).to(DEVICE)

# Sanity forward
xb, yb, tfb = next(iter(train_loader))
with torch.no_grad():
    out = stgcn_base(xb.to(DEVICE), tfb.to(DEVICE))
print("Forward output:", out.shape, "scaled pred stats:", float(out.mean()), float(out.std()))

run_dir = run_experiment_and_save(
    model_name="STGCN",
    model=stgcn_base,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
)


Forward output: torch.Size([8, 72, 1821]) scaled pred stats: 0.04451674595475197 0.5800633430480957
Run dir: artifacts/runs/20260209_045856_STGCN


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.250966 val_avg_MAE=162.807

VAL
   12h  MAE=147.694  RMSE=297.933
   24h  MAE=150.905  RMSE=307.435
   48h  MAE=165.279  RMSE=321.227
   72h  MAE=187.351  RMSE=345.004


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.198355 val_avg_MAE=151.319

VAL
   12h  MAE=126.232  RMSE=256.539
   24h  MAE=136.581  RMSE=282.147
   48h  MAE=158.525  RMSE=310.717
   72h  MAE=183.936  RMSE=353.417


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.185316 val_avg_MAE=150.038

VAL
   12h  MAE=123.796  RMSE=256.978
   24h  MAE=138.322  RMSE=281.302
   48h  MAE=161.329  RMSE=320.515
   72h  MAE=176.704  RMSE=342.250


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.177419 val_avg_MAE=151.114

VAL
   12h  MAE=125.383  RMSE=257.735
   24h  MAE=133.466  RMSE=277.751
   48h  MAE=162.796  RMSE=322.952
   72h  MAE=182.812  RMSE=353.390


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.173663 val_avg_MAE=147.846

VAL
   12h  MAE=127.529  RMSE=263.337
   24h  MAE=133.106  RMSE=271.559
   48h  MAE=155.514  RMSE=311.491
   72h  MAE=175.235  RMSE=343.807


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.170295 val_avg_MAE=143.184

VAL
   12h  MAE=123.229  RMSE=252.882
   24h  MAE=127.686  RMSE=265.887
   48h  MAE=156.364  RMSE=311.306
   72h  MAE=165.456  RMSE=322.152


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.166737 val_avg_MAE=143.280

VAL
   12h  MAE=117.789  RMSE=245.680
   24h  MAE=128.483  RMSE=265.882
   48h  MAE=157.306  RMSE=307.850
   72h  MAE=169.542  RMSE=335.639


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.164811 val_avg_MAE=144.770

VAL
   12h  MAE=114.700  RMSE=243.654
   24h  MAE=137.070  RMSE=273.928
   48h  MAE=161.213  RMSE=310.414
   72h  MAE=166.098  RMSE=318.674


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.164994 val_avg_MAE=144.036

VAL
   12h  MAE=120.426  RMSE=247.997
   24h  MAE=132.155  RMSE=274.025
   48h  MAE=157.831  RMSE=313.536
   72h  MAE=165.730  RMSE=328.414


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.162928 val_avg_MAE=147.952

VAL
   12h  MAE=111.936  RMSE=241.293
   24h  MAE=139.887  RMSE=280.909
   48h  MAE=165.217  RMSE=321.430
   72h  MAE=174.767  RMSE=338.671


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.161761 val_avg_MAE=146.622

VAL
   12h  MAE=122.780  RMSE=248.747
   24h  MAE=128.076  RMSE=263.646
   48h  MAE=162.222  RMSE=320.803
   72h  MAE=173.412  RMSE=339.150


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 24: train_loss=0.161647 val_avg_MAE=140.516

VAL
   12h  MAE=122.143  RMSE=251.693
   24h  MAE=123.930  RMSE=259.899
   48h  MAE=151.445  RMSE=302.201
   72h  MAE=164.544  RMSE=326.233


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 26: train_loss=0.159943 val_avg_MAE=149.155

VAL
   12h  MAE=123.250  RMSE=255.677
   24h  MAE=138.902  RMSE=280.275
   48h  MAE=163.449  RMSE=319.669
   72h  MAE=171.018  RMSE=334.956


Train 27/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 28/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 28: train_loss=0.159359 val_avg_MAE=142.014

VAL
   12h  MAE=111.286  RMSE=235.258
   24h  MAE=130.740  RMSE=267.831
   48h  MAE=158.796  RMSE=312.691
   72h  MAE=167.233  RMSE=330.366


Train 29/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 30/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 30: train_loss=0.158898 val_avg_MAE=137.910

VAL
   12h  MAE=117.151  RMSE=244.578
   24h  MAE=126.157  RMSE=265.732
   48h  MAE=151.016  RMSE=301.190
   72h  MAE=157.317  RMSE=310.011


Train 31/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 32/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 32: train_loss=0.158708 val_avg_MAE=136.992

VAL
   12h  MAE=115.324  RMSE=238.849
   24h  MAE=124.163  RMSE=258.722
   48h  MAE=149.866  RMSE=301.408
   72h  MAE=158.613  RMSE=313.007


Train 33/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 34/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 34: train_loss=0.157839 val_avg_MAE=142.080

VAL
   12h  MAE=113.515  RMSE=239.150
   24h  MAE=125.888  RMSE=264.526
   48h  MAE=156.802  RMSE=313.278
   72h  MAE=172.115  RMSE=339.594


Train 35/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 36: train_loss=0.158052 val_avg_MAE=143.181

VAL
   12h  MAE=117.767  RMSE=245.907
   24h  MAE=137.043  RMSE=282.010
   48h  MAE=155.687  RMSE=307.704
   72h  MAE=162.226  RMSE=317.296


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 38: train_loss=0.156788 val_avg_MAE=143.783

VAL
   12h  MAE=118.535  RMSE=243.556
   24h  MAE=128.388  RMSE=266.130
   48h  MAE=156.430  RMSE=315.717
   72h  MAE=171.780  RMSE=342.920


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 40: train_loss=0.156351 val_avg_MAE=140.454

VAL
   12h  MAE=112.206  RMSE=237.664
   24h  MAE=132.444  RMSE=273.634
   48h  MAE=152.302  RMSE=303.295
   72h  MAE=164.863  RMSE=323.979
Saved history: artifacts/runs/20260209_045856_STGCN/history.csv

Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN — TEST
   12h  MAE=119.126  RMSE=243.996
   24h  MAE=121.787  RMSE=247.606
   48h  MAE=139.610  RMSE=289.304
   72h  MAE=148.542  RMSE=300.345


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260209_045856_STGCN
 - best checkpoint: artifacts/runs/20260209_045856_STGCN/best.pt
 - history: artifacts/runs/20260209_045856_STGCN/history.csv
 - test metrics: artifacts/runs/20260209_045856_STGCN/test_metrics.json
 - predictions (npz): artifacts/runs/20260209_045856_STGCN/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260209_045856_STGCN/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv
