
# 04 — Model Comparison: Baselines vs XGB and STL+XGB

**Project:** Sea Surface Temperature Anomaly Forecasting (SSTA)  
**Notebook:** 04_models.ipynb  
**Purpose:** Concise, academically styled comparison of multiple models using rolling-origin artifacts.

## Scope
- Load **metrics** and **predictions** from:
  - Baselines (`metrics_<project>.csv`, `preds_<project>.parquet`)
  - XGB direct (`metrics_xgb_<project>.csv`, `preds_xgb_<project>.parquet`)
  - STL+XGB hybrid (`metrics_stl_xgb_<project>.csv`, `preds_stl_xgb_<project>.parquet`)
- Build horizon-wise **leaderboards** (MAE, RMSE, sMAPE, MASE).
- Plot **MAE by model per horizon** and **forecast vs. actuals** overlays.
- Provide interpretation prompts for reporting.



## 0. Setup & configuration

Point to your processed artifacts directory and project name.  
The loader will attempt to auto-detect reasonable file name variants if needed.


In [None]:

# --- Configuration ---
PROJECT = "plymouth"     # <- change if your files use a different project name
PROCESSED_DIR = "data/processed"

import warnings
warnings.filterwarnings("ignore")

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def detect_metrics_path(stem: str, project: str) -> Path:
    cand = Path(PROCESSED_DIR) / f"{stem}_{project}.csv"
    if cand.exists():
        return cand
    # fallback: latest matching file for the stem
    matches = sorted(Path(PROCESSED_DIR).glob(f"{stem}_*.csv"))
    if not matches:
        raise FileNotFoundError(f"No {stem}_<project>.csv found in {PROCESSED_DIR}")
    return matches[-1]

def detect_preds_path(stem: str, project: str) -> Path:
    cand = Path(PROCESSED_DIR) / f"{stem}_{project}.parquet"
    if cand.exists():
        return cand
    # fallback: latest matching file for the stem
    matches = sorted(Path(PROCESSED_DIR).glob(f"{stem}_*.parquet"))
    if not matches:
        raise FileNotFoundError(f"No {stem}_<project>.parquet found in {PROCESSED_DIR}")
    return matches[-1]

# Locate files
m_base = detect_metrics_path("metrics", PROJECT)
p_base = detect_preds_path("preds", PROJECT)

m_xgb  = detect_metrics_path("metrics_xgb", PROJECT)
p_xgb  = detect_preds_path("preds_xgb", PROJECT)

m_stl  = detect_metrics_path("metrics_stl_xgb", PROJECT)
p_stl  = detect_preds_path("preds_stl_xgb", PROJECT)

print("Loaded:")
print(" ", m_base.name, p_base.name)
print(" ", m_xgb.name,  p_xgb.name)
print(" ", m_stl.name,  p_stl.name)

# Read all
metrics_base = pd.read_csv(m_base)
metrics_xgb  = pd.read_csv(m_xgb)
metrics_stl  = pd.read_csv(m_stl)

def load_preds(path: Path) -> pd.DataFrame:
    df = pd.read_parquet(path)
    df["date"] = pd.to_datetime(df["date"], errors="coerce")
    return df.dropna(subset=["date"]).sort_values(["date","horizon","model"]).reset_index(drop=True)

preds_base = load_preds(p_base)
preds_xgb  = load_preds(p_xgb)
preds_stl  = load_preds(p_stl)

# Normalize model names for neat plots
metrics_xgb["model"] = metrics_xgb.get("model", "xgb_direct")
metrics_stl["model"] = metrics_stl.get("model", "stl+xgb_resid")
preds_xgb["model"]   = preds_xgb.get("model", "xgb_direct")
preds_stl["model"]   = preds_stl.get("model", "stl+xgb_resid")



## 1. Unified leaderboards

We concatenate all metrics and produce per-horizon leaderboards (sorted by MAE).


In [None]:

metrics_all = pd.concat([metrics_base, metrics_xgb, metrics_stl], ignore_index=True, sort=False)
metrics_all = metrics_all.sort_values(["horizon","MAE"]).reset_index(drop=True)

print("Unified metrics (top 12 rows):")
print(metrics_all.head(12).to_string(index=False))

def leaderboard_for_h(h):
    m = metrics_all[metrics_all["horizon"] == h].copy()
    return m.sort_values("MAE")[["horizon","model","MAE","RMSE","sMAPE","MASE"]]

for h in sorted(metrics_all["horizon"].unique()):
    print("\n=== LEADERBOARD: Horizon +%d months ===" % h)
    print(leaderboard_for_h(h).to_string(index=False))



## 2. MAE by model per horizon

Bar plots of MAE, one panel per horizon.


In [None]:

def plot_mae_by_model(h):
    m = metrics_all[metrics_all["horizon"] == h].copy()
    labels = m["model"].tolist()
    vals = m["MAE"].tolist()
    plt.figure(figsize=(8,3))
    plt.bar(range(len(labels)), vals)
    plt.xticks(range(len(labels)), labels, rotation=20, ha="right")
    plt.ylabel("MAE (°C)")
    plt.title(f"MAE by Model — Horizon +{h} months")
    plt.tight_layout(); plt.show()

for h in sorted(metrics_all["horizon"].unique()):
    plot_mae_by_model(int(h))



## 3. Forecast vs. actuals (overlay)

We overlay forecasts from the **best model** at each horizon against actuals.  
(Uses the predictions files from each model family.)


In [None]:

# Helper to get preds df by model keyword
def get_preds_for_model(name_keyword: str):
    if "xgb" in name_keyword and "stl" not in name_keyword:
        return preds_xgb
    if "stl" in name_keyword:
        return preds_stl
    return preds_base

def plot_best_overlay(h):
    m = leaderboard_for_h(h)
    best_model = str(m.iloc[0]["model"])
    source = get_preds_for_model(best_model)
    p = source[source["horizon"] == h].copy()
    if p.empty:
        print(f"No predictions for horizon={h} in {best_model}.")
        return
    # Actuals (from the selected preds table)
    actual = p[["date","y_true"]].drop_duplicates().sort_values("date")
    plt.figure(figsize=(10,4))
    plt.plot(actual["date"], actual["y_true"], label="Actual SSTA")
    # Forecast
    pm = p[p["model"] == best_model].sort_values("date")
    plt.plot(pm["date"], pm["y_pred"], label=best_model)
    plt.title(f"Best Model vs Actuals — Horizon +{h} months")
    plt.xlabel("Date"); plt.ylabel("Anomaly (°C)")
    plt.legend()
    plt.tight_layout(); plt.show()

for h in sorted(metrics_all["horizon"].unique()):
    plot_best_overlay(int(h))



## 4. Relative improvement

Compare each model’s MAE to the **seasonal naïve** baseline at the same horizon.


In [None]:

def improvement_vs_seasonal_naive(h):
    m = metrics_all[metrics_all["horizon"] == h].copy()
    base = m[m["model"].str.contains("seasonal_naive")]["MAE"]
    if base.empty:
        print(f"Seasonal naïve not found for horizon {h}.")
        return
    b = float(base.iloc[0])
    m = m.assign(Improvement_vs_Seasonal_Naive_pct = 100.0 * (1.0 - m["MAE"] / b))
    print(m[["model","MAE","Improvement_vs_Seasonal_Naive_pct"]].sort_values("MAE").to_string(index=False))

for h in sorted(metrics_all["horizon"].unique()):
    print(f"\n=== Improvement vs Seasonal Naïve — h=+{h} ===")
    improvement_vs_seasonal_naive(int(h))



## 5. Interpretation prompts
- **Winners by horizon:** Which models dominate MAE at +1, +3, +6? Are the gains consistent?
- **Practical significance:** Are MAE improvements (°C) large relative to typical SSTA variability?
- **Robustness:** Do results hold across subperiods (e.g., 1982–2000 vs 2001–present)?
- **Complexity vs. gain:** Do advanced models materially outperform seasonal naïve and SARIMA?
- **Next steps:** Tune hyperparameters, add climate indices with delays, or test multi-horizon direct models.
