# Evaluation Notebook — AQ-NBEATS Exog+Series Model

Adapted from the paper's `Evaluate.ipynb` for `AnyQuantileForecasterExogWithSeries`.

**Changes vs the original notebook:**
- Loads `AnyQuantileForecasterExogWithSeries` instead of `AnyQuantileForecaster`
- Uses `EMHIRESUnivariateDataModule` (has `series_id` and exog support)
- Checkpoint pattern matches `nbeatsaq-exog-series` naming convention
- Computes CRPS on all 11 fixed eval quantiles (no random/deterministic split needed)
- Optionally loads and applies the isotonic calibrator from `calibrate.py`
- Saves per-series results in the same pickle format as the original for downstream compatibility

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle
from glob import glob
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from tqdm.auto import tqdm

from model.models import AnyQuantileForecasterExogWithSeries
from dataset.datasets import EMHIRESUnivariateDataModule
from metrics import SMAPE, MAPE, CRPS
from utils.model_factory import instantiate

import warnings
warnings.filterwarnings('ignore')

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
RESULTS_DIR = './results'

## 1. Configuration

In [None]:
# ── Checkpoint discovery ───────────────────────────────────────────────────────
EXPERIMENT_NAME  = "nbeatsaq-exog-series"   
CHECKPOINT_NAME  = "model-epoch=13.ckpt"    # or "*" to pick all epochs

checkpoint_pattern = (
    f"lightning_logs/{EXPERIMENT_NAME}*/checkpoints/{CHECKPOINT_NAME}"
)

# ── Eval quantiles (fixed grid, same as plot_interactive.py) ──────────────────
EVAL_QUANTILES = [0.01, 0.05, 0.10, 0.25, 0.40, 0.50, 0.60, 0.75, 0.90, 0.95, 0.99]

# ── Calibrator (optional — set to None to skip) ───────────────────────────────
CALIBRATOR_PATH = "results/calibration/calibrator.pkl"  # or None

# ── Dataset split (must match training config) ────────────────────────────────
SPLIT_BOUNDARIES = ["2006-01-01", "2017-12-30", "2018-01-01", "2019-01-01"]

model_list = sorted(glob(checkpoint_pattern))
print(f"Pattern : {checkpoint_pattern}")
print(f"Found   : {len(model_list)} checkpoint(s)")
for m in model_list:
    print(f"  {m}")

if not model_list:
    raise FileNotFoundError(f"No checkpoints found: {checkpoint_pattern}")

## 2. Load Config & Dataset

In [None]:
import yaml

CONFIG_PATH = Path("config/nbeatsaq-exog-series.yaml")

def load_cfg(config_path: Path):
    """Load OmegaConf cfg, stripping the !!python/tuple YAML tag."""
    if config_path.exists():
        with open(config_path) as f:
            raw = f.read().replace("!!python/tuple", "")
        return OmegaConf.create(yaml.safe_load(raw))
    # Fallback minimal config
    return OmegaConf.create({
        "model": {
            "input_horizon_len": 168, "max_norm": True,
            "num_series": 35, "series_embed_dim": 32, "series_embed_scale": 0.08,
        },
        "dataset": {
            "name": "MHLV", "train_batch_size": 512, "eval_batch_size": 512,
            "num_workers": 0, "persistent_workers": False,
            "horizon_length": 24, "history_length": 168,
            "split_boundaries": SPLIT_BOUNDARIES,
            "fillna": "ffill", "train_step": 1, "eval_step": 24,
        },
    })

cfg = load_cfg(CONFIG_PATH)
cfg.dataset.split_boundaries = SPLIT_BOUNDARIES  # override if needed

dm = EMHIRESUnivariateDataModule(
    name               = cfg.dataset.name,
    train_batch_size   = cfg.dataset.train_batch_size,
    eval_batch_size    = cfg.dataset.eval_batch_size,
    num_workers        = 0,
    persistent_workers = False,
    horizon_length     = cfg.dataset.horizon_length,
    history_length     = cfg.dataset.history_length,
    split_boundaries   = cfg.dataset.split_boundaries,
    fillna             = cfg.dataset.fillna,
    train_step         = cfg.dataset.train_step,
    eval_step          = cfg.dataset.eval_step,
)
dm.setup(stage="test")
test_loader = dm.test_dataloader()
print(f"Test samples: {len(dm.test_dataset)}")

In [None]:
# Quick test with one checkpoint
test_ckpt = model_list[0]
print(f"Testing checkpoint: {test_ckpt}")

# Add weights_only=False to bypass the security check
model_test = AnyQuantileForecasterExogWithSeries.load_from_checkpoint(
    test_ckpt, 
    cfg=cfg, 
    strict=False, 
    map_location='cpu',
    weights_only=False  # ← ADD THIS LINE
)
model_test.eval()

# Get one batch
test_batch = next(iter(test_loader))
test_batch['quantiles'] = torch.full((test_batch['history'].shape[0], 1), 0.5, dtype=torch.float32)

# Try forward
try:
    output = model_test(test_batch)
    print(f"✅ Forward works! Output shape: {output.shape}")
except Exception as e:
    print(f"❌ Forward failed: {e}")

## 3. Collect Ground-Truth Metadata

In [None]:
# Mirror the original notebook's metadata collection loop
dfs = []
for b in tqdm(test_loader, desc="Collecting metadata"):
    row = {}
    for k in ["target", "history", "series_id"]:
        if k in b:
            row[k] = list(b[k].cpu().numpy())
    dfs.append(pd.DataFrame.from_dict(row))

df = pd.concat(dfs, axis=0, ignore_index=True)
print(f"Total rows : {len(df)}")
print(f"Series IDs : {sorted(df.series_id.unique().astype(int))}")
df.head(3)

## 4. Generate Predictions (all seeds → ensemble)

In [None]:
@torch.no_grad()
def predict_checkpoint(ckpt_path: str, cfg, dataloader, quantiles: list, device):
    """
    Load one checkpoint and produce predictions at all fixed quantiles.
    Returns tensor of shape [N, H, Q].
    """
    model = AnyQuantileForecasterExogWithSeries.load_from_checkpoint(
        ckpt_path, 
        cfg=cfg, 
        strict=False, 
        map_location=device,
        weights_only=False  # ← ADD THIS
    )
    model.eval().to(device)
    
    all_preds = []
    
    for batch in tqdm(dataloader, desc="  batches", leave=False):
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                 for k, v in batch.items()}
        
        B = batch["history"].shape[0]
        per_q = []
        
        for q_val in quantiles:
            batch["quantiles"] = torch.full((B, 1), q_val, device=device, dtype=torch.float32)
            pred = model(batch)
            
            if pred.dim() == 3:
                per_q.append(pred.squeeze(-1).cpu())
            else:
                per_q.append(pred.cpu())
        
        all_preds.append(torch.stack(per_q, dim=-1))
    
    return torch.cat(all_preds, dim=0)

In [None]:
# Setup safe globals ONCE before loading any checkpoints
import torch.serialization
from omegaconf import DictConfig, ListConfig
from omegaconf.base import ContainerMetadata
import typing
import collections

# Add ALL necessary safe globals
torch.serialization.add_safe_globals([
    # OmegaConf types
    DictConfig, 
    ListConfig, 
    ContainerMetadata,
    
    # Python built-ins
    int,
    float,
    str,
    bool,
    list,
    dict,
    tuple,
    set,
    
    # Typing
    typing.Any,
    
    # Collections
    collections.defaultdict,
    collections.OrderedDict,
])

print("✅ Safe globals configured")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

per_seed_preds = []
for i, ckpt in enumerate(model_list):
    print(f"\n[{i+1}/{len(model_list)}] {Path(ckpt).parent.parent.name}")
    try:
        preds = predict_checkpoint(ckpt, cfg, test_loader, EVAL_QUANTILES, device)
        per_seed_preds.append(preds)
        print(f"  ✅ Shape: {preds.shape}")
    except Exception as e:
        print(f"  ❌ Error loading checkpoint: {e}")
        raise

# Ensemble: average across seeds
predictions_ensemble = torch.stack(per_seed_preds).mean(dim=0)
predictions_ensemble, _ = torch.sort(predictions_ensemble, dim=-1)
print(f"\n✅ Ensemble predictions shape: {predictions_ensemble.shape}")
print(f"   [N_samples, H_horizon, Q_quantiles] = {predictions_ensemble.shape}")

## 5. (Optional) Apply Isotonic Calibration

In [None]:
# Import calibrator class BEFORE loading pickle
from calibrate import IsotonicCalibrator

calibrator = None
predictions_calibrated = predictions_ensemble.clone()

if CALIBRATOR_PATH and Path(CALIBRATOR_PATH).exists():
    with open(CALIBRATOR_PATH, "rb") as f:
        calibrator = pickle.load(f)
    print(f"✅ Calibrator loaded from {CALIBRATOR_PATH}")

    # Get corrected raw quantiles and re-query the model
    corrected_q = calibrator.corrected_query_quantiles(EVAL_QUANTILES)
    print(f"   Nominal  : {[f'{q:.3f}' for q in EVAL_QUANTILES]}")
    print(f"   Corrected: {[f'{q:.3f}' for q in corrected_q]}")

    cal_preds_list = []
    for i, ckpt in enumerate(model_list):
        print(f"\n  Re-querying [{i+1}/{len(model_list)}] at corrected quantiles…")
        preds = predict_checkpoint(ckpt, cfg, test_loader, corrected_q.tolist(), device)
        cal_preds_list.append(preds)

    predictions_calibrated = torch.stack(cal_preds_list).mean(dim=0)
    predictions_calibrated, _ = torch.sort(predictions_calibrated, dim=-1)
    print(f"\nCalibrated ensemble shape: {predictions_calibrated.shape}")
else:
    print("ℹ️  No calibrator found — using raw predictions.")
    print("   Run calibrate.py with --save-calibrator to generate one.")

## 6. CRPS & Coverage Metrics

In [None]:
q_tensor = torch.tensor(EVAL_QUANTILES, dtype=torch.float32)

def compute_all_metrics(preds_tensor, df_meta, quantiles_list, label=""):
    """
    Compute CRPS, per-quantile coverage, MAE, MAPE for a [N, H, Q] prediction tensor.
    """
    q_t = torch.tensor(quantiles_list, dtype=torch.float32)
    crps_metric  = CRPS()
    mape_metric  = MAPE()
    smape_metric = SMAPE()

    preds_np  = preds_tensor.numpy()          # [N, H, Q]
    targets   = np.array(list(df_meta.target))# [N, H]

    # --- CRPS (using paper's metric class) ---
    for i in tqdm(range(len(df_meta)), desc=f"{label} CRPS", leave=False):
        tgt = torch.tensor(targets[i], dtype=torch.float32)
        if torch.isinf(tgt).any() or torch.isnan(tgt).any():
            continue
        pred = torch.tensor(preds_np[i], dtype=torch.float32)   # [H, Q]
        crps_metric.update(
            preds=pred[None],      # [1, H, Q]
            target=tgt[None],      # [1, H]
            q=q_t[None],           # [1, Q]
        )

    # --- Point forecast metrics (median = Q0.50) ---
    mid_idx    = quantiles_list.index(0.50) if 0.50 in quantiles_list else len(quantiles_list)//2
    pred_median = torch.tensor(preds_np[:, :, mid_idx], dtype=torch.float32)  # [N, H]
    target_t    = torch.tensor(targets, dtype=torch.float32)                   # [N, H]
    valid       = ~(torch.isnan(target_t) | torch.isinf(target_t))
    mape_metric.update(pred_median[valid], target_t[valid])
    smape_metric.update(pred_median[valid], target_t[valid])

    # --- Coverage per quantile ---
    coverage = {}
    for i, q in enumerate(quantiles_list):
        pred_q = preds_np[:, :, i]  # [N, H]
        hit    = np.mean(targets <= pred_q)
        coverage[q] = hit

    return {
        "crps"    : crps_metric.compute().item(),
        "mape"    : mape_metric.compute().item(),
        "smape"   : smape_metric.compute().item(),
        "mae"     : float(np.nanmean(np.abs(preds_np[:,:,mid_idx] - targets))),
        "coverage": coverage,
    }


print("Computing metrics on raw ensemble…")
metrics_raw = compute_all_metrics(predictions_ensemble, df, EVAL_QUANTILES, label="Raw")

if calibrator is not None:
    print("\nComputing metrics on calibrated ensemble…")
    metrics_cal = compute_all_metrics(predictions_calibrated, df, EVAL_QUANTILES, label="Cal")
else:
    metrics_cal = None

In [None]:
def print_metrics(metrics, label=""):
    print(f"\n{'='*60}")
    print(f"  {label}")
    print(f"{'='*60}")
    print(f"  CRPS  : {metrics['crps']:.3f} MW")
    print(f"  MAE   : {metrics['mae']:.3f} MW")
    print(f"  MAPE  : {metrics['mape']*100:.2f}%")
    print(f"  sMAPE : {metrics['smape']*100:.2f}%")
    print(f"\n  {'Quantile':>10}  {'Coverage':>10}  {'Error':>10}")
    print(f"  {'-'*34}")
    for q, cov in metrics["coverage"].items():
        print(f"  {q:>10.3f}  {cov:>10.3f}  {cov-q:>+10.3f}")
    mace = np.mean([abs(cov - q) for q, cov in metrics["coverage"].items()])
    print(f"\n  MACE  : {mace:.4f}")

print_metrics(metrics_raw, "RAW ENSEMBLE")
if metrics_cal:
    print_metrics(metrics_cal, "CALIBRATED ENSEMBLE")

## 7. Normalised CRPS (N-CRPS) — Paper's Primary Metric

Replicates Eq. (18) from the paper: normalises by mean country demand so all 35 series contribute equally.

In [None]:
def compute_ncrps(preds_np, df_meta, quantiles_list):
    """
    N-CRPS = 100 * (1/C) * sum_c [ (1 / H*N_c*Q) * sum_{i,h,q} pinball(y,f) / y_bar_c ]
    Replicates Eq. (18) in the paper.
    """
    q_arr   = np.array(quantiles_list)
    targets = np.array(list(df_meta.target))      # [N, H]
    series  = df_meta.series_id.values
    country_ncrps = {}

    for sid in np.unique(series):
        mask   = series == sid
        tgt    = targets[mask]                    # [Nc, H]
        pred   = preds_np[mask]                   # [Nc, H, Q]
        y_bar  = np.nanmean(tgt)
        if y_bar == 0:
            continue

        # pinball loss: [Nc, H, Q]
        err = tgt[..., np.newaxis] - pred
        pb  = np.where(err >= 0, q_arr * err, (q_arr - 1) * err)
        country_ncrps[int(sid)] = pb.mean() / y_bar * 100

    ncrps_all = list(country_ncrps.values())
    return np.mean(ncrps_all), country_ncrps


preds_np_raw = predictions_ensemble.numpy()
ncrps_raw, ncrps_per_country_raw = compute_ncrps(preds_np_raw, df, EVAL_QUANTILES)
print(f"N-CRPS (raw ensemble)      : {ncrps_raw:.4f}")

if calibrator is not None:
    preds_np_cal = predictions_calibrated.numpy()
    ncrps_cal, ncrps_per_country_cal = compute_ncrps(preds_np_cal, df, EVAL_QUANTILES)
    print(f"N-CRPS (calibrated ensemble): {ncrps_cal:.4f}")

# Compare vs paper baseline (AQ-NBEATS: 211.22 CRPS, 1.84 N-CRPS)
print(f"\nPaper AQ-NBEATS baseline   : N-CRPS 1.84")
print(f"Your model (raw)           : N-CRPS {ncrps_raw:.4f}")
if calibrator:
    print(f"Your model (calibrated)    : N-CRPS {ncrps_cal:.4f}")

In [None]:
# Per-country N-CRPS breakdown
import matplotlib.pyplot as plt

countries = sorted(ncrps_per_country_raw.keys())
vals_raw  = [ncrps_per_country_raw[c] for c in countries]

fig, ax = plt.subplots(figsize=(16, 5))
x = np.arange(len(countries))
ax.bar(x, vals_raw, color="#2196F3", alpha=0.8, label="Raw")
if calibrator:
    vals_cal = [ncrps_per_country_cal[c] for c in countries]
    ax.bar(x, vals_cal, color="#4CAF50", alpha=0.6, label="Calibrated")
ax.axhline(1.84, color="red", lw=1.5, ls="--", label="Paper AQ-NBEATS (1.84)")
ax.set_xticks(x)
ax.set_xticklabels(countries, rotation=45, fontsize=8)
ax.set(xlabel="Series ID", ylabel="N-CRPS", title="Per-Country N-CRPS")
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(f"{RESULTS_DIR}/ncrps_per_country.png", dpi=150, bbox_inches="tight")
plt.show()

## 8. Save Per-Series Results (compatible with paper format)

In [None]:
# Decide which predictions to save
predictions_to_save = (
    predictions_calibrated if calibrator is not None else predictions_ensemble
)
preds_np_save = predictions_to_save.numpy()  # [N, H, Q]

RESULTS_PATH = os.path.join(RESULTS_DIR, f"MHLV/ExogSeries-calibrated={calibrator is not None}")
os.makedirs(RESULTS_PATH, exist_ok=True)

targets_all = np.array(list(df.target))      # [N, H]
q_arr       = np.array(EVAL_QUANTILES)       # [Q]

# Save one pickle per seed (compatibility with downstream scripts)
for worker_idx, seed_preds in enumerate(per_seed_preds):
    p = seed_preds.numpy()  # [N, H, Q]

    for series_id in df.series_id.unique():
        mask       = df.series_id == series_id
        df_series  = df[mask]
        p_series   = p[mask.values]              # [Nc, H, Q]

        target_series = np.nan_to_num(targets_all[mask.values], posinf=np.nan)  # [Nc, H]
        # Expand for compatibility: [Nc, H, Q]
        target_rep    = np.repeat(target_series[..., None], len(EVAL_QUANTILES), axis=-1)
        # Quantiles: [Nc, H, Q]
        quant_rep     = np.broadcast_to(q_arr, p_series.shape).copy()

        forec = pd.DataFrame({f"forec{worker_idx+1}": p_series.ravel()})
        if worker_idx == 0:
            forec["actuals"] = target_rep.ravel()
            forec["quants"]  = quant_rep.ravel()
            forec = forec[["actuals", "quants", "forec1"]]

        out_path = os.path.join(RESULTS_PATH, f"e1w{worker_idx+1}_{int(series_id)}.pickle")
        forec.to_pickle(out_path)

print(f"✅ Results saved to {RESULTS_PATH}/")

## 9. Final Summary vs Paper Baselines

In [None]:
paper_baselines = {
    "Naive"    : {"crps": 502.62, "ncrps": 3.95, "mape": 5.08},
    "ARIMA"    : {"crps": 353.49, "ncrps": 2.83, "mape": 3.74},
    "ES"       : {"crps": 325.80, "ncrps": 2.64, "mape": 3.48},
    "WaveNet"  : {"crps": 293.38, "ncrps": 2.52, "mape": 3.39},
    "AQ-ESRNN" : {"crps": 195.94, "ncrps": 1.72, "mape": 2.32},
    "AQ-NBEATS": {"crps": 211.22, "ncrps": 1.84, "mape": 2.47},
}

rows = []
for name, v in paper_baselines.items():
    rows.append({"Model": name, "CRPS": v["crps"], "N-CRPS": v["ncrps"], "MAPE %": v["mape"]})

rows.append({
    "Model"  : "Exog+Series (raw)",
    "CRPS"   : metrics_raw["crps"],
    "N-CRPS" : ncrps_raw,
    "MAPE %" : metrics_raw["mape"] * 100,
})

if metrics_cal:
    rows.append({
        "Model"  : "Exog+Series (calibrated)",
        "CRPS"   : metrics_cal["crps"],
        "N-CRPS" : ncrps_cal,
        "MAPE %" : metrics_cal["mape"] * 100,
    })

summary_df = pd.DataFrame(rows).set_index("Model")
summary_df = summary_df.sort_values("CRPS")

# Highlight our rows
def highlight_ours(row):
    return ["background-color: #e8f5e9" if "Exog" in row.name else "" for _ in row]

display(summary_df.style.apply(highlight_ours, axis=1).format(
    {"CRPS": "{:.2f}", "N-CRPS": "{:.4f}", "MAPE %": "{:.2f}"}
))

In [None]:
os.makedirs(RESULTS_DIR, exist_ok=True)
summary_df.to_csv(f"{RESULTS_DIR}/benchmark_summary.csv")
print(f"✅ Summary saved to {RESULTS_DIR}/benchmark_summary.csv")