In [1]:
import torch
import numpy as np
import pandas as pd

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


Torch version: 2.1.1+cu121
CUDA available: True
GPU: Quadro P5000


In [2]:
from pathlib import Path

DATA_PATH = Path("artifacts/pems_graph_dataset_strict.npz")
assert DATA_PATH.exists(), f"Missing dataset: {DATA_PATH}"

data = np.load(DATA_PATH, allow_pickle=True)
print("Loaded:", DATA_PATH)
print("Keys:", list(data.keys()))

X = data["X"].astype(np.float32)          # (T,N,F)
Y = data["Y"].astype(np.float32)          # (T,N)
A = data["A"].astype(np.float32)          # (N,N)

stations = data["stations"]
timestamps = pd.to_datetime(data["timestamps"])

train_starts = data["train_starts"]
val_starts   = data["val_starts"]
test_starts  = data["test_starts"]

IN_LEN  = int(data["in_len"])
OUT_LEN = int(data["out_len"])

flow_mean  = data["flow_mean"]
flow_std   = data["flow_std"]
speed_mean = data["speed_mean"]
speed_std  = data["speed_std"]

T, N, Fdim = X.shape
print("X:", X.shape, "Y:", Y.shape)
print("A:", A.shape)
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)
print("Stations:", len(stations))
print("Time range:", timestamps.min(), "→", timestamps.max())


Loaded: artifacts/pems_graph_dataset_strict.npz
Keys: ['X', 'Y', 'A', 'stations', 'timestamps', 'train_starts', 'val_starts', 'test_starts', 'in_len', 'out_len', 'flow_mean', 'flow_std', 'speed_mean', 'speed_std']
X: (2208, 1821, 6) Y: (2208, 1821)
A: (1821, 1821)
IN_LEN: 24 OUT_LEN: 72
Stations: 1821
Time range: 2024-10-01 00:00:00 → 2024-12-31 23:00:00


  IN_LEN  = int(data["in_len"])
  OUT_LEN = int(data["out_len"])


In [3]:
def scaled_laplacian(A):
    A = np.maximum(A, A.T)                     # undirected
    A = A + np.eye(A.shape[0], dtype=np.float32)

    d = A.sum(axis=1)
    d_inv_sqrt = np.power(d, -0.5, where=(d > 0))
    d_inv_sqrt[~np.isfinite(d_inv_sqrt)] = 0.0

    A_norm = (d_inv_sqrt[:, None] * A) * d_inv_sqrt[None, :]
    L = np.eye(A.shape[0], dtype=np.float32) - A_norm

    lambda_max = 2.0
    L_tilde = (2.0 / lambda_max) * L - np.eye(A.shape[0], dtype=np.float32)
    return L_tilde

def dense_to_sparse(A_dense, device):
    idx = np.nonzero(A_dense)
    indices = torch.tensor(np.vstack(idx), dtype=torch.long)
    values = torch.tensor(A_dense[idx], dtype=torch.float32)
    return torch.sparse_coo_tensor(
        indices, values, size=A_dense.shape, device=device
    ).coalesce()

L_tilde = scaled_laplacian(A)
L_sp = dense_to_sparse(L_tilde, DEVICE)

print("L_sp nnz:", int(L_sp._nnz()))


L_sp nnz: 7856


In [4]:
class STGCNDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, starts, in_len, out_len,
                 flow_mean, flow_std, speed_mean, speed_std):
        self.X = X
        self.Y = Y
        self.starts = starts.astype(int)
        self.in_len = in_len
        self.out_len = out_len
        self.flow_mean = flow_mean
        self.flow_std = flow_std
        self.speed_mean = speed_mean
        self.speed_std = speed_std

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = self.starts[idx]

        x = self.X[t:t+self.in_len].copy()      # (IN,N,F)
        y = self.Y[t+self.in_len:t+self.in_len+self.out_len].copy()

        # scale
        x[:, :, 0] = (x[:, :, 0] - self.flow_mean) / self.flow_std
        x[:, :, 1] = (x[:, :, 1] - self.speed_mean) / self.speed_std
        y = (y - self.flow_mean) / self.flow_std

        x = np.transpose(x, (2,1,0))            # (F,N,IN)

        return (
            torch.tensor(x, dtype=torch.float32),
            torch.tensor(y, dtype=torch.float32)
        )

train_ds = STGCNDataset(X, Y, train_starts, IN_LEN, OUT_LEN,
                         flow_mean, flow_std, speed_mean, speed_std)
val_ds   = STGCNDataset(X, Y, val_starts, IN_LEN, OUT_LEN,
                         flow_mean, flow_std, speed_mean, speed_std)
test_ds  = STGCNDataset(X, Y, test_starts, IN_LEN, OUT_LEN,
                         flow_mean, flow_std, speed_mean, speed_std)

BATCH_SIZE = 8

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0, pin_memory=False
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, pin_memory=False
)

xb, yb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape)


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821])


In [5]:
IN_LEN  = int(np.array(data["in_len"]).item())
OUT_LEN = int(np.array(data["out_len"]).item())


In [6]:
from tqdm import tqdm   


In [7]:
import os, json, time, gc
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm

EVAL_HORIZONS = [12, 24, 48, 72]

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

def _unscale(y_scaled, flow_mean_t, flow_std_t):
    return y_scaled * flow_std_t + flow_mean_t

@torch.inference_mode()
def eval_horizons_fast(model, loader, device, flow_mean, flow_std, eval_horizons=EVAL_HORIZONS):
    model.eval()
    flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=device).view(1, 1, -1)
    flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=device).view(1, 1, -1)
    h_idx = torch.tensor([h - 1 for h in eval_horizons], device=device)

    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in eval_horizons}

    for batch in tqdm(loader, desc="Eval", leave=False):
        if len(batch) == 2:
            xb, yb = batch
            tfb = None
        else:
            xb, yb, tfb = batch

        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        if tfb is not None:
            tfb = tfb.to(device, non_blocking=True)

        pred = model(xb, tfb) if tfb is not None else model(xb)  # scaled

        pred_u = _unscale(pred, flow_mean_t, flow_std_t)
        true_u = _unscale(yb,   flow_mean_t, flow_std_t)

        pred_h = pred_u[:, h_idx, :]  # (B, H, N)
        true_h = true_u[:, h_idx, :]
        err = pred_h - true_h

        for i, h in enumerate(eval_horizons):
            e = err[:, i, :]
            acc[h]["abs"] += float(e.abs().sum().item())
            acc[h]["sq"]  += float((e * e).sum().item())
            acc[h]["count"] += e.numel()

    metrics = {}
    for h in eval_horizons:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": float(mae), "RMSE": float(rmse)}
    return metrics

def make_run_dir(model_name, base_dir="artifacts/runs"):
    base = Path(base_dir)
    base.mkdir(parents=True, exist_ok=True)
    stamp = time.strftime("%Y%m%d_%H%M%S")
    run_dir = base / f"{stamp}_{model_name}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(path: Path, obj):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def metrics_to_flat_row(metrics, prefix):
    row = {}
    for h in sorted(metrics.keys()):
        row[f"{prefix}_MAE_{h}h"]  = metrics[h]["MAE"]
        row[f"{prefix}_RMSE_{h}h"] = metrics[h]["RMSE"]
    return row

def append_master_summary(row_dict, master_csv="artifacts/results_summary.csv"):
    master = Path(master_csv)
    master.parent.mkdir(parents=True, exist_ok=True)
    df_new = pd.DataFrame([row_dict])
    if master.exists():
        df_old = pd.read_csv(master)
        df = pd.concat([df_old, df_new], ignore_index=True)
    else:
        df = df_new
    df.to_csv(master, index=False)
    return master

@torch.inference_mode()
def collect_preds_true_selected(model, loader, device, flow_mean, flow_std,
                               horizons_to_save=(12, 24, 48, 72),
                               stations_all=None,
                               timestamps_all=None,
                               in_len=None,
                               max_stations=300):
    """
    Returns:
      pred_u: (S, H, M)
      true_u: (S, H, M)
      horizons: list
      station_ids: (M,)
      ts_h: (S, H) timestamps for each sample/horizon if possible else None
    """
    model.eval()
    flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=device).view(1, 1, -1)
    flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=device).view(1, 1, -1)
    horizons = list(horizons_to_save)
    h_idx = torch.tensor([h - 1 for h in horizons], device=device)

    # station subset
    N = len(stations_all) if stations_all is not None else None
    if (max_stations is None) or (N is None) or (max_stations >= N):
        sel = None
        station_ids = np.array(stations_all) if stations_all is not None else None
    else:
        sel = np.arange(max_stations, dtype=int)
        station_ids = np.array(stations_all)[sel]

    preds, trues = [], []

    for batch in tqdm(loader, desc="Collect preds", leave=False):
        if len(batch) == 2:
            xb, yb = batch
            tfb = None
        else:
            xb, yb, tfb = batch

        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        if tfb is not None:
            tfb = tfb.to(device, non_blocking=True)

        pred = model(xb, tfb) if tfb is not None else model(xb)  # scaled

        pred_u = _unscale(pred, flow_mean_t, flow_std_t)[:, h_idx, :]  # (B,H,N)
        true_u = _unscale(yb,   flow_mean_t, flow_std_t)[:, h_idx, :]

        if sel is not None:
            pred_u = pred_u[:, :, sel]
            true_u = true_u[:, :, sel]

        preds.append(pred_u.detach().cpu())
        trues.append(true_u.detach().cpu())

    pred_u = torch.cat(preds, dim=0).numpy()
    true_u = torch.cat(trues, dim=0).numpy()

    # timestamps for each sample/horizon (optional)
    ts_h = None
    if (timestamps_all is not None) and (in_len is not None) and hasattr(loader.dataset, "starts"):
        starts = np.array(loader.dataset.starts, dtype=int)  # (S,)
        # For each horizon h, timestamp = timestamps[start + in_len + (h-1)]
        ts_h = np.zeros((len(starts), len(horizons)), dtype="datetime64[ns]")
        ts_all = pd.to_datetime(timestamps_all).to_numpy()
        for j, h in enumerate(horizons):
            ts_h[:, j] = ts_all[starts + in_len + (h - 1)]

    return pred_u, true_u, horizons, station_ids, ts_h

def save_preds_to_excel_and_csv(run_dir: Path, pred_u, true_u, horizons, station_ids, ts_h=None):
    # NPZ (always)
    npz_path = run_dir / "test_pred_true_selected_horizons.npz"
    np.savez_compressed(
        npz_path,
        pred=pred_u,
        true=true_u,
        horizons=np.array(horizons, dtype=int),
        stations=np.array(station_ids) if station_ids is not None else None,
        timestamps=ts_h
    )

    # Excel + CSV per horizon (readable)
    xlsx_path = run_dir / "test_pred_true_selected_horizons.xlsx"
    csv_dir = run_dir / "preds_csv"
    csv_dir.mkdir(parents=True, exist_ok=True)

    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        for j, h in enumerate(horizons):
            cols = [str(s) for s in station_ids] if station_ids is not None else [f"node_{i}" for i in range(pred_u.shape[2])]
            df_pred = pd.DataFrame(pred_u[:, j, :], columns=cols)
            df_true = pd.DataFrame(true_u[:, j, :], columns=cols)

            if ts_h is not None:
                df_pred.insert(0, "timestamp", pd.to_datetime(ts_h[:, j]))
                df_true.insert(0, "timestamp", pd.to_datetime(ts_h[:, j]))

            df_pred.to_excel(writer, sheet_name=f"pred_{h}h", index=False)
            df_true.to_excel(writer, sheet_name=f"true_{h}h", index=False)

            # CSV versions too
            df_pred.to_csv(csv_dir / f"pred_{h}h.csv", index=False)
            df_true.to_csv(csv_dir / f"true_{h}h.csv", index=False)

    return npz_path, xlsx_path, csv_dir

def train_and_save_best(model, model_name, run_dir: Path,
                        train_loader, val_loader,
                        device,
                        flow_mean, flow_std,
                        epochs=40, lr=1e-3, weight_decay=1e-4, clip=5.0,
                        patience=6, eval_every=2):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best_score = float("inf")
    best_state = None
    bad = 0
    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=False):
            if len(batch) == 2:
                xb, yb = batch
                tfb = None
            else:
                xb, yb, tfb = batch

            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            if tfb is not None:
                tfb = tfb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb) if tfb is not None else model(xb)  # scaled
            loss = loss_fn(pred, yb)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            running += float(loss.item())

        train_loss = running / max(1, len(train_loader))

        if epoch % eval_every == 0:
            val_metrics = eval_horizons_fast(model, val_loader, device, flow_mean, flow_std)
            score = avg_mae(val_metrics)

            print(f"\nEpoch {epoch}: train_loss={train_loss:.6f} val_avg_MAE={score:.3f}")
            print_metrics("VAL", val_metrics)

            history.append({"epoch": epoch, "train_loss": train_loss, "val_avg_MAE": score, **metrics_to_flat_row(val_metrics, "val")})
            pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)

            if score < best_score:
                best_score = score
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                torch.save(best_state, run_dir / "best.pt")
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    break

    if best_state is None:
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        torch.save(best_state, run_dir / "best.pt")

    model.load_state_dict(best_state)
    return model

def run_experiment_and_save(model_name, model,
                            train_loader, val_loader, test_loader,
                            device,
                            flow_mean, flow_std,
                            stations, timestamps,
                            in_len,
                            epochs=40, patience=6, eval_every=2,
                            horizons_to_save=(12, 24, 48, 72),
                            max_stations_excel=300):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    model = model.to(device)

    # Train (and keep saving best.pt + history.csv)
    model = train_and_save_best(
        model=model,
        model_name=model_name,
        run_dir=run_dir,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        flow_mean=flow_mean,
        flow_std=flow_std,
        epochs=epochs,
        patience=patience,
        eval_every=eval_every
    )

    # TEST
    print("\nEvaluating on TEST set...")
    test_metrics = eval_horizons_fast(model, test_loader, device, flow_mean, flow_std)
    print_metrics(f"{model_name} — TEST", test_metrics)

    save_json(run_dir / "test_metrics.json", test_metrics)
    pd.DataFrame([metrics_to_flat_row(test_metrics, "test")]).to_csv(run_dir / "test_metrics.csv", index=False)

    # Save preds/true for selected horizons + subset of stations (NPZ + XLSX + CSV)
    pred_u, true_u, horizons, station_ids, ts_h = collect_preds_true_selected(
        model=model,
        loader=test_loader,
        device=device,
        flow_mean=flow_mean,
        flow_std=flow_std,
        horizons_to_save=horizons_to_save,
        stations_all=stations,
        timestamps_all=timestamps,
        in_len=in_len,
        max_stations=max_stations_excel
    )

    npz_path, xlsx_path, csv_dir = save_preds_to_excel_and_csv(run_dir, pred_u, true_u, horizons, station_ids, ts_h)

    # Append to master summary
    row = {
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "model_name": model_name,
        "run_dir": str(run_dir),
        **metrics_to_flat_row(test_metrics, "test")
    }
    master = append_master_summary(row)

    print("\nSaved run outputs to:", run_dir)
    print(" - best checkpoint:", run_dir / "best.pt")
    print(" - history:", run_dir / "history.csv")
    print(" - test metrics (json):", run_dir / "test_metrics.json")
    print(" - test metrics (csv):", run_dir / "test_metrics.csv")
    print(" - predictions (npz):", npz_path)
    print(" - predictions (xlsx):", xlsx_path)
    print(" - predictions (csv folder):", csv_dir)
    print(" - master summary:", master)

    return run_dir


In [8]:
class STGCN_RNNHead(nn.Module):
    """
    Wrap a base STGCN model and refine its multi-horizon output using GRU/LSTM.
    Base model must output (B, OUT_LEN, N) in *scaled* space.
    """
    def __init__(self, base_model: nn.Module, out_len: int, rnn_hidden: int = 128,
                 use_gru: bool = False, use_lstm: bool = False):
        super().__init__()
        assert use_gru or use_lstm, "Turn on at least one of use_gru/use_lstm."
        self.base = base_model
        self.out_len = out_len

        in_dim = 1  # we feed the STGCN scalar output sequence per node

        self.gru = nn.GRU(in_dim, rnn_hidden, batch_first=True) if use_gru else None
        rnn_in = rnn_hidden if use_gru else in_dim

        self.lstm = nn.LSTM(rnn_in, rnn_hidden, batch_first=True) if use_lstm else None
        rnn_out = rnn_hidden if use_lstm else rnn_in

        self.proj = nn.Linear(rnn_out, 1)

    def forward(self, x, tf=None):
        y0 = self.base(x, tf) if tf is not None else self.base(x)   # (B, T, N)
        B, T, N = y0.shape

        seq = y0.permute(0, 2, 1).contiguous().view(B * N, T, 1)     # (B*N, T, 1)

        out = seq
        if self.gru is not None:
            out, _ = self.gru(out)
        if self.lstm is not None:
            out, _ = self.lstm(out)

        out = self.proj(out)                                        # (B*N, T, 1)
        out = out.view(B, N, T).permute(0, 2, 1).contiguous()        # (B, T, N)
        return out


In [9]:
def build_stgcn():
    return STGCN_MultiHorizon(
        num_nodes=N,
        in_dim=Fdim,
        out_len=OUT_LEN,
        L_sp=L_sp,
        kt=3,
        Ks=3,
        dropout=0.1,
        c_t=64, c_s=16, c_out=64,
        blocks=2
    )


In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------
# Helpers: sparse node-mix
# -------------------------
def nconv_sparse(x: torch.Tensor, A_sp: torch.Tensor) -> torch.Tensor:
    """
    x: (B, C, N, T)
    A_sp: sparse (N, N)
    returns: (B, C, N, T)
    """
    B, C, N, T = x.shape
    x_r = x.permute(2, 0, 1, 3).reshape(N, -1)          # (N, B*C*T)
    x_r = torch.sparse.mm(A_sp, x_r)                    # (N, B*C*T)
    x_out = x_r.reshape(N, B, C, T).permute(1, 2, 0, 3) # (B,C,N,T)
    return x_out


class TemporalConvGLU(nn.Module):
    """
    Causal temporal conv with GLU gating.
    Input/Output: (B,C,N,T)
    """
    def __init__(self, c_in: int, c_out: int, kt: int, dropout: float):
        super().__init__()
        self.kt = kt
        self.conv = nn.Conv2d(c_in, 2*c_out, kernel_size=(1, kt), bias=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # causal pad on the left in time dimension (last dim)
        x = F.pad(x, (self.kt - 1, 0, 0, 0))     # pad time: (left, right, top, bottom) for last two dims
        z = self.conv(x)                         # (B, 2*Cout, N, T)
        a, b = z.chunk(2, dim=1)
        out = a * torch.sigmoid(b)
        return self.dropout(out)


class ChebGraphConv(nn.Module):
    """
    Chebyshev graph convolution with Ks terms.
    Uses sparse scaled Laplacian L_sp (N,N).
    """
    def __init__(self, c_in: int, c_out: int, Ks: int, L_sp: torch.Tensor):
        super().__init__()
        assert Ks >= 1
        self.Ks = Ks
        self.L_sp = L_sp
        self.theta = nn.Conv2d(Ks * c_in, c_out, kernel_size=(1, 1), bias=True)

    def forward(self, x):
        # x: (B,C,N,T)
        out = [x]  # T0
        if self.Ks > 1:
            x1 = nconv_sparse(x, self.L_sp)  # T1
            out.append(x1)
            for _ in range(2, self.Ks):
                x2 = 2 * nconv_sparse(out[-1], self.L_sp) - out[-2]  # Tk
                out.append(x2)

        x_cat = torch.cat(out, dim=1)  # (B, Ks*C, N, T)
        return self.theta(x_cat)


class STConvBlock(nn.Module):
    """
    ST block: TemporalConv -> GraphConv -> TemporalConv (+ residual)
    """
    def __init__(self, c_in, c_t, c_s, c_out, kt, Ks, L_sp, dropout):
        super().__init__()
        self.temp1 = TemporalConvGLU(c_in,  c_t,  kt=kt, dropout=dropout)
        self.gconv = ChebGraphConv(c_t, c_s, Ks=Ks, L_sp=L_sp)
        self.temp2 = TemporalConvGLU(c_s,  c_out, kt=kt, dropout=dropout)

        self.res = None
        if c_in != c_out:
            self.res = nn.Conv2d(c_in, c_out, kernel_size=(1,1))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_in = x
        x = self.temp1(x)
        x = F.relu(self.gconv(x))
        x = self.temp2(x)

        if self.res is not None:
            x_in = self.res(x_in)

        x = F.relu(x + x_in)
        return self.dropout(x)


class STGCN_MultiHorizon(nn.Module):
    """
    Multi-horizon forecaster:
      encode past window -> take last time state -> project to OUT_LEN for each node
    forward(x, tf) keeps signature compatible (tf is ignored here).
    """
    def __init__(
        self,
        num_nodes: int,
        in_dim: int,
        out_len: int,
        L_sp: torch.Tensor,
        kt: int = 3,
        Ks: int = 3,
        dropout: float = 0.1,
        c_t: int = 64,
        c_s: int = 16,
        c_out: int = 64,
        blocks: int = 2,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_dim = in_dim
        self.out_len = out_len
        self.c_out = c_out

        layers = []
        c_in = in_dim
        for _ in range(blocks):
            layers.append(STConvBlock(
                c_in=c_in, c_t=c_t, c_s=c_s, c_out=c_out,
                kt=kt, Ks=Ks, L_sp=L_sp, dropout=dropout
            ))
            c_in = c_out
        self.blocks = nn.ModuleList(layers)

        # node-wise linear map: (B, c_out, N) -> (B, out_len, N)
        self.head = nn.Conv1d(c_out, out_len, kernel_size=1)

    def encode(self, x):
        # x: (B,F,N,T)
        h = x
        for blk in self.blocks:
            h = blk(h)  # (B,c_out,N,T)
        return h

    def forward(self, x, tf=None):
        h = self.encode(x)
        h_last = h[:, :, :, -1]           # (B, c_out, N)
        out = self.head(h_last)           # (B, out_len, N)
        return out


In [12]:
class STGCN_RNNHead(nn.Module):
    """
    Wrap STGCN encoder with optional GRU and/or LSTM over the encoder time sequence per node.
    If both are enabled: GRU -> LSTM (stacked).
    """
    def __init__(
        self,
        base: STGCN_MultiHorizon,
        out_len: int,
        rnn_hidden: int = 128,
        use_gru: bool = False,
        use_lstm: bool = False,
        dropout: float = 0.1,
    ):
        super().__init__()
        assert use_gru or use_lstm, "Enable at least one of GRU/LSTM"
        self.base = base
        self.out_len = out_len
        self.use_gru = use_gru
        self.use_lstm = use_lstm
        self.rnn_hidden = rnn_hidden

        enc_dim = base.c_out

        self.gru = None
        self.lstm = None

        if use_gru:
            self.gru = nn.GRU(
                input_size=enc_dim,
                hidden_size=rnn_hidden,
                num_layers=1,
                batch_first=True,
                dropout=0.0
            )

        if use_lstm:
            lstm_in = rnn_hidden if use_gru else enc_dim
            self.lstm = nn.LSTM(
                input_size=lstm_in,
                hidden_size=rnn_hidden,
                num_layers=1,
                batch_first=True,
                dropout=0.0
            )

        self.drop = nn.Dropout(dropout)
        self.head = nn.Conv1d(rnn_hidden, out_len, kernel_size=1)

    def forward(self, x, tf=None):
        # Encode: (B,C,N,T)
        feat = self.base.encode(x)
        B, C, N, T = feat.shape

        # per-node sequences: (B*N, T, C)
        seq = feat.permute(0, 2, 3, 1).contiguous().view(B*N, T, C)

        if self.use_gru:
            seq, _ = self.gru(seq)  # (B*N, T, H)

        if self.use_lstm:
            seq, _ = self.lstm(seq) # (B*N, T, H)

        last = seq[:, -1, :]                    # (B*N, H)
        last = self.drop(last)
        last = last.view(B, N, self.rnn_hidden).permute(0, 2, 1)  # (B,H,N)

        out = self.head(last)                   # (B,out_len,N)
        return out


In [13]:
import numpy as np
import pandas as pd
from pathlib import Path

def save_pred_true_csv_long(
    out_csv: Path,
    pred_u: np.ndarray,   # (S, H, N)
    true_u: np.ndarray,   # (S, H, N)
    horizons: list[int],
    station_ids: list[str] | None = None,
    max_stations: int = 300,
):
    """
    Saves long-form CSV:
      sample, horizon, station, y_true, y_pred
    To keep it sane, we limit to first max_stations.
    """
    S, H, N = pred_u.shape
    assert H == len(horizons)

    take = min(N, max_stations)
    stations = station_ids[:take] if station_ids is not None else [f"node_{i}" for i in range(take)]

    rows = []
    for hi, h in enumerate(horizons):
        # (S, take)
        p = pred_u[:, hi, :take]
        t = true_u[:, hi, :take]

        # build long rows efficiently
        sample_idx = np.repeat(np.arange(S), take)
        station_col = np.tile(np.array(stations, dtype=object), S)

        df_h = pd.DataFrame({
            "sample": sample_idx,
            "horizon_h": h,
            "station": station_col,
            "y_true": t.reshape(-1),
            "y_pred": p.reshape(-1),
        })
        rows.append(df_h)

    df = pd.concat(rows, ignore_index=True)
    df.to_csv(out_csv, index=False)
    return out_csv


### WORKING ON STGCN AGAIN

In [1]:
import os, json, time
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

# ---------------- Repro ----------------
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__)
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

Torch: 2.1.1+cu121
Device: cuda
GPU: Quadro P5000


In [2]:
DATA_PATH = Path("artifacts/pems_graph_dataset_strict.npz")
assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Rebuild the dataset first."

data = np.load(DATA_PATH, allow_pickle=True)
print("Loaded:", DATA_PATH)
print("Keys:", list(data.keys()))

X_raw = data["X"]            # (T, N, F)
Y_raw = data["Y"]            # (T, N)  (flow target)
A = data["A"]                # (N, N)
stations = data["stations"]
timestamps = data["timestamps"]

train_starts = data["train_starts"]
val_starts   = data["val_starts"]
test_starts  = data["test_starts"]

IN_LEN  = int(data["in_len"])
OUT_LEN = int(data["out_len"])

flow_mean = data["flow_mean"]   # (N,)
flow_std  = data["flow_std"]    # (N,)
speed_mean = data["speed_mean"] # (N,)
speed_std  = data["speed_std"]  # (N,)

T, N, Fdim = X_raw.shape
print("\nShapes:")
print("X_raw:", X_raw.shape, "(T,N,F)")
print("Y_raw:", Y_raw.shape, "(T,N)")
print("A:", A.shape)
print("IN_LEN:", IN_LEN, "OUT_LEN:", OUT_LEN)
print("train/val/test starts:", len(train_starts), len(val_starts), len(test_starts))

Loaded: artifacts/pems_graph_dataset_strict.npz
Keys: ['X', 'Y', 'A', 'stations', 'timestamps', 'train_starts', 'val_starts', 'test_starts', 'in_len', 'out_len', 'flow_mean', 'flow_std', 'speed_mean', 'speed_std']

Shapes:
X_raw: (2208, 1821, 6) (T,N,F)
Y_raw: (2208, 1821) (T,N)
A: (1821, 1821)
IN_LEN: 24 OUT_LEN: 72
train/val/test starts: 1009 289 673


  IN_LEN  = int(data["in_len"])
  OUT_LEN = int(data["out_len"])


In [3]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)

dt_idx = pd.to_datetime(timestamps)
TF_all = time_encoding(dt_idx)         # (T,4)

# ----- scale inputs -----
X_scaled = X_raw.astype(np.float32).copy()
X_scaled[:, :, 0] = (X_scaled[:, :, 0] - flow_mean[None, :]) / (flow_std[None, :] + 1e-6)
X_scaled[:, :, 1] = (X_scaled[:, :, 1] - speed_mean[None, :]) / (speed_std[None, :] + 1e-6)

# ----- scale targets (flow) -----
Y_scaled = (Y_raw.astype(np.float32) - flow_mean[None, :]) / (flow_std[None, :] + 1e-6)

# Store for fast slicing as (F,N,T)
X_fnt = np.transpose(X_scaled, (2, 1, 0)).copy()  # (F,N,T)

print("X_fnt:", X_fnt.shape, "Y_scaled:", Y_scaled.shape, "TF_all:", TF_all.shape)
print("Sanity (Y_scaled mean/std approx):", float(Y_scaled.mean()), float(Y_scaled.std()))

X_fnt: (6, 1821, 2208) Y_scaled: (2208, 1821) TF_all: (2208, 4)
Sanity (Y_scaled mean/std approx): -780.4212036132812 30666.189453125


In [4]:
class FastPeMSWindowDataset(Dataset):
    def __init__(self, X_fnt, Y_scaled, TF_all, starts, in_len, out_len):
        self.X_fnt = X_fnt
        self.Y = Y_scaled
        self.TF = TF_all
        self.starts = starts.astype(np.int64)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        t = int(self.starts[idx])
        x = self.X_fnt[:, :, t:t+self.in_len]  # (F,N,IN_LEN)
        y = self.Y[t+self.in_len:t+self.in_len+self.out_len, :]  # (OUT_LEN,N)
        tf = self.TF[t+self.in_len:t+self.in_len+self.out_len, :]  # (OUT_LEN,4)
        return (
            torch.from_numpy(x).float(),
            torch.from_numpy(y).float(),
            torch.from_numpy(tf).float()
        )

train_ds = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, train_starts, IN_LEN, OUT_LEN)
val_ds   = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, val_starts,   IN_LEN, OUT_LEN)
test_ds  = FastPeMSWindowDataset(X_fnt, Y_scaled, TF_all, test_starts,  IN_LEN, OUT_LEN)

BATCH_SIZE = 8
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)

Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


In [5]:
def dense_to_sparse(A_dense: np.ndarray, device: str):
    idx = np.nonzero(A_dense)
    indices = torch.from_numpy(np.vstack(idx)).long()
    values  = torch.from_numpy(A_dense[idx].astype(np.float32))
    sp = torch.sparse_coo_tensor(indices, values, size=A_dense.shape, device=device)
    return sp.coalesce()

def scaled_laplacian(A: np.ndarray) -> np.ndarray:
    """
    Build scaled Laplacian L_tilde = (2/lambda_max)*L - I.
    We follow the common approximation lambda_max ≈ 2. :contentReference[oaicite:3]{index=3}
    """
    A = A.astype(np.float32)
    # Make undirected for STGCN (common choice)
    A = np.maximum(A, A.T)

    # Add self loops
    A = A + np.eye(A.shape[0], dtype=np.float32)

    d = A.sum(axis=1)
    d_inv_sqrt = np.power(d, -0.5, where=(d > 0))
    d_inv_sqrt[~np.isfinite(d_inv_sqrt)] = 0.0

    A_norm = (d_inv_sqrt[:, None] * A) * d_inv_sqrt[None, :]
    L = np.eye(A.shape[0], dtype=np.float32) - A_norm

    lambda_max = 2.0
    L_tilde = (2.0 / lambda_max) * L - np.eye(A.shape[0], dtype=np.float32)
    return L_tilde

L_tilde = scaled_laplacian(A)
L_sp = dense_to_sparse(L_tilde, DEVICE)
print("L_sp nnz:", int(L_sp._nnz()))

L_sp nnz: 7856


In [6]:
EVAL_HORIZONS = [12, 24, 48, 72]
h_idx = torch.tensor([h-1 for h in EVAL_HORIZONS], device=DEVICE)

flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=False):
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # MUST be scaled outputs (B,OUT_LEN,N)

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        # selected horizons
        pred_sel = pred_u[:, h_idx, :]
        true_sel = true_u[:, h_idx, :]
        for i, h in enumerate(EVAL_HORIZONS):
            err = pred_sel[:, i, :] - true_sel[:, i, :]
            acc[h]["abs"] += float(err.abs().sum())
            acc[h]["sq"]  += float((err ** 2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": mae, "RMSE": rmse}
    return metrics

In [7]:
class NConv(nn.Module):
    """Sparse matrix multiply along node dimension."""
    def forward(self, x, A_sp):
        # x: (B, C, N, T)
        B, C, N, T = x.shape
        x_r = x.permute(2, 0, 1, 3).reshape(N, -1).float()      # (N, B*C*T)
        x_r = torch.sparse.mm(A_sp, x_r)                         # (N, B*C*T)
        x_out = x_r.reshape(N, B, C, T).permute(1, 2, 0, 3)      # (B, C, N, T)
        return x_out

class ChebGraphConv(nn.Module):
    """
    Chebyshev graph conv using recurrence:
      T0(X)=X
      T1(X)=L~ X
      Tk(X)=2 L~ T_{k-1}(X) - T_{k-2}(X)
    Then 1x1 conv mixes the K stacks.
    """
    def __init__(self, c_in, c_out, K, L_sp):
        super().__init__()
        self.K = K
        self.L_sp = L_sp
        self.nconv = NConv()
        self.mlp = nn.Conv2d(c_in * K, c_out, kernel_size=(1,1))

    def forward(self, x):
        # x: (B,C,N,T)
        out = [x]
        if self.K > 1:
            x1 = self.nconv(x, self.L_sp)
            out.append(x1)
        for k in range(2, self.K):
            x2 = 2.0 * self.nconv(out[-1], self.L_sp) - out[-2]
            out.append(x2)

        h = torch.cat(out, dim=1)  # (B, C*K, N, T)
        return self.mlp(h)

class TemporalGLU(nn.Module):
    """Temporal convolution + GLU gating. No padding -> time shrinks."""
    def __init__(self, c_in, c_out, kt):
        super().__init__()
        self.kt = kt
        self.conv = nn.Conv2d(c_in, 2*c_out, kernel_size=(1, kt))

    def forward(self, x):
        # x: (B,C,N,T)
        z = self.conv(x)                 # (B,2C,N,T-kt+1)
        P, Q = torch.chunk(z, 2, dim=1)  # each (B,C,N,T')
        return P * torch.sigmoid(Q)

class STConvBlock(nn.Module):
    """
    STGCN block: TemporalGLU -> ChebGraphConv -> ReLU -> TemporalGLU
    + residual (time-aligned) + LayerNorm over channels
    """
    def __init__(self, c_in, c_t, c_s, c_out, kt, Ks, L_sp, dropout=0.0):
        super().__init__()
        self.temporal1 = TemporalGLU(c_in, c_t, kt)
        self.graphconv = ChebGraphConv(c_t, c_s, Ks, L_sp)
        self.temporal2 = TemporalGLU(c_s, c_out, kt)

        self.res_conv = None
        if c_in != c_out:
            self.res_conv = nn.Conv2d(c_in, c_out, kernel_size=(1,1))

        self.ln = nn.LayerNorm(c_out)
        self.drop = nn.Dropout(dropout)

        self.kt = kt

    def forward(self, x):
        # x: (B,C_in,N,T)
        x_in = x
        x = self.temporal1(x)            # (B,c_t,N,T1)
        x = self.graphconv(x)            # (B,c_s,N,T1)
        x = F.relu(x)
        x = self.temporal2(x)            # (B,c_out,N,T2)

        # residual: align last T2 timesteps
        T2 = x.shape[-1]
        res = x_in[..., -T2:]
        if self.res_conv is not None:
            res = self.res_conv(res)
        x = x + res

        x = self.drop(x)

        # LayerNorm over channels (per node per time)
        x = x.permute(0, 2, 3, 1)        # (B,N,T,C)
        x = self.ln(x)
        x = x.permute(0, 3, 1, 2)        # (B,C,N,T)
        return x

class STGCN_MultiHorizon(nn.Module):
    """
    STGCN encoder + multi-horizon head.
    Output is (B, OUT_LEN, N) in SCALED space (no unscale inside).
    """
    def __init__(self, num_nodes, in_dim, out_len, L_sp,
                 kt=3, Ks=3, dropout=0.1,
                 c_t=64, c_s=16, c_out=64, blocks=2):
        super().__init__()
        self.out_len = out_len

        layers = []
        c_in = in_dim
        for _ in range(blocks):
            layers.append(STConvBlock(c_in, c_t=c_t, c_s=c_s, c_out=c_out,
                                      kt=kt, Ks=Ks, L_sp=L_sp, dropout=dropout))
            c_in = c_out
        self.blocks = nn.ModuleList(layers)

        # After blocks, time is reduced by blocks * 2*(kt-1)
        # We will infer the remaining time at runtime and build head lazily if needed.
        self.head = None
        self.c_out = c_out

    def _build_head(self, T_rem):
        # Collapse time dimension into 1, output channels = out_len
        self.head = nn.Conv2d(self.c_out, self.out_len, kernel_size=(1, T_rem))

    def forward(self, x, tf_future=None):
        # x: (B,F,N,IN_LEN)
        for blk in self.blocks:
            x = blk(x)

        T_rem = x.shape[-1]
        if self.head is None:
            self._build_head(T_rem)
            self.head = self.head.to(x.device)

        y = self.head(x)       # (B,OUT_LEN,N,1)
        y = y.squeeze(-1)      # (B,OUT_LEN,N)
        return y

In [8]:
ART_DIR = Path("artifacts")
RUNS_DIR = ART_DIR / "runs"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

def make_run_dir(model_name: str) -> Path:
    ts = time.strftime("%Y%m%d_%H%M%S")
    run_dir = RUNS_DIR / f"{ts}_{model_name}"
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir

def save_json(path: Path, obj: dict):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def metrics_to_flat_row(model_name: str, split: str, metrics: dict) -> dict:
    row = {"model_name": model_name, "split": split}
    for h in EVAL_HORIZONS:
        row[f"{split}_MAE_{h}h"] = metrics[h]["MAE"]
        row[f"{split}_RMSE_{h}h"] = metrics[h]["RMSE"]
    row[f"{split}_avg_MAE"] = avg_mae(metrics)
    return row

def append_results_summary(row: dict, out_csv: Path = ART_DIR/"results_summary.csv"):
    df_new = pd.DataFrame([row])
    if out_csv.exists():
        df_old = pd.read_csv(out_csv)
        df = pd.concat([df_old, df_new], ignore_index=True)
    else:
        df = df_new
    df.to_csv(out_csv, index=False)
    return out_csv

@torch.inference_mode()
def collect_predictions_selected_horizons(model, loader, horizons=(12,24,48,72)):
    model.eval()
    h_idx_local = torch.tensor([h-1 for h in horizons], device=DEVICE)
    preds_all = []
    trues_all = []
    for xb, yb, tfb in tqdm(loader, desc="Collect preds", leave=False):
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        tfb = tfb.to(DEVICE)
        pred = model(xb, tfb)  # scaled

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        preds_all.append(pred_u[:, h_idx_local, :].detach().cpu().numpy())
        trues_all.append(true_u[:, h_idx_local, :].detach().cpu().numpy())

    preds_all = np.concatenate(preds_all, axis=0)  # (Btot, Hsel, N)
    trues_all = np.concatenate(trues_all, axis=0)
    return preds_all, trues_all, horizons

def save_predictions_excel(run_dir: Path, preds, trues, horizons, stations, max_stations=300):
    N_total = preds.shape[-1]
    N_use = min(max_stations, N_total)
    st_sel = stations[:N_use]

    out_xlsx = run_dir / "test_pred_true_selected_horizons.xlsx"
    with pd.ExcelWriter(out_xlsx, engine="openpyxl") as writer:
        for hi, h in enumerate(horizons):
            df = pd.DataFrame({
                "station": np.repeat(st_sel, preds.shape[0]),
                "sample":  np.tile(np.arange(preds.shape[0]), N_use),
                "true":    trues[:, hi, :N_use].T.reshape(-1),
                "pred":    preds[:, hi, :N_use].T.reshape(-1),
            })
            df.to_excel(writer, sheet_name=f"h{h}", index=False)
    return out_xlsx

def train_and_save_best(
    model, model_name: str, run_dir: Path,
    epochs=40, lr=1e-3, weight_decay=1e-4, clip=5.0,
    patience=6, eval_every=2
):
    model = model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.MSELoss()

    best_score = float("inf")
    best_state = None
    bad = 0

    history = []

    for epoch in range(1, epochs+1):
        model.train()
        run_loss = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            tfb = tfb.to(DEVICE)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)               # scaled
            loss = loss_fn(pred, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()

            run_loss += float(loss.item())

        row = {"epoch": epoch, "train_loss": run_loss / max(1, len(train_loader))}
        history.append(row)

        # Evaluate every eval_every epochs
        if epoch % eval_every == 0:
            val_m = eval_horizons_fast(model, val_loader)
            score = avg_mae(val_m)

            print(f"\nEpoch {epoch}: train_loss={row['train_loss']:.6f} val_avg_MAE={score:.3f}")
            print_metrics("VAL", val_m)

            row.update({f"val_MAE_{h}h": val_m[h]["MAE"] for h in EVAL_HORIZONS})
            row.update({f"val_RMSE_{h}h": val_m[h]["RMSE"] for h in EVAL_HORIZONS})
            row["val_avg_MAE"] = score

            if score < best_score:
                best_score = score
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                bad = 0
                torch.save(best_state, run_dir / "best.pt")
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best_score:.3f}")
                    break

    # Save history
    hist_df = pd.DataFrame(history)
    hist_df.to_csv(run_dir / "history.csv", index=False)
    print("Saved history:", run_dir / "history.csv")

    # Load best
    assert best_state is not None, "best_state is None (evaluation never ran?)"
    model.load_state_dict(best_state)
    return model, hist_df

def run_experiment_and_save(
    model_name: str,
    model: nn.Module,
    epochs=40, patience=6, eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    # Train
    model, history_df = train_and_save_best(
        model=model,
        model_name=model_name,
        run_dir=run_dir,
        epochs=epochs,
        patience=patience,
        eval_every=eval_every,
    )

    # Test metrics
    print("\nEvaluating on TEST set...")
    test_m = eval_horizons_fast(model, test_loader)
    print_metrics(f"{model_name} — TEST", test_m)

    # Save test metrics
    save_json(run_dir / "test_metrics.json", test_m)
    pd.DataFrame([metrics_to_flat_row(model_name, "test", test_m)]).to_csv(run_dir / "test_metrics.csv", index=False)

    # Collect & save predictions
    preds, trues, horizons = collect_predictions_selected_horizons(model, test_loader, horizons=horizons_to_save)
    np.savez_compressed(run_dir / "test_pred_true_selected_horizons.npz",
                        preds=preds, trues=trues, horizons=np.array(horizons))
    out_xlsx = save_predictions_excel(run_dir, preds, trues, horizons, stations, max_stations=max_stations_excel)

    # Update master summary CSV
    summary_row = metrics_to_flat_row(model_name, "test", test_m)
    out_summary = append_results_summary(summary_row)

    print("\nSaved run outputs to:", run_dir)
    print(" - best checkpoint:", run_dir / "best.pt")
    print(" - history:", run_dir / "history.csv")
    print(" - test metrics:", run_dir / "test_metrics.json")
    print(" - predictions (npz):", run_dir / "test_pred_true_selected_horizons.npz")
    print(" - predictions (xlsx):", out_xlsx)
    print(" - master summary:", out_summary)

    return run_dir

In [9]:
from torch.utils.data import DataLoader

BATCH_SIZE_ABL = 8  

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE_ABL, shuffle=True, num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE_ABL, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE_ABL, shuffle=False, num_workers=0, pin_memory=False)

xb, yb, tfb = next(iter(train_loader))
print("Batch check:", xb.shape, yb.shape, tfb.shape)


Batch check: torch.Size([8, 6, 1821, 24]) torch.Size([8, 72, 1821]) torch.Size([8, 72, 4])


In [10]:
import torch
import torch.nn as nn
import numpy as np

# ----------------------------
# Assumes you already have:
# N, Fdim, OUT_LEN, L_sp, DEVICE
# and STGCN_MultiHorizon class defined
# ----------------------------

def build_stgcn_backbone():
    # EXACT same hyperparams as your STGCN baseline
    return STGCN_MultiHorizon(
        num_nodes=N,
        in_dim=Fdim,
        out_len=OUT_LEN,
        L_sp=L_sp,
        kt=3,
        Ks=3,
        dropout=0.1,
        c_t=64, c_s=16, c_out=64,
        blocks=2
    )

class HorizonRNNRefinement(nn.Module):
    """
    Takes baseline multi-horizon predictions y0: (B, T, N),
    runs an RNN across horizon dimension T for each node independently,
    and outputs y = y0 + delta (residual refinement).
    """
    def __init__(self, out_len, mode="gru", hidden=32, dropout=0.1):
        super().__init__()
        self.out_len = out_len
        self.mode = mode
        self.drop = nn.Dropout(dropout)

        if mode == "gru":
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden, num_layers=1, batch_first=True)
            self.proj = nn.Linear(hidden, 1)

        elif mode == "lstm":
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden, num_layers=1, batch_first=True)
            self.proj = nn.Linear(hidden, 1)

        elif mode == "gru_lstm":
            self.gru  = nn.GRU(input_size=1, hidden_size=hidden, num_layers=1, batch_first=True)
            self.lstm = nn.LSTM(input_size=hidden, hidden_size=hidden, num_layers=1, batch_first=True)
            self.proj = nn.Linear(hidden, 1)

        else:
            raise ValueError(f"Unknown mode={mode}")

    def forward(self, y0):
        # y0: (B, T, N)
        B, T, Nn = y0.shape
        assert T == self.out_len

        # reshape into (B*N, T, 1)
        seq = y0.permute(0, 2, 1).contiguous().view(B * Nn, T, 1)

        if self.mode in ("gru", "lstm"):
            out, _ = self.rnn(seq)
            out = self.drop(out)
        else:
            out, _ = self.gru(seq)
            out = self.drop(out)
            out, _ = self.lstm(out)
            out = self.drop(out)

        delta = self.proj(out)  # (B*N, T, 1)
        delta = delta.view(B, Nn, T).permute(0, 2, 1).contiguous()  # (B, T, N)

        return y0 + delta

class STGCN_WithHorizonRNN(nn.Module):
    def __init__(self, backbone, rnn_mode="gru", rnn_hidden=32, dropout=0.1):
        super().__init__()
        self.backbone = backbone
        self.head = HorizonRNNRefinement(
            out_len=OUT_LEN,
            mode=rnn_mode,
            hidden=rnn_hidden,
            dropout=dropout
        )

    def forward(self, x, tf):
        y0 = self.backbone(x, tf)      # (B, OUT_LEN, N)
        y  = self.head(y0)             # refined (B, OUT_LEN, N)
        return y


In [11]:
RNN_H = 64   

# STGCN + GRU
stgcn_gru = STGCN_WithHorizonRNN(
    backbone=build_stgcn_backbone(),
    rnn_mode="gru",
    rnn_hidden=RNN_H,
    dropout=0.1
).to(DEVICE)

# STGCN + LSTM
stgcn_lstm = STGCN_WithHorizonRNN(
    backbone=build_stgcn_backbone(),
    rnn_mode="lstm",
    rnn_hidden=RNN_H,
    dropout=0.1
).to(DEVICE)

# STGCN + GRU + LSTM
stgcn_gru_lstm = STGCN_WithHorizonRNN(
    backbone=build_stgcn_backbone(),
    rnn_mode="gru_lstm",
    rnn_hidden=RNN_H,
    dropout=0.1
).to(DEVICE)

# Sanity forward on one batch
xb, yb, tfb = next(iter(train_loader))
with torch.no_grad():
    o1 = stgcn_gru(xb.to(DEVICE), tfb.to(DEVICE))
    o2 = stgcn_lstm(xb.to(DEVICE), tfb.to(DEVICE))
    o3 = stgcn_gru_lstm(xb.to(DEVICE), tfb.to(DEVICE))

print("GRU out:", o1.shape, float(o1.mean()), float(o1.std()))
print("LSTM out:", o2.shape, float(o2.mean()), float(o2.std()))
print("GRU+LSTM out:", o3.shape, float(o3.mean()), float(o3.std()))


GRU out: torch.Size([8, 72, 1821]) 0.02864772267639637 0.5788347721099854
LSTM out: torch.Size([8, 72, 1821]) 0.12592457234859467 0.6218409538269043
GRU+LSTM out: torch.Size([8, 72, 1821]) 0.08967868238687515 0.5888320207595825


In [12]:
#  make station subset selection reproducible 
np.random.seed(42)
torch.manual_seed(42)

run_dir_gru = run_experiment_and_save(
    model_name="STGCN_GRU",
    model=stgcn_gru,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
)

np.random.seed(42)
torch.manual_seed(42)

run_dir_lstm = run_experiment_and_save(
    model_name="STGCN_LSTM",
    model=stgcn_lstm,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
)

np.random.seed(42)
torch.manual_seed(42)

run_dir_gru_lstm = run_experiment_and_save(
    model_name="STGCN_GRU_LSTM",
    model=stgcn_gru_lstm,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12,24,48,72),
    max_stations_excel=300
)

print("Done.")
print("GRU run dir:", run_dir_gru)
print("LSTM run dir:", run_dir_lstm)
print("GRU_LSTM run dir:", run_dir_gru_lstm)


Run dir: artifacts/runs/20260210_205540_STGCN_GRU


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.233463 val_avg_MAE=162.356

VAL
   12h  MAE=132.676  RMSE=274.626
   24h  MAE=145.894  RMSE=296.259
   48h  MAE=169.784  RMSE=326.262
   72h  MAE=201.069  RMSE=372.833


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.192946 val_avg_MAE=157.003

VAL
   12h  MAE=130.237  RMSE=264.035
   24h  MAE=141.342  RMSE=288.684
   48h  MAE=169.147  RMSE=331.933
   72h  MAE=187.287  RMSE=358.725


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.181731 val_avg_MAE=145.411

VAL
   12h  MAE=123.244  RMSE=256.278
   24h  MAE=130.935  RMSE=274.027
   48h  MAE=161.921  RMSE=320.886
   72h  MAE=165.542  RMSE=327.881


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.176248 val_avg_MAE=155.406

VAL
   12h  MAE=122.552  RMSE=253.167
   24h  MAE=145.858  RMSE=298.626
   48h  MAE=168.445  RMSE=332.267
   72h  MAE=184.768  RMSE=365.919


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.172655 val_avg_MAE=147.587

VAL
   12h  MAE=121.398  RMSE=254.440
   24h  MAE=135.759  RMSE=283.503
   48h  MAE=156.392  RMSE=313.810
   72h  MAE=176.799  RMSE=349.246


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.170455 val_avg_MAE=142.511

VAL
   12h  MAE=118.069  RMSE=246.943
   24h  MAE=125.894  RMSE=267.419
   48h  MAE=155.732  RMSE=312.486
   72h  MAE=170.348  RMSE=336.480


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.167309 val_avg_MAE=138.601

VAL
   12h  MAE=115.004  RMSE=243.286
   24h  MAE=126.489  RMSE=268.424
   48h  MAE=151.298  RMSE=303.059
   72h  MAE=161.613  RMSE=319.477


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.165813 val_avg_MAE=140.444

VAL
   12h  MAE=112.601  RMSE=240.774
   24h  MAE=130.684  RMSE=273.027
   48h  MAE=155.929  RMSE=313.941
   72h  MAE=162.563  RMSE=324.785


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.164232 val_avg_MAE=142.060

VAL
   12h  MAE=110.060  RMSE=237.707
   24h  MAE=126.093  RMSE=268.763
   48h  MAE=159.701  RMSE=318.785
   72h  MAE=172.388  RMSE=338.958


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.162565 val_avg_MAE=146.843

VAL
   12h  MAE=119.711  RMSE=246.898
   24h  MAE=134.406  RMSE=273.631
   48h  MAE=155.861  RMSE=311.123
   72h  MAE=177.395  RMSE=350.573


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.161220 val_avg_MAE=145.467

VAL
   12h  MAE=120.149  RMSE=247.475
   24h  MAE=132.424  RMSE=277.767
   48h  MAE=158.536  RMSE=322.430
   72h  MAE=170.758  RMSE=341.814


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 36: train_loss=0.157849 val_avg_MAE=141.897

VAL
   12h  MAE=107.517  RMSE=232.973
   24h  MAE=133.671  RMSE=278.473
   48h  MAE=154.960  RMSE=311.373
   72h  MAE=171.438  RMSE=344.917


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 38: train_loss=0.158196 val_avg_MAE=140.529

VAL
   12h  MAE=114.924  RMSE=242.604
   24h  MAE=127.911  RMSE=269.247
   48h  MAE=153.328  RMSE=312.856
   72h  MAE=165.951  RMSE=333.921


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 40: train_loss=0.157007 val_avg_MAE=140.659

VAL
   12h  MAE=114.564  RMSE=244.315
   24h  MAE=127.082  RMSE=267.980
   48h  MAE=152.742  RMSE=307.844
   72h  MAE=168.250  RMSE=335.281
Saved history: artifacts/runs/20260210_205540_STGCN_GRU/history.csv

Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN_GRU — TEST
   12h  MAE=114.700  RMSE=237.349
   24h  MAE=121.738  RMSE=246.057
   48h  MAE=139.977  RMSE=288.806
   72h  MAE=154.095  RMSE=311.351


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260210_205540_STGCN_GRU
 - best checkpoint: artifacts/runs/20260210_205540_STGCN_GRU/best.pt
 - history: artifacts/runs/20260210_205540_STGCN_GRU/history.csv
 - test metrics: artifacts/runs/20260210_205540_STGCN_GRU/test_metrics.json
 - predictions (npz): artifacts/runs/20260210_205540_STGCN_GRU/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260210_205540_STGCN_GRU/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv
Run dir: artifacts/runs/20260210_213135_STGCN_LSTM


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.235438 val_avg_MAE=168.156

VAL
   12h  MAE=136.834  RMSE=282.608
   24h  MAE=156.041  RMSE=312.736
   48h  MAE=186.112  RMSE=347.875
   72h  MAE=193.638  RMSE=366.169


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.191357 val_avg_MAE=156.985

VAL
   12h  MAE=133.067  RMSE=262.449
   24h  MAE=140.911  RMSE=283.656
   48h  MAE=170.360  RMSE=325.826
   72h  MAE=183.601  RMSE=351.599


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.179752 val_avg_MAE=146.464

VAL
   12h  MAE=121.773  RMSE=254.973
   24h  MAE=130.611  RMSE=275.136
   48h  MAE=163.391  RMSE=324.446
   72h  MAE=170.083  RMSE=333.749


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.174459 val_avg_MAE=152.235

VAL
   12h  MAE=122.576  RMSE=254.195
   24h  MAE=145.103  RMSE=294.936
   48h  MAE=166.064  RMSE=327.203
   72h  MAE=175.195  RMSE=343.373


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.171299 val_avg_MAE=147.710

VAL
   12h  MAE=120.365  RMSE=248.124
   24h  MAE=138.321  RMSE=285.497
   48h  MAE=158.271  RMSE=316.608
   72h  MAE=173.882  RMSE=344.818


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.169861 val_avg_MAE=142.583

VAL
   12h  MAE=119.498  RMSE=248.133
   24h  MAE=126.954  RMSE=268.936
   48h  MAE=152.218  RMSE=305.648
   72h  MAE=171.664  RMSE=340.248


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.166665 val_avg_MAE=137.085

VAL
   12h  MAE=116.327  RMSE=246.383
   24h  MAE=126.377  RMSE=269.919
   48h  MAE=147.226  RMSE=295.715
   72h  MAE=158.411  RMSE=314.506


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.165718 val_avg_MAE=140.395

VAL
   12h  MAE=116.427  RMSE=244.096
   24h  MAE=131.026  RMSE=273.229
   48h  MAE=153.786  RMSE=311.070
   72h  MAE=160.340  RMSE=325.649


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.163657 val_avg_MAE=146.375

VAL
   12h  MAE=115.364  RMSE=248.303
   24h  MAE=131.535  RMSE=278.240
   48h  MAE=164.478  RMSE=329.032
   72h  MAE=174.122  RMSE=343.663


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.162586 val_avg_MAE=144.401

VAL
   12h  MAE=119.664  RMSE=244.301
   24h  MAE=130.797  RMSE=268.998
   48h  MAE=154.198  RMSE=306.704
   72h  MAE=172.945  RMSE=342.690


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.160441 val_avg_MAE=145.354

VAL
   12h  MAE=122.735  RMSE=252.082
   24h  MAE=131.826  RMSE=276.570
   48h  MAE=156.517  RMSE=318.339
   72h  MAE=170.337  RMSE=342.086


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 24: train_loss=0.159786 val_avg_MAE=140.623

VAL
   12h  MAE=111.748  RMSE=233.924
   24h  MAE=133.183  RMSE=273.733
   48h  MAE=155.823  RMSE=311.206
   72h  MAE=161.738  RMSE=323.364


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 26: train_loss=0.159669 val_avg_MAE=137.312

VAL
   12h  MAE=114.973  RMSE=243.987
   24h  MAE=122.287  RMSE=260.352
   48h  MAE=148.387  RMSE=301.416
   72h  MAE=163.603  RMSE=328.830

Early stopping. Best val_avg_MAE=137.085
Saved history: artifacts/runs/20260210_213135_STGCN_LSTM/history.csv

Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN_LSTM — TEST
   12h  MAE=119.302  RMSE=241.473
   24h  MAE=124.602  RMSE=255.679
   48h  MAE=139.211  RMSE=284.156
   72h  MAE=152.125  RMSE=302.663


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260210_213135_STGCN_LSTM
 - best checkpoint: artifacts/runs/20260210_213135_STGCN_LSTM/best.pt
 - history: artifacts/runs/20260210_213135_STGCN_LSTM/history.csv
 - test metrics: artifacts/runs/20260210_213135_STGCN_LSTM/test_metrics.json
 - predictions (npz): artifacts/runs/20260210_213135_STGCN_LSTM/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260210_213135_STGCN_LSTM/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv
Run dir: artifacts/runs/20260210_215633_STGCN_GRU_LSTM


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.237603 val_avg_MAE=162.906

VAL
   12h  MAE=144.636  RMSE=282.637
   24h  MAE=149.053  RMSE=303.920
   48h  MAE=172.479  RMSE=331.723
   72h  MAE=185.455  RMSE=355.479


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.193425 val_avg_MAE=154.668

VAL
   12h  MAE=132.365  RMSE=270.265
   24h  MAE=136.452  RMSE=283.101
   48h  MAE=163.750  RMSE=323.023
   72h  MAE=186.106  RMSE=360.543


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.183178 val_avg_MAE=148.843

VAL
   12h  MAE=127.880  RMSE=260.625
   24h  MAE=133.700  RMSE=280.580
   48h  MAE=163.274  RMSE=327.813
   72h  MAE=170.518  RMSE=335.863


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.177234 val_avg_MAE=154.520

VAL
   12h  MAE=127.858  RMSE=255.679
   24h  MAE=148.648  RMSE=303.584
   48h  MAE=166.251  RMSE=327.664
   72h  MAE=175.323  RMSE=347.749


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.173290 val_avg_MAE=146.593

VAL
   12h  MAE=122.953  RMSE=256.028
   24h  MAE=139.710  RMSE=291.586
   48h  MAE=154.361  RMSE=310.631
   72h  MAE=169.346  RMSE=334.173


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.171537 val_avg_MAE=143.401

VAL
   12h  MAE=118.147  RMSE=242.184
   24h  MAE=125.730  RMSE=268.846
   48h  MAE=159.139  RMSE=319.977
   72h  MAE=170.588  RMSE=339.899


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.169009 val_avg_MAE=138.996

VAL
   12h  MAE=116.082  RMSE=241.626
   24h  MAE=124.626  RMSE=263.935
   48h  MAE=151.468  RMSE=301.849
   72h  MAE=163.809  RMSE=321.111


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.168379 val_avg_MAE=142.159

VAL
   12h  MAE=117.564  RMSE=250.414
   24h  MAE=132.743  RMSE=278.390
   48h  MAE=155.823  RMSE=316.785
   72h  MAE=162.506  RMSE=326.917


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.165653 val_avg_MAE=143.009

VAL
   12h  MAE=114.471  RMSE=242.480
   24h  MAE=125.567  RMSE=266.700
   48h  MAE=161.444  RMSE=322.430
   72h  MAE=170.555  RMSE=336.321


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.165318 val_avg_MAE=148.338

VAL
   12h  MAE=123.690  RMSE=252.784
   24h  MAE=137.117  RMSE=279.523
   48h  MAE=157.508  RMSE=316.159
   72h  MAE=175.038  RMSE=350.283


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.162441 val_avg_MAE=145.536

VAL
   12h  MAE=125.369  RMSE=255.188
   24h  MAE=130.941  RMSE=275.790
   48h  MAE=156.627  RMSE=318.608
   72h  MAE=169.206  RMSE=338.761


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 24: train_loss=0.161498 val_avg_MAE=140.612

VAL
   12h  MAE=113.087  RMSE=239.323
   24h  MAE=130.494  RMSE=272.679
   48h  MAE=155.995  RMSE=314.110
   72h  MAE=162.871  RMSE=324.030


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 26: train_loss=0.161323 val_avg_MAE=142.125

VAL
   12h  MAE=118.401  RMSE=250.192
   24h  MAE=123.616  RMSE=260.885
   48h  MAE=155.419  RMSE=310.612
   72h  MAE=171.063  RMSE=338.146

Early stopping. Best val_avg_MAE=138.996
Saved history: artifacts/runs/20260210_215633_STGCN_GRU_LSTM/history.csv

Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


STGCN_GRU_LSTM — TEST
   12h  MAE=119.174  RMSE=240.949
   24h  MAE=122.301  RMSE=249.547
   48h  MAE=141.759  RMSE=288.465
   72h  MAE=154.342  RMSE=308.134


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260210_215633_STGCN_GRU_LSTM
 - best checkpoint: artifacts/runs/20260210_215633_STGCN_GRU_LSTM/best.pt
 - history: artifacts/runs/20260210_215633_STGCN_GRU_LSTM/history.csv
 - test metrics: artifacts/runs/20260210_215633_STGCN_GRU_LSTM/test_metrics.json
 - predictions (npz): artifacts/runs/20260210_215633_STGCN_GRU_LSTM/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260210_215633_STGCN_GRU_LSTM/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv
Done.
GRU run dir: artifacts/runs/20260210_205540_STGCN_GRU
LSTM run dir: artifacts/runs/20260210_213135_STGCN_LSTM
GRU_LSTM run dir: artifacts/runs/20260210_215633_STGCN_GRU_LSTM


In [13]:
import pandas as pd
from pathlib import Path

def xlsx_to_csvs(run_dir):
    run_dir = Path(run_dir)
    xlsx_path = run_dir / "test_pred_true_selected_horizons.xlsx"
    if not xlsx_path.exists():
        print("No xlsx found:", xlsx_path)
        return

    xls = pd.ExcelFile(xlsx_path)
    for sheet in xls.sheet_names:
        df = pd.read_excel(xlsx_path, sheet_name=sheet)
        out = run_dir / f"{sheet}.csv"
        df.to_csv(out, index=False)
    print("Exported CSVs to:", run_dir)


xlsx_to_csvs(run_dir_gru)
xlsx_to_csvs(run_dir_lstm)
xlsx_to_csvs(run_dir_gru_lstm)


Exported CSVs to: artifacts/runs/20260210_205540_STGCN_GRU
Exported CSVs to: artifacts/runs/20260210_213135_STGCN_LSTM
Exported CSVs to: artifacts/runs/20260210_215633_STGCN_GRU_LSTM


In [14]:
import pandas as pd

df = pd.read_csv("artifacts/results_summary.csv")

horizons = [12, 24, 48, 72]
for h in horizons:
    df[f"test_MAE_{h}h"] = pd.to_numeric(df[f"test_MAE_{h}h"], errors="coerce")

df["avg_MAE"] = df[[f"test_MAE_{h}h" for h in horizons]].mean(axis=1)

cols = ["model_name", "avg_MAE"] + [f"test_MAE_{h}h" for h in horizons]
display(df.sort_values("avg_MAE")[cols].head(25))


Unnamed: 0,model_name,avg_MAE,test_MAE_12h,test_MAE_24h,test_MAE_48h,test_MAE_72h
13,RandomForest,119.585777,114.454689,115.199486,123.086647,125.602287
2,GraphWaveNet_GRU_LSTM,130.347263,119.049412,125.123518,135.325328,141.890796
1,GraphWaveNet_LSTM,131.859503,123.367641,125.953695,135.830582,142.286094
5,STGCN,132.266313,119.125651,121.787452,139.610473,148.541676
8,STGCN_GRU,132.610577,114.689314,121.745046,139.941089,154.066858
14,STGCN_GRU,132.627357,114.699978,121.737924,139.976878,154.094648
6,STGCN,132.881293,124.865714,121.199576,136.894008,148.565873
7,STGCN,132.881293,124.865714,121.199576,136.894008,148.565873
0,GraphWaveNet_GRU,133.121911,122.688777,127.520536,138.531665,143.746668
15,STGCN_LSTM,133.809814,119.301886,124.601723,139.211051,152.124597


# Baseline Models

In [36]:
import os
# IMPORTANT: limit CPU thread explosions (helps stop Paperspace kernels from dying)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import json
import gc
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn


In [37]:
!pip -q install scikit-learn joblib

[0m

In [38]:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.ensemble import RandomForestRegressor
from joblib import Parallel, delayed


In [39]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

DATA_PATH = Path("artifacts/pems_graph_dataset_strict.npz")
assert DATA_PATH.exists(), f"Missing: {DATA_PATH}"

data = np.load(DATA_PATH, allow_pickle=True)

X = data["X"].astype(np.float32)         # (T, N, F)
Y = data["Y"].astype(np.float32)         # (T, N)   raw flow
A = data["A"].astype(np.float32)         # (N, N)
stations = data["stations"]
timestamps = data["timestamps"]

train_starts = data["train_starts"].astype(np.int64)
val_starts   = data["val_starts"].astype(np.int64)
test_starts  = data["test_starts"].astype(np.int64)

IN_LEN  = int(np.array(data["in_len"]).item())
OUT_LEN = int(np.array(data["out_len"]).item())

flow_mean  = data["flow_mean"].astype(np.float32)   # (N,)
flow_std   = data["flow_std"].astype(np.float32)    # (N,)
speed_mean = data["speed_mean"].astype(np.float32)  # (N,)
speed_std  = data["speed_std"].astype(np.float32)   # (N,)

T, N, Fdim = X.shape
print("X:", X.shape, "(T,N,F)")
print("Y:", Y.shape, "(T,N)")
print("IN_LEN/OUT_LEN:", IN_LEN, OUT_LEN)
print("N stations:", N)
print("train/val/test starts:", len(train_starts), len(val_starts), len(test_starts))


Device: cuda
GPU: Quadro P5000
X: (2208, 1821, 6) (T,N,F)
Y: (2208, 1821) (T,N)
IN_LEN/OUT_LEN: 24 72
N stations: 1821
train/val/test starts: 1009 289 673


In [40]:
def time_encoding(dt_index: pd.DatetimeIndex) -> np.ndarray:
    hours = dt_index.hour.values
    dow   = dt_index.dayofweek.values
    hour_sin = np.sin(2*np.pi*hours/24.0)
    hour_cos = np.cos(2*np.pi*hours/24.0)
    dow_sin  = np.sin(2*np.pi*dow/7.0)
    dow_cos  = np.cos(2*np.pi*dow/7.0)
    return np.stack([hour_sin, hour_cos, dow_sin, dow_cos], axis=1).astype(np.float32)

TF_all = time_encoding(pd.to_datetime(timestamps))  # (T,4)
print("TF_all:", TF_all.shape)


TF_all: (2208, 4)


In [41]:
# Avoid divide-by-zero
flow_std  = np.maximum(flow_std,  1e-6).astype(np.float32)
speed_std = np.maximum(speed_std, 1e-6).astype(np.float32)

X_scaled = X.copy()
# assume channel0=flow, channel1=speed (your pipeline)
X_scaled[:, :, 0] = (X_scaled[:, :, 0] - flow_mean[None, :])  / flow_std[None, :]
X_scaled[:, :, 1] = (X_scaled[:, :, 1] - speed_mean[None, :]) / speed_std[None, :]

Y_scaled = (Y - flow_mean[None, :]) / flow_std[None, :]

print("Sanity (scaled) flow mean/std ~ 0/1 on TRAIN-ish slice:")
print("Y_scaled mean/std:", float(Y_scaled.mean()), float(Y_scaled.std()))


Sanity (scaled) flow mean/std ~ 0/1 on TRAIN-ish slice:
Y_scaled mean/std: -781.403564453125 30704.76953125


In [42]:
X_fnt = np.transpose(X_scaled, (2, 1, 0)).copy()  # (F, N, T)

class FastPemsWindowDataset(torch.utils.data.Dataset):
    def __init__(self, X_fnt, Y_scaled, TF_all, starts, in_len, out_len):
        self.X_fnt = X_fnt
        self.Y = Y_scaled
        self.TF = TF_all
        self.starts = np.asarray(starts, dtype=np.int64)
        self.in_len = int(in_len)
        self.out_len = int(out_len)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, i):
        s = int(self.starts[i])
        x = self.X_fnt[:, :, s:s+self.in_len]                    # (F,N,IN)
        y = self.Y[s+self.in_len:s+self.in_len+self.out_len, :]  # (OUT,N)
        tf = self.TF[s+self.in_len:s+self.in_len+self.out_len]   # (OUT,4)
        return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(tf)

train_ds = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, train_starts, IN_LEN, OUT_LEN)
val_ds   = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, val_starts,   IN_LEN, OUT_LEN)
test_ds  = FastPemsWindowDataset(X_fnt, Y_scaled, TF_all, test_starts,  IN_LEN, OUT_LEN)

BATCH_SIZE = 8
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = torch.utils.data.DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

xb, yb, tfb = next(iter(train_loader))
print("Batch x:", xb.shape, "Batch y:", yb.shape, "Batch tf:", tfb.shape)


Batch x: torch.Size([8, 6, 1821, 24]) Batch y: torch.Size([8, 72, 1821]) Batch tf: torch.Size([8, 72, 4])


In [43]:
EVAL_HORIZONS = [12, 24, 48, 72]
H = len(EVAL_HORIZONS)

flow_mean_t = torch.tensor(flow_mean, dtype=torch.float32, device=DEVICE).view(1, 1, -1)
flow_std_t  = torch.tensor(flow_std,  dtype=torch.float32, device=DEVICE).view(1, 1, -1)

def print_metrics(title, metrics):
    print("\n" + title)
    for h in sorted(metrics.keys()):
        print(f"  {h:>3}h  MAE={metrics[h]['MAE']:.3f}  RMSE={metrics[h]['RMSE']:.3f}")

def avg_mae(metrics):
    return float(np.mean([metrics[h]["MAE"] for h in metrics]))

@torch.inference_mode()
def eval_horizons_fast(model, loader):
    model.eval()
    acc = {h: {"abs": 0.0, "sq": 0.0, "count": 0} for h in EVAL_HORIZONS}

    for xb, yb, tfb in tqdm(loader, desc="Eval", leave=False):
        xb  = xb.to(DEVICE, non_blocking=True)
        yb  = yb.to(DEVICE, non_blocking=True)
        tfb = tfb.to(DEVICE, non_blocking=True)

        pred = model(xb, tfb)  # scaled (B,OUT,N)

        pred_u = pred * flow_std_t + flow_mean_t
        true_u = yb   * flow_std_t + flow_mean_t

        for h in EVAL_HORIZONS:
            idx = h - 1
            err = pred_u[:, idx, :] - true_u[:, idx, :]
            acc[h]["abs"]   += float(err.abs().sum())
            acc[h]["sq"]    += float((err**2).sum())
            acc[h]["count"] += err.numel()

    metrics = {}
    for h in EVAL_HORIZONS:
        mae = acc[h]["abs"] / acc[h]["count"]
        rmse = (acc[h]["sq"] / acc[h]["count"]) ** 0.5
        metrics[h] = {"MAE": float(mae), "RMSE": float(rmse)}
    return metrics

def make_run_dir(model_name: str) -> Path:
    ts = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    run_dir = Path("artifacts/runs") / f"{ts}_{model_name}"
    run_dir.mkdir(parents=True, exist_ok=False)
    return run_dir

def save_metrics_files(run_dir: Path, split_name: str, metrics: dict):
    (run_dir / f"{split_name}_metrics.json").write_text(json.dumps(metrics, indent=2))
    rows = []
    for h in sorted(metrics.keys()):
        rows.append({"horizon": h, "MAE": metrics[h]["MAE"], "RMSE": metrics[h]["RMSE"]})
    pd.DataFrame(rows).to_csv(run_dir / f"{split_name}_metrics.csv", index=False)

def append_results_summary(model_name: str, run_dir: Path, test_metrics: dict):
    summary_path = Path("artifacts/results_summary.csv")
    row = {
        "timestamp": pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
        "model_name": model_name,
        "run_dir": str(run_dir),
    }
    for h in EVAL_HORIZONS:
        row[f"test_MAE_{h}h"] = test_metrics[h]["MAE"]
        row[f"test_RMSE_{h}h"] = test_metrics[h]["RMSE"]

    df_new = pd.DataFrame([row])
    if summary_path.exists():
        df_old = pd.read_csv(summary_path)
        df = pd.concat([df_old, df_new], ignore_index=True)
    else:
        df = df_new
    df.to_csv(summary_path, index=False)
    return summary_path

def save_preds_npz_and_csv_subset(
    run_dir: Path,
    pred_u: np.ndarray,   # (S,N,H)
    true_u: np.ndarray,   # (S,N,H)
    starts: np.ndarray,
    max_stations_csv: int = 300,
):
    # NPZ (full)
    np.savez_compressed(
        run_dir / "test_pred_true_selected_horizons.npz",
        pred=pred_u.astype(np.float32),
        true=true_u.astype(np.float32),
        horizons=np.array(EVAL_HORIZONS, dtype=np.int64),
        starts=starts.astype(np.int64),
        stations=stations,
        timestamps=timestamps
    )

    # CSV subset (manageable)
    K = min(max_stations_csv, N)
    keep = np.arange(K)

    frames = []
    for j, h in enumerate(EVAL_HORIZONS):
        idx = starts + IN_LEN + (h - 1)
        ts_h = pd.to_datetime(timestamps[idx])

        df_h = pd.DataFrame({
            "start_idx": np.repeat(starts, K),
            "timestamp": np.repeat(ts_h, K),
            "station": np.tile(np.array(stations)[keep], len(starts)),
            "horizon_h": h,
            "y_true": true_u[:, keep, j].reshape(-1),
            "y_pred": pred_u[:, keep, j].reshape(-1),
        })
        frames.append(df_h)

    df_out = pd.concat(frames, ignore_index=True)
    df_out.to_csv(run_dir / "test_pred_true_selected_horizons.csv", index=False)
    return run_dir / "test_pred_true_selected_horizons.csv"


## LSTM MODEL

In [44]:
class LSTM_Baseline(nn.Module):
    def __init__(self, in_dim: int, out_len: int, hidden: int = 64, layers: int = 1, dropout: float = 0.1, tf_dim: int = 4):
        super().__init__()
        self.out_len = out_len
        self.hidden = hidden
        self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden, num_layers=layers, dropout=(dropout if layers > 1 else 0.0), batch_first=True)
        self.head = nn.Linear(hidden + tf_dim, 1)

    def forward(self, x, tf):
        # x: (B,F,N,IN) -> (B,N,IN,F) -> (B*N, IN, F)
        B, F, Nn, INL = x.shape
        x_seq = x.permute(0, 2, 3, 1).contiguous().view(B * Nn, INL, F)

        out, (h, c) = self.lstm(x_seq)
        h_last = h[-1]  # (B*N, hidden)

        # tf: (B, OUT, 4) -> repeat per node -> (B*N, OUT, 4)
        tf_rep = tf.unsqueeze(1).expand(B, Nn, self.out_len, tf.shape[-1]).contiguous().view(B * Nn, self.out_len, tf.shape[-1])

        h_rep = h_last.unsqueeze(1).expand(B * Nn, self.out_len, self.hidden)
        z = torch.cat([h_rep, tf_rep], dim=-1)           # (B*N, OUT, hidden+4)
        y = self.head(z).squeeze(-1)                    # (B*N, OUT)
        y = y.view(B, Nn, self.out_len).permute(0, 2, 1) # (B, OUT, N)
        return y

def train_torch_and_save(model_name: str, model: nn.Module, epochs=40, lr=1e-3, weight_decay=1e-4, clip=5.0, patience=6, eval_every=2):
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    model = model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.SmoothL1Loss(beta=1.0)

    best = float("inf")
    bad = 0
    best_state = None
    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0

        for xb, yb, tfb in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=False):
            xb = xb.to(DEVICE, non_blocking=True)
            yb = yb.to(DEVICE, non_blocking=True)
            tfb = tfb.to(DEVICE, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            pred = model(xb, tfb)
            loss = loss_fn(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step()
            running += float(loss.item())

        if epoch % eval_every == 0:
            val_m = eval_horizons_fast(model, val_loader)
            score = avg_mae(val_m)
            print(f"\nEpoch {epoch}: train_loss={running/len(train_loader):.6f}  val_avg_MAE={score:.3f}")
            print_metrics("VAL", val_m)

            history.append({"epoch": epoch, "train_loss": running/len(train_loader), "val_avg_MAE": score, **{f"val_MAE_{h}h": val_m[h]["MAE"] for h in EVAL_HORIZONS}})

            if score < best:
                best = score
                bad = 0
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            else:
                bad += 1
                if bad >= patience:
                    print(f"\nEarly stopping. Best val_avg_MAE={best:.3f}")
                    break

    # save history
    if len(history) > 0:
        pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)

    # load best + save checkpoint
    if best_state is not None:
        model.load_state_dict(best_state)
    torch.save(model.state_dict(), run_dir / "best.pt")

    # TEST
    print("\nEvaluating on TEST set...")
    test_m = eval_horizons_fast(model, test_loader)
    print_metrics(f"{model_name} — TEST", test_m)

    save_metrics_files(run_dir, "test", test_m)

    # Collect & save preds (selected horizons only)
    S = len(test_starts)
    pred_u = np.zeros((S, N, len(EVAL_HORIZONS)), dtype=np.float32)
    true_u = np.zeros((S, N, len(EVAL_HORIZONS)), dtype=np.float32)

    model.eval()
    pos = 0
    with torch.inference_mode():
        for xb, yb, tfb in tqdm(test_loader, desc="Collect preds", leave=False):
            bsz = xb.shape[0]
            xb  = xb.to(DEVICE)
            yb  = yb.to(DEVICE)
            tfb = tfb.to(DEVICE)

            pred = model(xb, tfb)  # scaled (B,OUT,N)
            pred_u_b = (pred * flow_std_t + flow_mean_t)  # (B,OUT,N)
            true_u_b = (yb   * flow_std_t + flow_mean_t)

            for j, h in enumerate(EVAL_HORIZONS):
                idx = h - 1
                pred_u[pos:pos+bsz, :, j] = pred_u_b[:, idx, :].detach().cpu().numpy()
                true_u[pos:pos+bsz, :, j] = true_u_b[:, idx, :].detach().cpu().numpy()

            pos += bsz

    csv_path = save_preds_npz_and_csv_subset(run_dir, pred_u, true_u, test_starts, max_stations_csv=300)
    summary_path = append_results_summary(model_name, run_dir, test_m)

    print("\nSaved run outputs to:", run_dir)
    print(" - best checkpoint:", run_dir / "best.pt")
    print(" - history:", run_dir / "history.csv")
    print(" - test metrics:", run_dir / "test_metrics.json")
    print(" - preds npz:", run_dir / "test_pred_true_selected_horizons.npz")
    print(" - preds csv:", csv_path)
    print(" - master summary:", summary_path)

    return run_dir


In [45]:
lstm_base = LSTM_Baseline(in_dim=Fdim, out_len=OUT_LEN, hidden=64, layers=1, dropout=0.1).to(DEVICE)
run_dir_lstm = train_torch_and_save("LSTM", lstm_base, epochs=40, patience=6, eval_every=2)
run_dir_lstm

Run dir: artifacts/runs/20260210_194456_LSTM


                                                             


Epoch 2: train_loss=0.320809  val_avg_MAE=382.517

VAL
   12h  MAE=379.311  RMSE=625.654
   24h  MAE=381.042  RMSE=630.496
   48h  MAE=386.313  RMSE=638.276
   72h  MAE=383.400  RMSE=630.289


                                                             


Epoch 4: train_loss=0.236424  val_avg_MAE=312.220

VAL
   12h  MAE=314.541  RMSE=533.124
   24h  MAE=310.456  RMSE=528.506
   48h  MAE=313.938  RMSE=532.925
   72h  MAE=309.946  RMSE=521.925


                                                             


Epoch 6: train_loss=0.191218  val_avg_MAE=268.685

VAL
   12h  MAE=270.014  RMSE=471.836
   24h  MAE=267.554  RMSE=467.962
   48h  MAE=270.475  RMSE=471.248
   72h  MAE=266.697  RMSE=459.570


                                                             


Epoch 8: train_loss=0.169478  val_avg_MAE=243.285

VAL
   12h  MAE=241.783  RMSE=433.174
   24h  MAE=242.627  RMSE=433.923
   48h  MAE=246.278  RMSE=437.987
   72h  MAE=242.451  RMSE=425.507


                                                              


Epoch 10: train_loss=0.159286  val_avg_MAE=228.872

VAL
   12h  MAE=225.890  RMSE=411.183
   24h  MAE=228.787  RMSE=414.217
   48h  MAE=231.621  RMSE=416.383
   72h  MAE=229.188  RMSE=406.236


                                                              


Epoch 12: train_loss=0.155571  val_avg_MAE=219.229

VAL
   12h  MAE=217.691  RMSE=402.907
   24h  MAE=218.839  RMSE=404.333
   48h  MAE=221.824  RMSE=406.237
   72h  MAE=218.563  RMSE=393.787


                                                              


Epoch 14: train_loss=0.154543  val_avg_MAE=217.548

VAL
   12h  MAE=215.619  RMSE=398.803
   24h  MAE=217.312  RMSE=401.362
   48h  MAE=220.151  RMSE=402.976
   72h  MAE=217.110  RMSE=391.180


                                                              


Epoch 16: train_loss=0.154310  val_avg_MAE=215.735

VAL
   12h  MAE=214.697  RMSE=396.537
   24h  MAE=214.984  RMSE=396.953
   48h  MAE=217.842  RMSE=399.114
   72h  MAE=215.418  RMSE=388.827


                                                              


Epoch 18: train_loss=0.153910  val_avg_MAE=216.894

VAL
   12h  MAE=214.726  RMSE=396.012
   24h  MAE=216.812  RMSE=398.659
   48h  MAE=219.204  RMSE=399.618
   72h  MAE=216.832  RMSE=389.739


                                                              


Epoch 20: train_loss=0.153893  val_avg_MAE=214.605

VAL
   12h  MAE=210.363  RMSE=393.016
   24h  MAE=215.191  RMSE=399.678
   48h  MAE=218.016  RMSE=401.261
   72h  MAE=214.850  RMSE=389.219


                                                              


Epoch 22: train_loss=0.153931  val_avg_MAE=217.337

VAL
   12h  MAE=214.923  RMSE=395.916
   24h  MAE=216.774  RMSE=398.869
   48h  MAE=220.158  RMSE=401.589
   72h  MAE=217.494  RMSE=391.058


                                                              


Epoch 24: train_loss=0.154166  val_avg_MAE=216.516

VAL
   12h  MAE=213.300  RMSE=394.374
   24h  MAE=216.472  RMSE=398.559
   48h  MAE=219.286  RMSE=400.382
   72h  MAE=217.006  RMSE=390.458


                                                              


Epoch 26: train_loss=0.154025  val_avg_MAE=212.017

VAL
   12h  MAE=211.293  RMSE=394.835
   24h  MAE=211.839  RMSE=396.043
   48h  MAE=214.018  RMSE=396.692
   72h  MAE=210.917  RMSE=384.676


                                                              


Epoch 28: train_loss=0.153823  val_avg_MAE=220.093

VAL
   12h  MAE=216.945  RMSE=395.549
   24h  MAE=219.505  RMSE=399.129
   48h  MAE=222.564  RMSE=401.864
   72h  MAE=221.357  RMSE=394.701


                                                              


Epoch 30: train_loss=0.153642  val_avg_MAE=220.669

VAL
   12h  MAE=217.018  RMSE=397.284
   24h  MAE=221.059  RMSE=402.909
   48h  MAE=223.714  RMSE=404.069
   72h  MAE=220.886  RMSE=393.742


                                                              


Epoch 32: train_loss=0.153735  val_avg_MAE=219.903

VAL
   12h  MAE=217.648  RMSE=399.058
   24h  MAE=219.960  RMSE=402.868
   48h  MAE=222.741  RMSE=404.134
   72h  MAE=219.262  RMSE=392.515


                                                              


Epoch 34: train_loss=0.153863  val_avg_MAE=218.836

VAL
   12h  MAE=216.376  RMSE=398.738
   24h  MAE=219.371  RMSE=403.235
   48h  MAE=221.700  RMSE=403.637
   72h  MAE=217.898  RMSE=391.222


                                                              


Epoch 36: train_loss=0.153885  val_avg_MAE=216.502

VAL
   12h  MAE=213.170  RMSE=392.736
   24h  MAE=215.992  RMSE=397.322
   48h  MAE=219.361  RMSE=400.568
   72h  MAE=217.483  RMSE=391.558


                                                              


Epoch 38: train_loss=0.153954  val_avg_MAE=215.830

VAL
   12h  MAE=212.955  RMSE=393.432
   24h  MAE=215.556  RMSE=397.488
   48h  MAE=218.603  RMSE=400.005
   72h  MAE=216.204  RMSE=389.975

Early stopping. Best val_avg_MAE=212.017

Evaluating on TEST set...


                                                     


LSTM — TEST
   12h  MAE=217.363  RMSE=403.856
   24h  MAE=212.865  RMSE=399.008
   48h  MAE=214.150  RMSE=399.032
   72h  MAE=212.684  RMSE=392.795


                                                              


Saved run outputs to: artifacts/runs/20260210_194456_LSTM
 - best checkpoint: artifacts/runs/20260210_194456_LSTM/best.pt
 - history: artifacts/runs/20260210_194456_LSTM/history.csv
 - test metrics: artifacts/runs/20260210_194456_LSTM/test_metrics.json
 - preds npz: artifacts/runs/20260210_194456_LSTM/test_pred_true_selected_horizons.npz
 - preds csv: artifacts/runs/20260210_194456_LSTM/test_pred_true_selected_horizons.csv
 - master summary: artifacts/results_summary.csv


PosixPath('artifacts/runs/20260210_194456_LSTM')

## Elastic Net Linear Regression 

In [46]:
H_OFF = np.array([h - 1 for h in EVAL_HORIZONS], dtype=np.int64)

def node_features_and_targets(node: int, starts: np.ndarray):
    # past window features (scaled) for this node
    Xn = X_scaled[:, node, :]  # (T,F)
    win = np.lib.stride_tricks.sliding_window_view(Xn, window_shape=IN_LEN, axis=0)  # (T-IN_LEN+1, IN_LEN, F)
    X_hist = win[starts].reshape(len(starts), -1)  # (S, IN_LEN*F)

    # future time features for all horizons (S,H,4) -> (S, H*4)
    idx = starts[:, None] + IN_LEN + H_OFF[None, :]
    X_tf = TF_all[idx].reshape(len(starts), -1)

    X_feat = np.concatenate([X_hist, X_tf], axis=1)

    # targets (scaled) (S,H)
    y = Y_scaled[idx, node]
    return X_feat, y.astype(np.float32)

def run_elasticnet_baseline(alpha=1e-3, l1_ratio=0.5, jobs=4):
    model_name = "ElasticNet"
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    S_test = len(test_starts)
    pred_scaled = np.zeros((S_test, N, H), dtype=np.float32)
    true_scaled = np.zeros((S_test, N, H), dtype=np.float32)

    def fit_one(node: int):
        Xtr, ytr = node_features_and_targets(node, train_starts)
        Xte, yte = node_features_and_targets(node, test_starts)

        # Elastic net + scaling
        mdl = make_pipeline(
            StandardScaler(),
            MultiTaskElasticNet(alpha=alpha, l1_ratio=l1_ratio, max_iter=5000, random_state=42)
        )
        mdl.fit(Xtr, ytr)
        pred = mdl.predict(Xte).astype(np.float32)  # (S_test,H)
        return node, pred, yte

    nodes = list(range(N))
    results = Parallel(n_jobs=jobs, prefer="threads")(
        delayed(fit_one)(node) for node in tqdm(nodes, desc="ElasticNet per-node")
    )

    for node, pred, yte in results:
        pred_scaled[:, node, :] = pred
        true_scaled[:, node, :] = yte

    # unscale
    pred_u = pred_scaled * flow_std[None, :, None] + flow_mean[None, :, None]
    true_u = true_scaled * flow_std[None, :, None] + flow_mean[None, :, None]

    # metrics
    metrics = {}
    for j, h in enumerate(EVAL_HORIZONS):
        err = pred_u[:, :, j] - true_u[:, :, j]
        metrics[h] = {
            "MAE": float(np.abs(err).mean()),
            "RMSE": float(np.sqrt((err**2).mean())),
        }

    print_metrics("ElasticNet — TEST", metrics)
    save_metrics_files(run_dir, "test", metrics)

    csv_path = save_preds_npz_and_csv_subset(run_dir, pred_u, true_u, test_starts, max_stations_csv=300)
    summary_path = append_results_summary(model_name, run_dir, metrics)

    print("\nSaved:", run_dir)
    print(" - test metrics:", run_dir / "test_metrics.json")
    print(" - preds npz:", run_dir / "test_pred_true_selected_horizons.npz")
    print(" - preds csv:", csv_path)
    print(" - master summary:", summary_path)

    return run_dir


In [47]:
run_dir_en = run_elasticnet_baseline(alpha=1e-3, l1_ratio=0.5, jobs=4)
run_dir_en

Run dir: artifacts/runs/20260210_195631_ElasticNet


  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(
  ) = cd_fast.enet_coordinate_descent_multi_task(



ElasticNet — TEST
   12h  MAE=151.053  RMSE=306.839
   24h  MAE=148.707  RMSE=304.068
   48h  MAE=149.399  RMSE=305.192
   72h  MAE=141.125  RMSE=295.535

Saved: artifacts/runs/20260210_195631_ElasticNet
 - test metrics: artifacts/runs/20260210_195631_ElasticNet/test_metrics.json
 - preds npz: artifacts/runs/20260210_195631_ElasticNet/test_pred_true_selected_horizons.npz
 - preds csv: artifacts/runs/20260210_195631_ElasticNet/test_pred_true_selected_horizons.csv
 - master summary: artifacts/results_summary.csv


PosixPath('artifacts/runs/20260210_195631_ElasticNet')

## Random Forest 

In [49]:
def run_random_forest_baseline(
    n_estimators=50,
    max_depth=20,
    min_samples_leaf=5,
    max_features="sqrt",
    jobs=4
):
    model_name = "RandomForest"
    run_dir = make_run_dir(model_name)
    print("Run dir:", run_dir)

    S_test = len(test_starts)
    pred_scaled = np.zeros((S_test, N, H), dtype=np.float32)
    true_scaled = np.zeros((S_test, N, H), dtype=np.float32)

    def fit_one(node: int):
        Xtr, ytr = node_features_and_targets(node, train_starts)
        Xte, yte = node_features_and_targets(node, test_starts)

        mdl = RandomForestRegressor(
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
            n_jobs=1,             # IMPORTANT: keep 1; we parallelize across nodes
            random_state=42
        )
        mdl.fit(Xtr, ytr)
        pred = mdl.predict(Xte).astype(np.float32)
        return node, pred, yte

    nodes = list(range(N))
    results = Parallel(n_jobs=jobs, prefer="threads")(
        delayed(fit_one)(node) for node in tqdm(nodes, desc="RF per-node")
    )

    for node, pred, yte in results:
        pred_scaled[:, node, :] = pred
        true_scaled[:, node, :] = yte

    pred_u = pred_scaled * flow_std[None, :, None] + flow_mean[None, :, None]
    true_u = true_scaled * flow_std[None, :, None] + flow_mean[None, :, None]

    metrics = {}
    for j, h in enumerate(EVAL_HORIZONS):
        err = pred_u[:, :, j] - true_u[:, :, j]
        metrics[h] = {
            "MAE": float(np.abs(err).mean()),
            "RMSE": float(np.sqrt((err**2).mean())),
        }

    print_metrics("RandomForest — TEST", metrics)
    save_metrics_files(run_dir, "test", metrics)

    csv_path = save_preds_npz_and_csv_subset(run_dir, pred_u, true_u, test_starts, max_stations_csv=300)
    summary_path = append_results_summary(model_name, run_dir, metrics)

    print("\nSaved:", run_dir)
    print(" - test metrics:", run_dir / "test_metrics.json")
    print(" - preds npz:", run_dir / "test_pred_true_selected_horizons.npz")
    print(" - preds csv:", csv_path)
    print(" - master summary:", summary_path)

    return run_dir


In [50]:
run_dir_rf = run_random_forest_baseline(n_estimators=50, max_depth=20, min_samples_leaf=5, jobs=4)
run_dir_rf


Run dir: artifacts/runs/20260210_200319_RandomForest


RF per-node: 100%|██████████| 1821/1821 [03:09<00:00,  9.63it/s]



RandomForest — TEST
   12h  MAE=114.455  RMSE=259.444
   24h  MAE=115.199  RMSE=263.879
   48h  MAE=123.087  RMSE=279.939
   72h  MAE=125.602  RMSE=284.720

Saved: artifacts/runs/20260210_200319_RandomForest
 - test metrics: artifacts/runs/20260210_200319_RandomForest/test_metrics.json
 - preds npz: artifacts/runs/20260210_200319_RandomForest/test_pred_true_selected_horizons.npz
 - preds csv: artifacts/runs/20260210_200319_RandomForest/test_pred_true_selected_horizons.csv
 - master summary: artifacts/results_summary.csv


PosixPath('artifacts/runs/20260210_200319_RandomForest')

In [51]:
df = pd.read_csv("artifacts/results_summary.csv")

horizons = [12, 24, 48, 72]
for h in horizons:
    df[f"test_MAE_{h}h"] = pd.to_numeric(df[f"test_MAE_{h}h"], errors="coerce")

df["avg_MAE"] = df[[f"test_MAE_{h}h" for h in horizons]].mean(axis=1)

cols = ["timestamp","model_name","avg_MAE"] + [f"test_MAE_{h}h" for h in horizons]
display(df.sort_values("avg_MAE")[cols].head(50))


Unnamed: 0,timestamp,model_name,avg_MAE,test_MAE_12h,test_MAE_24h,test_MAE_48h,test_MAE_72h
13,2026-02-10 20:06:36,RandomForest,119.585777,114.454689,115.199486,123.086647,125.602287
2,,GraphWaveNet_GRU_LSTM,130.347263,119.049412,125.123518,135.325328,141.890796
1,,GraphWaveNet_LSTM,131.859503,123.367641,125.953695,135.830582,142.286094
5,,STGCN,132.266313,119.125651,121.787452,139.610473,148.541676
8,,STGCN_GRU,132.610577,114.689314,121.745046,139.941089,154.066858
6,,STGCN,132.881293,124.865714,121.199576,136.894008,148.565873
7,,STGCN,132.881293,124.865714,121.199576,136.894008,148.565873
0,,GraphWaveNet_GRU,133.121911,122.688777,127.520536,138.531665,143.746668
9,,STGCN_LSTM,133.812095,119.279658,124.610721,139.245704,152.112297
10,,STGCN_GRU_LSTM,135.100299,118.894072,123.719928,142.578867,155.208328


## CNN-GRU-LSTM (Literature Model)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN_GRU_LSTM_MultiHorizon(nn.Module):
    """
    Per-node temporal model:
      (Conv1D over time) -> GRU -> LSTM -> Linear(out_len)
    Works on x: (B, F, N, IN_LEN)  and tf: (B, OUT_LEN, tf_dim)
    Returns: (B, OUT_LEN, N)

    node_chunk: splits nodes to avoid OOM when N is large (like 1821).
    """
    def __init__(
        self,
        in_dim: int,
        out_len: int,
        tf_dim: int = 4,
        conv_channels: int = 32,
        conv_kernel: int = 3,
        gru_hidden: int = 64,
        lstm_hidden: int = 64,
        dropout: float = 0.1,
        use_time_bias: bool = True,
        node_chunk: int = 256,
    ):
        super().__init__()
        self.in_dim = in_dim
        self.out_len = out_len
        self.tf_dim = tf_dim
        self.use_time_bias = use_time_bias
        self.node_chunk = node_chunk

        pad = conv_kernel // 2

        # Temporal CNN encoder (per node)
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels=in_dim, out_channels=conv_channels, kernel_size=conv_kernel, padding=pad),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(in_channels=conv_channels, out_channels=conv_channels, kernel_size=conv_kernel, padding=pad),
            nn.ReLU(),
        )

        # Stacked RNN encoder (per node)
        self.gru = nn.GRU(
            input_size=conv_channels,
            hidden_size=gru_hidden,
            num_layers=1,
            batch_first=True,
        )
        self.lstm = nn.LSTM(
            input_size=gru_hidden,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
        )

        # Map final hidden -> all horizons directly
        self.h_to_out = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(lstm_hidden, out_len),
        )

        # Optional: horizon/time bias from tf (same for all nodes in a sample)
        if use_time_bias:
            self.tf_to_bias = nn.Sequential(
                nn.Linear(tf_dim, 32),
                nn.ReLU(),
                nn.Linear(32, 1),   # per-horizon scalar
            )

    def forward(self, x, tf):
        """
        x:  (B, F, N, IN_LEN)
        tf: (B, OUT_LEN, tf_dim)
        """
        B, Fdim, N, IN_LEN = x.shape
        device = x.device
        dtype = x.dtype

        # time bias: (B, OUT_LEN) — computed once, broadcast across nodes
        if self.use_time_bias:
            tf = tf.to(device=device, dtype=dtype)
            time_bias = self.tf_to_bias(tf).squeeze(-1)  # (B, OUT_LEN)
        else:
            time_bias = None

        out = torch.empty((B, self.out_len, N), device=device, dtype=dtype)

        # process nodes in chunks to avoid OOM
        for s in range(0, N, self.node_chunk):
            e = min(N, s + self.node_chunk)
            Nc = e - s

            # x_chunk: (B, F, Nc, IN_LEN)
            x_chunk = x[:, :, s:e, :]

            # reshape to per-node sequences: (B*Nc, F, IN_LEN)
            x_seq = x_chunk.permute(0, 2, 1, 3).contiguous().view(B * Nc, Fdim, IN_LEN)

            # CNN over time -> (B*Nc, C, IN_LEN)
            z = self.cnn(x_seq)

            # RNN expects (B*Nc, IN_LEN, C)
            z = z.transpose(1, 2).contiguous()

            z, _ = self.gru(z)                 # (B*Nc, IN_LEN, gru_hidden)
            z, (h, c) = self.lstm(z)           # h: (1, B*Nc, lstm_hidden)
            h_last = h[-1]                     # (B*Nc, lstm_hidden)

            pred = self.h_to_out(h_last)       # (B*Nc, OUT_LEN)

            # reshape to (B, OUT_LEN, Nc)
            pred = pred.view(B, Nc, self.out_len).permute(0, 2, 1).contiguous()

            if time_bias is not None:
                pred = pred + time_bias.unsqueeze(-1)  # broadcast over nodes

            out[:, :, s:e] = pred

        return out


In [16]:
# Make the model
cnn_gru_lstm = CNN_GRU_LSTM_MultiHorizon(
    in_dim=Fdim,
    out_len=OUT_LEN,
    tf_dim=4,
    conv_channels=32,
    conv_kernel=3,
    gru_hidden=64,
    lstm_hidden=64,
    dropout=0.1,
    use_time_bias=True,
    node_chunk=256,   # if you still OOM, drop to 128
).to(DEVICE)

# Sanity forward
xb, yb, tfb = next(iter(train_loader))
with torch.no_grad():
    out = cnn_gru_lstm(xb.to(DEVICE), tfb.to(DEVICE))
print("Forward:", out.shape, "mean/std:", float(out.mean()), float(out.std()))

# Train + save (same as you did for STGCN/GWN)
run_dir = run_experiment_and_save(
    model_name="CNN_GRU_LSTM",
    model=cnn_gru_lstm,
    epochs=40,
    patience=6,
    eval_every=2,
    horizons_to_save=(12, 24, 48, 72),
    max_stations_excel=300
)

print("Saved to:", run_dir)


Forward: torch.Size([8, 72, 1821]) mean/std: -0.10008469969034195 0.21270188689231873
Run dir: artifacts/runs/20260210_223029_CNN_GRU_LSTM


Train 1/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 2/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 2: train_loss=0.279704 val_avg_MAE=187.348

VAL
   12h  MAE=183.550  RMSE=333.511
   24h  MAE=184.685  RMSE=343.514
   48h  MAE=187.308  RMSE=349.380
   72h  MAE=193.849  RMSE=355.000


Train 3/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 4/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 4: train_loss=0.245881 val_avg_MAE=178.873

VAL
   12h  MAE=160.190  RMSE=302.606
   24h  MAE=174.787  RMSE=322.889
   48h  MAE=187.620  RMSE=342.224
   72h  MAE=192.896  RMSE=350.863


Train 5/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 6/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 6: train_loss=0.225988 val_avg_MAE=175.368

VAL
   12h  MAE=164.786  RMSE=310.027
   24h  MAE=164.573  RMSE=307.617
   48h  MAE=182.529  RMSE=337.886
   72h  MAE=189.584  RMSE=348.248


Train 7/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 8/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 8: train_loss=0.211273 val_avg_MAE=165.319

VAL
   12h  MAE=149.580  RMSE=286.344
   24h  MAE=157.797  RMSE=301.509
   48h  MAE=174.591  RMSE=329.848
   72h  MAE=179.308  RMSE=336.321


Train 9/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 10/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 10: train_loss=0.202945 val_avg_MAE=163.730

VAL
   12h  MAE=151.945  RMSE=292.227
   24h  MAE=153.064  RMSE=295.175
   48h  MAE=171.462  RMSE=326.036
   72h  MAE=178.449  RMSE=335.917


Train 11/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 12/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 12: train_loss=0.198216 val_avg_MAE=155.706

VAL
   12h  MAE=136.625  RMSE=266.437
   24h  MAE=144.763  RMSE=286.044
   48h  MAE=167.613  RMSE=320.622
   72h  MAE=173.824  RMSE=329.629


Train 13/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 14/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 14: train_loss=0.192930 val_avg_MAE=155.351

VAL
   12h  MAE=141.359  RMSE=274.892
   24h  MAE=143.427  RMSE=283.277
   48h  MAE=165.045  RMSE=319.274
   72h  MAE=171.573  RMSE=327.310


Train 15/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 16/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 16: train_loss=0.190082 val_avg_MAE=151.296

VAL
   12h  MAE=136.104  RMSE=268.312
   24h  MAE=140.233  RMSE=278.416
   48h  MAE=160.874  RMSE=311.844
   72h  MAE=167.973  RMSE=322.489


Train 17/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 18/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 18: train_loss=0.185419 val_avg_MAE=148.935

VAL
   12h  MAE=131.536  RMSE=258.236
   24h  MAE=137.515  RMSE=274.154
   48h  MAE=160.364  RMSE=310.237
   72h  MAE=166.323  RMSE=318.238


Train 19/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 20/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 20: train_loss=0.183914 val_avg_MAE=150.846

VAL
   12h  MAE=131.105  RMSE=259.312
   24h  MAE=140.069  RMSE=278.347
   48h  MAE=164.003  RMSE=314.785
   72h  MAE=168.208  RMSE=320.337


Train 21/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 22/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 22: train_loss=0.180233 val_avg_MAE=152.782

VAL
   12h  MAE=129.453  RMSE=251.578
   24h  MAE=138.289  RMSE=276.005
   48h  MAE=168.168  RMSE=323.403
   72h  MAE=175.219  RMSE=333.212


Train 23/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 24/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 24: train_loss=0.178551 val_avg_MAE=146.026

VAL
   12h  MAE=133.525  RMSE=262.309
   24h  MAE=136.011  RMSE=271.153
   48h  MAE=154.218  RMSE=300.318
   72h  MAE=160.350  RMSE=309.459


Train 25/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 26/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 26: train_loss=0.177030 val_avg_MAE=148.570

VAL
   12h  MAE=135.140  RMSE=266.850
   24h  MAE=135.999  RMSE=273.156
   48h  MAE=158.709  RMSE=309.698
   72h  MAE=164.434  RMSE=319.642


Train 27/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 28/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 28: train_loss=0.175294 val_avg_MAE=141.994

VAL
   12h  MAE=128.672  RMSE=253.429
   24h  MAE=130.524  RMSE=265.118
   48h  MAE=151.308  RMSE=297.043
   72h  MAE=157.473  RMSE=305.931


Train 29/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 30/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 30: train_loss=0.174034 val_avg_MAE=145.617

VAL
   12h  MAE=129.858  RMSE=257.859
   24h  MAE=132.704  RMSE=268.412
   48h  MAE=155.743  RMSE=305.075
   72h  MAE=164.161  RMSE=320.209


Train 31/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 32/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 32: train_loss=0.172956 val_avg_MAE=145.367

VAL
   12h  MAE=129.445  RMSE=256.178
   24h  MAE=132.911  RMSE=270.741
   48h  MAE=156.525  RMSE=307.543
   72h  MAE=162.586  RMSE=316.788


Train 33/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 34/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 34: train_loss=0.171843 val_avg_MAE=143.985

VAL
   12h  MAE=120.119  RMSE=239.259
   24h  MAE=134.027  RMSE=273.253
   48h  MAE=158.614  RMSE=312.432
   72h  MAE=163.180  RMSE=316.732


Train 35/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 36/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 36: train_loss=0.171044 val_avg_MAE=147.171

VAL
   12h  MAE=130.143  RMSE=258.121
   24h  MAE=133.319  RMSE=270.872
   48h  MAE=157.858  RMSE=310.634
   72h  MAE=167.366  RMSE=326.320


Train 37/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 38/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 38: train_loss=0.169688 val_avg_MAE=144.147

VAL
   12h  MAE=121.013  RMSE=241.299
   24h  MAE=131.525  RMSE=268.733
   48h  MAE=158.283  RMSE=311.138
   72h  MAE=165.768  RMSE=320.365


Train 39/40:   0%|          | 0/127 [00:00<?, ?it/s]

Train 40/40:   0%|          | 0/127 [00:00<?, ?it/s]

Eval:   0%|          | 0/37 [00:00<?, ?it/s]


Epoch 40: train_loss=0.169250 val_avg_MAE=144.258

VAL
   12h  MAE=119.834  RMSE=241.589
   24h  MAE=133.103  RMSE=272.535
   48h  MAE=157.770  RMSE=308.974
   72h  MAE=166.324  RMSE=321.353

Early stopping. Best val_avg_MAE=141.994
Saved history: artifacts/runs/20260210_223029_CNN_GRU_LSTM/history.csv

Evaluating on TEST set...


Eval:   0%|          | 0/85 [00:00<?, ?it/s]


CNN_GRU_LSTM — TEST
   12h  MAE=132.169  RMSE=254.404
   24h  MAE=126.465  RMSE=251.612
   48h  MAE=142.666  RMSE=285.668
   72h  MAE=152.196  RMSE=297.030


Collect preds:   0%|          | 0/85 [00:00<?, ?it/s]


Saved run outputs to: artifacts/runs/20260210_223029_CNN_GRU_LSTM
 - best checkpoint: artifacts/runs/20260210_223029_CNN_GRU_LSTM/best.pt
 - history: artifacts/runs/20260210_223029_CNN_GRU_LSTM/history.csv
 - test metrics: artifacts/runs/20260210_223029_CNN_GRU_LSTM/test_metrics.json
 - predictions (npz): artifacts/runs/20260210_223029_CNN_GRU_LSTM/test_pred_true_selected_horizons.npz
 - predictions (xlsx): artifacts/runs/20260210_223029_CNN_GRU_LSTM/test_pred_true_selected_horizons.xlsx
 - master summary: artifacts/results_summary.csv
Saved to: artifacts/runs/20260210_223029_CNN_GRU_LSTM
