In [2]:
#   1) Rebuilds the SAME features/graph/time-tensor as deterministic GS-ADR h=24 run
#   2) Loads ckpt_pignn_h24.pt + scaler_pignn_h24.npz
#   3) Adds a heteroscedastic aleatoric head (sigma) and trains with Gaussian NLL
#      - Freeze mean/physics/embeddings for a few epochs (sigma-only warmup)
#      - Then fine-tune all parameters lightly
#   4) MC-dropout inference (M=30) = epistemic + aleatoric + exceedance probability
#   5) Isotonic calibration on VAL for p_exceed


import os, math
import numpy as np
import pandas as pd

from sklearn.neighbors import NearestNeighbors
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import (
    brier_score_loss, precision_recall_curve, average_precision_score,
    precision_score, recall_score, f1_score, confusion_matrix
)

import torch
import torch.nn as nn
import torch.nn.functional as F

DATA_PATH = "dataset_2023_2025.csv"

CKPT_DET_PATH   = "ckpt_pignn_h24.pt"
SCALER_PATH     = "scaler_pignn_h24.npz"         
OUT_PRED_UQ     = "predictions_uq_h24.csv"
OUT_MET_UQ      = "metrics_uq_h24.csv"
CKPT_UQ_PATH    = "ckpt_pignn_uq_h24.pt"

# Graph
K_GRAPH = 4
ELL_KM  = 120.0

# Horizon
H = 24
DT = 1.0

# Feature design (must match the deterministic model
PM_LAGS   = [1, 2, 6, 24]
CHEMS     = ['carbon_monoxide', 'nitrogen_dioxide', 'sulphur_dioxide', 'ozone']
CHEM_LAGS = [1, 6, 24]

# Splits (same strict time split)
TRAIN_END = "2024-12-31 23:00:00"
VAL_END   = "2025-06-30 23:00:00"
TEST_END  = "2025-11-23 23:00:00"

# Constraints (apply on mean mu)
LAMBDA_INEQ    = 1.0
LAMBDA_NONNEG  = 0.05

# Model capacity (must match deterministic)
DROPOUT = 0.20
EMB_DIM = 8
HIDDEN  = 96

# Training
SEED = 42
BATCH_TIMES = 128

EPOCHS_TOTAL = 25
FREEZE_MEAN_EPOCHS = 3     # sigma-only warmup
PATIENCE = 5

LR_SIGMA  = 2e-3           # during freeze
LR_FINE   = 5e-4           # after unfreeze (smaller)
WEIGHT_DECAY = 1e-6

# UQ inference
M_MC = 30

# Policy threshold
T_EXCEED   = 65.0
TAU_POLICY = 0.80          # policy (strict); we will also compute tau_star on VAL

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

np.random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------------
# Helpers
# -----------------------------
def time_split_times(times):
    times = pd.to_datetime(times)
    tr = times[times <= pd.Timestamp(TRAIN_END)]
    va = times[(times > pd.Timestamp(TRAIN_END)) & (times <= pd.Timestamp(VAL_END))]
    te = times[(times > pd.Timestamp(VAL_END)) & (times <= pd.Timestamp(TEST_END))]
    return tr, va, te

def build_knn_graph(city_meta, k=4, ell_km=120.0):
    coords = city_meta[['lat','lon']].to_numpy()
    coords_rad = np.radians(coords)

    nnm = NearestNeighbors(n_neighbors=k+1, metric='haversine')
    nnm.fit(coords_rad)
    dists, idxs = nnm.kneighbors(coords_rad)

    n = len(city_meta)
    W = np.zeros((n,n), dtype=np.float32)
    for i in range(n):
        for t in range(1, k+1):
            j = idxs[i, t]
            dist_km = dists[i, t] * 6371.0
            w = np.exp(-(dist_km/ell_km)**2)
            W[i, j] = max(W[i, j], w)
            W[j, i] = max(W[j, i], w)
    D = np.diag(W.sum(axis=1))
    L = (D - W).astype(np.float32)
    return W, L

def hinge_pos(x): return F.relu(x)

def enable_dropout(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()

def normal_cdf(z):
    return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))

def gaussian_nll(mu, sigma, y):
    # sigma positive; returns (B,N,1)
    var = sigma**2
    return 0.5 * (torch.log(var + 1e-12) + (y - mu)**2 / (var + 1e-12))

def ece_score(p, y, n_bins=10):
    p = np.asarray(p); y = np.asarray(y)
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bins[i], bins[i+1]
        mask = (p >= lo) & (p < hi) if i < n_bins-1 else (p >= lo) & (p <= hi)
        if mask.sum() == 0:
            continue
        conf = p[mask].mean()
        acc  = y[mask].mean()
        ece += (mask.sum() / len(p)) * abs(acc - conf)
    return float(ece)

def pick_tau_by_val(df_val, p_col):
    y = df_val["z_true"].values.astype(int)
    p = df_val[p_col].values.astype(float)
    taus = np.linspace(0.05, 0.95, 19)
    best = (-1, 0.5)
    for tau in taus:
        yhat = (p >= tau).astype(int)
        f1 = f1_score(y, yhat, zero_division=0)
        if f1 > best[0]:
            best = (f1, float(tau))
    return float(best[1]), float(best[0])

def compute_metrics(df_split, p_col, tau):
    y = df_split["z_true"].values.astype(int)
    p = df_split[p_col].values.astype(float)

    bs = brier_score_loss(y, p)
    ece = ece_score(p, y, n_bins=10)

    auprc = average_precision_score(y, p)

    yhat = (p >= tau).astype(int)
    P = precision_score(y, yhat, zero_division=0)
    R = recall_score(y, yhat, zero_division=0)
    F1 = f1_score(y, yhat, zero_division=0)
    tn, fp, fn, tp = confusion_matrix(y, yhat).ravel()

    return {
        "p_col": p_col,
        "tau": float(tau),
        "brier": float(bs),
        "ece": float(ece),
        "auprc": float(auprc),
        "precision": float(P),
        "recall": float(R),
        "f1": float(F1),
        "tp": int(tp), "fp": int(fp), "tn": int(tn), "fn": int(fn),
    }

# -----------------------------
# Load deterministic scaler
# -----------------------------
class SavedStandardScaler:
    def __init__(self, mean_, scale_):
        self.mean_ = mean_.astype(np.float32)
        self.scale_ = scale_.astype(np.float32)
        self.scale_[self.scale_ == 0] = 1.0

    def transform(self, X):
        # X (..., F)
        return (X - self.mean_) / self.scale_

sc_npz = np.load(SCALER_PATH)
scaler = SavedStandardScaler(sc_npz["mean_"], sc_npz["scale_"])
print("Loaded scaler:", SCALER_PATH)

# Load & build features (must match DET)

df = pd.read_csv(DATA_PATH)
df["datetime"] = pd.to_datetime(df["datetime"])

keep_cols = [
    "city_id","city_name","lat","lon","datetime",
    "pm2_5","pm10",
] + CHEMS
df = df[keep_cols].copy().sort_values(["city_id","datetime"]).reset_index(drop=True)

# lags (history only)
for lag in PM_LAGS:
    df[f"pm2_5_lag{lag}"] = df.groupby("city_id")["pm2_5"].shift(lag)
for c in CHEMS:
    for lag in CHEM_LAGS:
        df[f"{c}_lag{lag}"] = df.groupby("city_id")[c].shift(lag)

# targets at h=24
df[f"y_h{H}"] = df.groupby("city_id")["pm2_5"].shift(-H)
df[f"pm10_h{H}"] = df.groupby("city_id")["pm10"].shift(-H)

needed = [f"pm2_5_lag{l}" for l in PM_LAGS] + \
         [f"{c}_lag{l}" for c in CHEMS for l in CHEM_LAGS] + \
         [f"y_h{H}", f"pm10_h{H}"]
df = df.dropna(subset=needed).copy()

# city order
city_meta = df[["city_id","city_name","lat","lon"]].drop_duplicates().sort_values("city_id").reset_index(drop=True)
city_ids = city_meta["city_id"].to_numpy()
cid_to_idx = {cid:i for i,cid in enumerate(city_ids)}
N = len(city_ids)

# graph
W_np, L_np = build_knn_graph(city_meta, k=K_GRAPH, ell_km=ELL_KM)
W = torch.tensor(W_np, device=DEVICE)
L = torch.tensor(L_np, device=DEVICE)
deg = W.sum(dim=1, keepdim=True).clamp_min(1e-6)
A_norm = W / deg

print(f"Device: {DEVICE} | Cities: {N} | k={K_GRAPH} | Avg deg={(W_np>0).sum(axis=1).mean():.2f}")

# time tensor (target-time cyclical at t+h, no leakage)
def build_time_tensor(df_in, horizon_h):
    df2 = df_in.copy()
    df2["cid_idx"] = df2["city_id"].map(cid_to_idx)
    df2 = df2.sort_values(["datetime","cid_idx"])

    # target-time cyc features at t+h
    dt_tgt = df2["datetime"] + pd.to_timedelta(horizon_h, unit="h")
    hour = dt_tgt.dt.hour.values
    doy  = dt_tgt.dt.dayofyear.values

    df2["hour_sin_tgt"] = np.sin(2*np.pi*hour/24.0)
    df2["hour_cos_tgt"] = np.cos(2*np.pi*hour/24.0)
    df2["doy_sin_tgt"]  = np.sin(2*np.pi*doy/365.0)
    df2["doy_cos_tgt"]  = np.cos(2*np.pi*doy/365.0)

    X_COLS = []
    X_COLS += CHEMS
    X_COLS += [f"{c}_lag{l}" for c in CHEMS for l in CHEM_LAGS]
    X_COLS += [f"pm2_5_lag{l}" for l in PM_LAGS]
    X_COLS += ["doy_sin_tgt","doy_cos_tgt","hour_sin_tgt","hour_cos_tgt"]
    X_COLS += ["lat","lon"]

    # keep only full times (all cities)
    counts = df2.groupby("datetime")["cid_idx"].nunique()
    full_times = counts[counts == N].index
    df2 = df2[df2["datetime"].isin(full_times)].copy()

    times = pd.to_datetime(np.sort(df2["datetime"].unique()))
    T = len(times)

    X = np.zeros((T, N, len(X_COLS)), dtype=np.float32)
    c0 = np.zeros((T, N, 1), dtype=np.float32)
    y  = np.zeros((T, N, 1), dtype=np.float32)
    pm10_tgt = np.zeros((T, N, 1), dtype=np.float32)

    g = df2.groupby("datetime", sort=True)
    for ti, t in enumerate(times):
        gt = g.get_group(t).sort_values("cid_idx")
        X[ti,:,:] = gt[X_COLS].to_numpy(np.float32)
        c0[ti,:,0] = gt["pm2_5"].to_numpy(np.float32)
        y[ti,:,0]  = gt[f"y_h{horizon_h}"].to_numpy(np.float32)
        pm10_tgt[ti,:,0] = gt[f"pm10_h{horizon_h}"].to_numpy(np.float32)

    return times, X, c0, y, pm10_tgt, X_COLS

times, X, c0, y, pm10_tgt, X_COLS = build_time_tensor(df, H)

tr_times, va_times, te_times = time_split_times(times)
time_to_i = {t:i for i,t in enumerate(times)}
tr_idx = np.array([time_to_i[t] for t in tr_times], dtype=int)
va_idx = np.array([time_to_i[t] for t in va_times], dtype=int)
te_idx = np.array([time_to_i[t] for t in te_times], dtype=int)

# apply saved scaler (train-only) from deterministic stage
X = scaler.transform(X).astype(np.float32)

# Model (must match DET mean pathway) + sigma head
class Backbone(nn.Module):
    def __init__(self, in_dim, hidden=96, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
    def forward(self, z):
        return self.net(z)

class PIGNN_UQ(nn.Module):
    """
    Mean: same PI-GNN implicit macro-step with learned source mean s_mean
    Sigma: heteroscedastic head from same backbone (aleatoric)
    Epistemic: MC-dropout at inference
    """
    def __init__(self, x_dim, L, A_norm, n_cities, emb_dim=8, hidden=96, dt=1.0, dropout=0.2):
        super().__init__()
        self.L = L
        self.A = A_norm
        self.dt = dt

        self.city_emb = nn.Embedding(n_cities, emb_dim)
        self.register_buffer("city_idx", torch.arange(n_cities, dtype=torch.long))

        # input to backbone: [x, x_nb, c0, c_nb, emb]
        self.in_dim = 2*x_dim + 2*1 + emb_dim
        self.backbone = Backbone(self.in_dim, hidden=hidden, dropout=dropout)

        # mean head (source term mean)
        self.mean_head = nn.Linear(hidden, 1)

        # aleatoric head
        self.logsig_head = nn.Linear(hidden, 1)
        # init sigma small (important for stability)
        nn.init.zeros_(self.logsig_head.weight)
        nn.init.constant_(self.logsig_head.bias, -2.0)  # softplus(-2) ~ 0.13

        # physics params
        self.D_raw = nn.Parameter(torch.tensor(0.0))
        self.k_raw = nn.Parameter(torch.tensor(-1.0))

    def D(self): return F.softplus(self.D_raw) + 1e-6
    def k(self): return F.softplus(self.k_raw) + 1e-6

    def forward(self, c0, x, steps):
        # neighbor agg
        x_nb = torch.einsum("ij,bjk->bik", self.A, x)
        c_nb = torch.einsum("ij,bjk->bik", self.A, c0)

        emb = self.city_emb(self.city_idx).unsqueeze(0).expand(x.shape[0], -1, -1)

        z = torch.cat([x, x_nb, c0, c_nb, emb], dim=-1)
        h = self.backbone(z)

        s_mean = self.mean_head(h)
        sigma  = F.softplus(self.logsig_head(h)) + 1e-3   # positive

        # implicit macro-step (same as DET)
        D = self.D()
        k = self.k()
        hdt = steps * self.dt

        Nn = self.L.shape[0]
        I = torch.eye(Nn, device=c0.device, dtype=c0.dtype)
        A_sys = I + hdt * (D * self.L + k * I)
        A_b = A_sys.unsqueeze(0).expand(c0.shape[0], -1, -1)
        rhs = c0 + hdt * s_mean

        mu = torch.linalg.solve(A_b, rhs)  # (B,N,1)
        return mu, sigma

# -----------------------------
# Load deterministic checkpoint into UQ model (warm-start mean & physics)
# -----------------------------
model = PIGNN_UQ(
    x_dim=X.shape[-1], L=L, A_norm=A_norm, n_cities=N,
    emb_dim=EMB_DIM, hidden=HIDDEN, dt=DT, dropout=DROPOUT
).to(DEVICE)

ckpt = torch.load(CKPT_DET_PATH, map_location=DEVICE)

# The deterministic stage likely saved either:
#  - a full model.state_dict(), or
#  - a dict with key "state_dict".
state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt

# Map compatible keys:
missing, unexpected = [], []

def load_if_exists(dst_key, src_key):
    if src_key in state and dst_key in model.state_dict():
        model.state_dict()[dst_key].copy_(state[src_key])
        return True
    return False

sd = model.state_dict()

# embeddings
load_if_exists("city_emb.weight", "city_emb.weight")

# physics params
load_if_exists("D_raw", "D_raw")
load_if_exists("k_raw", "k_raw")

# If deterministic had a single SourceNet with .backbone and .mean_head we can match:
# try several likely key patterns
key_map_attempts = [
    # DET (newer) style
    ("backbone.net.0.weight", "source.backbone.0.weight"),
    ("backbone.net.0.bias",   "source.backbone.0.bias"),
    ("backbone.net.3.weight", "source.backbone.3.weight"),
    ("backbone.net.3.bias",   "source.backbone.3.bias"),
    ("mean_head.weight",      "source.mean_head.weight"),
    ("mean_head.bias",        "source.mean_head.bias"),

    # DET (older) style: source is nn.Sequential and last is Linear->1
    ("backbone.net.0.weight", "source.net.0.weight"),
    ("backbone.net.0.bias",   "source.net.0.bias"),
    ("backbone.net.3.weight", "source.net.3.weight"),
    ("backbone.net.3.bias",   "source.net.3.bias"),
    ("mean_head.weight",      "source.net.6.weight"),
    ("mean_head.bias",        "source.net.6.bias"),
]

loaded_any = False
for dst_k, src_k in key_map_attempts:
    if dst_k in sd and src_k in state:
        sd[dst_k].copy_(state[src_k])
        loaded_any = True

model.load_state_dict(sd, strict=False)

print("Loaded DET warm-start from:", CKPT_DET_PATH)
print("Note: logsigma_head is NEW and will be trained.")

# Freeze/unfreeze helpers
def set_requires_grad(module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag

def freeze_mean_parts(m):
    set_requires_grad(m.city_emb, False)
    set_requires_grad(m.backbone, False)
    set_requires_grad(m.mean_head, False)
    m.D_raw.requires_grad = False
    m.k_raw.requires_grad = False
    # sigma head trainable
    set_requires_grad(m.logsig_head, True)

def unfreeze_all(m):
    set_requires_grad(m, True)

# Training loop
@torch.no_grad()
def eval_val(idx):
    model.eval()
    losses, viols = [], []
    for b0 in range(0, len(idx), BATCH_TIMES):
        ii = idx[b0:b0+BATCH_TIMES]
        c0b   = torch.from_numpy(c0[ii]).to(DEVICE)
        Xb    = torch.from_numpy(X[ii]).to(DEVICE)
        yb    = torch.from_numpy(y[ii]).to(DEVICE)
        pm10b = torch.from_numpy(pm10_tgt[ii]).to(DEVICE)

        mu, sigma = model(c0b, Xb, H)

        nll = gaussian_nll(mu, sigma, yb).mean()
        loss = nll
        loss = loss + LAMBDA_INEQ * hinge_pos(mu - pm10b).mean()
        loss = loss + LAMBDA_NONNEG * hinge_pos(-mu).mean()

        losses.append(loss.item())
        viols.append(((mu - pm10b) > 0).float().mean().item())
    return float(np.mean(losses)), float(np.mean(viols))

best_val = 1e18
best_state = None
bad = 0

for ep in range(1, EPOCHS_TOTAL+1):
    if ep <= FREEZE_MEAN_EPOCHS:
        freeze_mean_parts(model)
        lr = LR_SIGMA
    else:
        unfreeze_all(model)
        lr = LR_FINE

    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                           lr=lr, weight_decay=WEIGHT_DECAY)

    model.train()
    np.random.shuffle(tr_idx)
    tr_losses = []

    for b0 in range(0, len(tr_idx), BATCH_TIMES):
        ii = tr_idx[b0:b0+BATCH_TIMES]
        c0b   = torch.from_numpy(c0[ii]).to(DEVICE)
        Xb    = torch.from_numpy(X[ii]).to(DEVICE)
        yb    = torch.from_numpy(y[ii]).to(DEVICE)
        pm10b = torch.from_numpy(pm10_tgt[ii]).to(DEVICE)

        mu, sigma = model(c0b, Xb, H)

        nll = gaussian_nll(mu, sigma, yb).mean()

        loss = nll
        loss = loss + LAMBDA_INEQ * hinge_pos(mu - pm10b).mean()
        loss = loss + LAMBDA_NONNEG * hinge_pos(-mu).mean()

        # mild regularizer to avoid sigma exploding early
        loss = loss + 1e-4 * (sigma.mean())

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        tr_losses.append(loss.item())

    val_loss, val_viol = eval_val(va_idx)
    print(f"[UQ h=24] ep {ep:02d} | train {np.mean(tr_losses):.4f} | val_loss {val_loss:.4f} | "
          f"val_viol {100*val_viol:.2f}% | D {model.D().item():.4f} k {model.k().item():.4f} | lr {lr:g}")

    if val_loss < best_val - 1e-4:
        best_val = val_loss
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        bad = 0
    else:
        bad += 1
        if bad >= PATIENCE:
            print(f"Early stop. Best val_loss={best_val:.4f}")
            break

if best_state is not None:
    model.load_state_dict(best_state)

torch.save({"state_dict": model.state_dict()}, CKPT_UQ_PATH)
print("Saved UQ checkpoint:", CKPT_UQ_PATH)

# MC inference => mu_mean, sigma components, exceedance probability
@torch.no_grad()
def mc_predict(idx, split_name):
    model.eval()
    enable_dropout(model)  # MC-dropout ON

    Tn = len(idx)
    mu_mean = np.zeros((Tn, N), dtype=np.float32)
    p_raw   = np.zeros((Tn, N), dtype=np.float32)
    epi_var = np.zeros((Tn, N), dtype=np.float32)
    ale_var = np.zeros((Tn, N), dtype=np.float32)

    y_true = y[idx,:,0].astype(np.float32)
    pm10_true = pm10_tgt[idx,:,0].astype(np.float32)

    for b0 in range(0, Tn, BATCH_TIMES):
        sl = slice(b0, min(b0+BATCH_TIMES, Tn))
        ii = idx[sl]
        c0b = torch.from_numpy(c0[ii]).to(DEVICE)
        Xb  = torch.from_numpy(X[ii]).to(DEVICE)

        mu_sum   = torch.zeros((len(ii), N, 1), device=DEVICE)
        mu2_sum  = torch.zeros((len(ii), N, 1), device=DEVICE)
        sig2_sum = torch.zeros((len(ii), N, 1), device=DEVICE)
        p_sum    = torch.zeros((len(ii), N, 1), device=DEVICE)

        T_thr = torch.tensor(T_EXCEED, device=DEVICE).view(1,1,1)

        for _ in range(M_MC):
            mu_m, sig_m = model(c0b, Xb, H)  # (B,N,1)
            mu_sum  += mu_m
            mu2_sum += mu_m**2
            sig2_sum += sig_m**2

            z = (T_thr - mu_m) / sig_m
            p_m = 1.0 - normal_cdf(z)
            p_sum += p_m

        mu_bar = mu_sum / M_MC
        mu_var = (mu2_sum / M_MC - mu_bar**2).clamp_min(0.0)  # epistemic variance of mean
        ale    = sig2_sum / M_MC                              # aleatoric variance

        pbar = (p_sum / M_MC).clamp(0.0, 1.0)

        mu_mean[sl,:] = mu_bar.squeeze(-1).cpu().numpy()
        p_raw[sl,:]   = pbar.squeeze(-1).cpu().numpy()
        epi_var[sl,:] = mu_var.squeeze(-1).cpu().numpy()
        ale_var[sl,:] = ale.squeeze(-1).cpu().numpy()

    rows = []
    for local_ti, global_ti in enumerate(idx):
        t = times[global_ti]  # issuance time
        for n in range(N):
            rows.append({
                "city_id": int(city_ids[n]),
                "datetime": pd.Timestamp(t),
                "split": split_name,
                "horizon_h": H,
                "y_true": float(y_true[local_ti, n]),
                "pm10_true": float(pm10_true[local_ti, n]),
                "mu_mean": float(mu_mean[local_ti, n]),
                "p_exceed_raw": float(np.clip(p_raw[local_ti, n], 0.0, 1.0)),
                "sigma_epi": float(math.sqrt(max(epi_var[local_ti, n], 0.0))),
                "sigma_ale": float(math.sqrt(max(ale_var[local_ti, n], 0.0))),
                "sigma_total": float(math.sqrt(max(epi_var[local_ti, n] + ale_var[local_ti, n], 0.0))),
                "z_true": int(y_true[local_ti, n] > T_EXCEED),
                # optional: keep lat/lon for maps
                "lat": float(city_meta.loc[n, "lat"]),
                "lon": float(city_meta.loc[n, "lon"]),
            })
    return pd.DataFrame(rows)

print("\nMC inference on VAL...")
df_val = mc_predict(va_idx, "val")
print("VAL rows:", len(df_val))

print("\nMC inference on TEST...")
df_test = mc_predict(te_idx, "test")
print("TEST rows:", len(df_test))

# Calibration (isotonic on VAL) + metrics
iso = IsotonicRegression(out_of_bounds="clip")
iso.fit(df_val["p_exceed_raw"].values, df_val["z_true"].values)

df_val["p_exceed_cal"]  = iso.transform(df_val["p_exceed_raw"].values)
df_test["p_exceed_cal"] = iso.transform(df_test["p_exceed_raw"].values)

# Save predictions
pred_all = pd.concat([df_val, df_test], ignore_index=True)
pred_all.to_csv(OUT_PRED_UQ, index=False)
print("\nSaved:", OUT_PRED_UQ)

# Choose tau_star on VAL (maximize F1) and report both policy tau and tau_star
tau_star, f1_star = pick_tau_by_val(df_val, "p_exceed_cal")
print(f"tau_star (VAL max-F1) = {tau_star:.2f} (F1={f1_star:.3f}) | tau_policy={TAU_POLICY:.2f}")

met_rows = []
for split_name, d in [("val", df_val), ("test", df_test)]:
    for pcol in ["p_exceed_raw", "p_exceed_cal"]:
        # metrics at tau_policy and tau_star
        for tau in [TAU_POLICY, tau_star]:
            m = compute_metrics(d, pcol, tau)
            m.update({
                "split": split_name,
                "T_exceed": float(T_EXCEED),
                "tau_policy": float(TAU_POLICY),
                "tau_star_val": float(tau_star),
            })
            met_rows.append(m)

met = pd.DataFrame(met_rows).sort_values(["split","p_col","tau"]).reset_index(drop=True)
met.to_csv(OUT_MET_UQ, index=False)
print("Saved:", OUT_MET_UQ)
print(met)

print("\nDone.")

Loaded scaler: scaler_pignn_h24.npz
Device: cpu | Cities: 29 | k=4 | Avg deg=4.90
Loaded DET warm-start from: ckpt_pignn_h24.pt
Note: logsigma_head is NEW and will be trained.
[UQ h=24] ep 01 | train 96.7584 | val_loss 7.3220 | val_viol 33.44% | D 0.5019 k 0.3113 | lr 0.002
[UQ h=24] ep 02 | train 4.6529 | val_loss 7.0159 | val_viol 33.44% | D 0.5019 k 0.3113 | lr 0.002
[UQ h=24] ep 03 | train 4.4459 | val_loss 6.9707 | val_viol 33.44% | D 0.5019 k 0.3113 | lr 0.002
[UQ h=24] ep 04 | train 3.5661 | val_loss 4.2010 | val_viol 13.31% | D 0.5247 k 0.3087 | lr 0.0005
[UQ h=24] ep 05 | train 3.4876 | val_loss 4.0969 | val_viol 11.23% | D 0.5459 k 0.3052 | lr 0.0005
[UQ h=24] ep 06 | train 3.4617 | val_loss 4.1992 | val_viol 12.89% | D 0.5621 k 0.3025 | lr 0.0005
[UQ h=24] ep 07 | train 3.4423 | val_loss 4.4552 | val_viol 17.21% | D 0.5780 k 0.3000 | lr 0.0005
[UQ h=24] ep 08 | train 3.4299 | val_loss 4.2353 | val_viol 13.87% | D 0.5894 k 0.2984 | lr 0.0005
[UQ h=24] ep 09 | train 3.4162 | v