# Benchmark forecasting — validation + plots

Sections:
1. Setup & configuration  
2. Validation (temporal holdout + metrics + calibration)  
3. Plots (fit on full data + forecasts by category)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from forecasting import (
    load_dataset,
    prepare_dataset,
    ModelConfig,
    SamplingConfig,
    temporal_holdout,
    crps_score,
    point_error,
    fit,
    generate_forecast,
)

from plotting import (
    Theme,
    make_mpl_style,
    apply_style,
    plot_calibration_curve,
    plot_category_forecast,
)


## 1) Setup & configuration

In [None]:
# ---- Data ----
DATA_PATH = "benchmark_data_processed/all_normalized_updated_benchmarks.csv"

# ---- Model ----
CFG = ModelConfig(
    sigmoid="harvey",
    joint=True,
    top_n=3,
)

SAMP = SamplingConfig(
    draws=2000,
    tune=1000,
    target_accept=0.9,
    seed=42,
    progressbar=True,
)

# ---- Plot style ----
# A4 / slides: 1.0
# double-column: try 1.3–1.6
PLOT_SCALE = 1.0

apply_style(make_mpl_style(scale=PLOT_SCALE), reset=False)

THEME = Theme()  # override via: THEME = Theme().with_overrides(...)


## 2) Load & prepare data

In [None]:
raw = load_dataset(DATA_PATH)
data = prepare_dataset(raw, top_n=CFG.top_n)

data.head()


## 3) Validation — temporal holdout

In [None]:
CUTOFF_DATE = pd.to_datetime("2025-01-01")

idata_val = temporal_holdout(
    raw,
    cutoff_date=CUTOFF_DATE,
    cfg=CFG,
    samp=SAMP,
    min_train_points=3,
)

print("CRPS:", crps_score(idata_val))
print("RMSE:", point_error(idata_val, metric="RMSE"))
print("MAE :", point_error(idata_val, metric="MAE"))


In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
plot_calibration_curve(idata_val, ax=ax, n_points=20)
plt.show()


## 4) Plots — fit on full data and forecast

In [None]:
idata, model = fit(data, CFG, SAMP)

END_DATE = pd.to_datetime("2030-03-01")
forecast_df = generate_forecast(
    idata,
    model,
    prepared_frontier=data,
    end_date=END_DATE,
    n_points=250,
    ci_level=0.8,
)

forecast_df.head()


In [None]:
if "category" in data.columns:
    categories = list(data["category"].dropna().unique())
else:
    categories = ["all"]

for cat in categories:
    obs_cat = data if cat == "all" else data.loc[data["category"] == cat]
    pred_cat = forecast_df if cat == "all" else forecast_df.loc[forecast_df["category"] == cat]

    plot_category_forecast(
        observed=obs_cat,
        forecast=pred_cat,
        end_date=END_DATE,
        category_label=cat,
        theme=THEME,
        scale=PLOT_SCALE,
        figsize=(7, 4),
    )
    plt.show()
