In [None]:
import os
import json
import yaml
import pandas as pd
import numpy as np
import wandb

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from src.constants import model_to_abbr, dataset_to_name
from src.constants import MODELS, BASELINES, DL

from tqdm import tqdm
from collections import defaultdict
from typing import Tuple
from pathlib import Path

os.environ.setdefault(
    "WANDB_CACHE_DIR", "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/artifacts"
)


def model_to_lbw(
    dataset: str,
    model: str,
    params_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/config/params",
) -> int:
    yaml_file = Path(params_path) / model / dataset / "lbw.yaml"
    if not yaml_file.is_file():
        raise FileNotFoundError(f"Missing file: {yaml_file}")
    with yaml_file.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f) or {}
    return int(data["look_back_window"])


def get_metrics(
    runs: list,
    metrics_to_keep: list[str] = [
        "MSE",
        "MAE",
        "DIRACC",
        "MASE",
        "ND",
        "NRMSE",
        "SMAPE",
        "TMAE",
        "TMSE",
    ],
    artifacts_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/artifacts",
) -> Tuple[dict, dict, dict]:
    """
    Preprocess the run metrics for the wandb training runs in the runs list.
    Returns two dataframes, the first storing the mean and the second the standard deviation for
    each metric in the METRICS list defined in src/constants.py

    returns: - raw metric dictionary with keys 'model' -> 'look_back_window' -> 'prediction_window' -> 'fold_nr' -> 'seed'
             - mean dict, taking mean over folds and seeds  keys = 'model' -> 'look_back_window' -> 'prediction_window'
             - std dict, taking std over folds and seeds keys = 'model' -> 'look_back_window' -> 'prediction_window'
    """

    metrics = defaultdict(
        lambda: defaultdict(
            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
        )
    )
    model_run_counts = defaultdict(int)
    for run in tqdm(runs):
        config = run.config
        model_name = config["model"]["name"]
        dataset_name = config["dataset"]["name"]
        is_global = (
            dataset_name in ["dalia", "wildppg", "ieee"]
        ) or dataset_name == "lmitbih"
        look_back_window = config["look_back_window"]
        prediction_window = config["prediction_window"]
        seed = config["seed"]
        if is_global:
            fold = config["folds"]["fold_nr"]
            summary = run.summary._json_dict
            filtered_summary = {k: summary[k] for k in summary if k in metrics_to_keep}
            metrics[model_name][look_back_window][prediction_window][fold][seed] = (
                filtered_summary
            )
        else:
            raw_artifact = next(
                (a for a in run.logged_artifacts() if "raw_metrics" in a.name), None
            )
            if raw_artifact is None:
                print(f"No raw_metrics table for run {run.name}")
                continue
            else:
                art_dir = Path(artifacts_path) / str(raw_artifact.name).replace(
                    ":", "-"
                )

                if not os.path.exists(art_dir):
                    raw_artifact.download()

                json_path = art_dir / "raw_metrics.table.json"
                with open(json_path, "r", encoding="utf-8") as f:
                    obj = json.load(f)
                data = obj["data"]
                cols = obj["columns"]
                df = pd.DataFrame(data, columns=cols)
                for fold, row in df.iterrows():
                    d = row.to_dict()
                    metrics[model_name][look_back_window][prediction_window][fold][
                        seed
                    ] = d

        model_run_counts[model_name] += 1

    for k, v in model_run_counts.items():
        print(f"Model {k}: {v}")

    processed_metrics_mean = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    processed_metrics_std = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for model, v in metrics.items():
        for lbw, w in v.items():
            for pw, fold_dict in w.items():
                metric_list = defaultdict(list)
                for fold_nr, seed_dict in fold_dict.items():
                    for seed, metric_dict in seed_dict.items():
                        for metric_name, metric_value in metric_dict.items():
                            if metric_name in metrics_to_keep:
                                if isinstance(metric_value, str):
                                    print(
                                        f"VALUE IS STRING {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isinf(metric_value):
                                    print(
                                        f"VALUE IS INF {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isnan(metric_value):
                                    print(
                                        f"VALUE IS NAN {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                else:
                                    metric_list[metric_name].append(metric_value)

                mean = {
                    metric_name: float(np.mean(v))
                    for metric_name, v in metric_list.items()
                }
                std = {
                    metric_name: float(np.std(v))
                    for metric_name, v in metric_list.items()
                }
                processed_metrics_mean[model][lbw][pw] = mean
                processed_metrics_std[model][lbw][pw] = std
    return metrics, processed_metrics_mean, processed_metrics_std


def get_runs(
    dataset: str,
    look_back_window: list[int],
    prediction_window: list[int],
    models: list[str],
    start_time: str = "2025-6-12",
    feature: str = "mean",
    local_norm_endo_only: bool = False,
    local_norm: list[str] = ["local_z"],
    predictions: bool = False,
) -> list:

    conditions = [
        {"config.use_prediction_callback": predictions},
        {"config.local_norm_endo_only": local_norm_endo_only},
        {"config.feature.name": {"$in": [feature]}},
        {"config.dataset.name": {"$in": [dataset]}},
        {"config.look_back_window": {"$in": look_back_window}},
        {"config.prediction_window": {"$in": prediction_window}},
        {"config.model.name": {"$in": models}},
        {"config.local_norm": {"$in": [local_norm]}},
        {"state": "finished"},
        {"created_at": {"$gte": start_time}},
    ]

    filters = {"$and": conditions}

    api = wandb.Api()
    runs = api.runs("c_keusch/thesis", filters=filters)
    runs = list(runs)
    print(f"Found {len(runs)} runs.")

    # assert len(runs) > 0, "No runs were found!"
    # assert len(runs) % 3 == 0, "Attention, length of runs is not divisible by 3!"
    return runs



In [5]:
datasets = ["dalia", "wildppg", "ieee"]
look_back_window = [30]
prediction_window = [3]
models = MODELS
start_time = "2025-10-28"
feature = "mean"
local_norms = ["local_z", "difference"]
dataset_metrics = {} 

for dataset in datasets:
    metrics = {}
    _, global_only_metrics, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=feature, local_norm_endo_only=False, local_norm="lnone",start_time=start_time))
    metrics["global_only"] = global_only_metrics
    for local_norm in local_norms:
        for local_norm_endo_only in [True, False]:
            _, metr, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=feature, local_norm_endo_only=local_norm_endo_only, local_norm=local_norm, start_time=start_time))
            metrics[f"{local_norm}_{local_norm_endo_only}"] = metr
    dataset_metrics[dataset] = metrics

Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 5816.77it/s]

Model xgboost: 3
Model linear: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model nbeatsx: 3
Model gpt4ts: 3
Model timexer: 3
Model mlp: 3
Model gp: 3
Model kalmanfilter: 3
Model mole: 3
Model msar: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 82049.73it/s]

Model linear: 3
Model xgboost: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 7525.34it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model mole: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 2626.95it/s]

Model xgboost: 3
Model linear: 3
Model gpt4ts: 3
Model simpletm: 3
Model nbeatsx: 3
Model timesnet: 3
Model timexer: 3
Model patchtst: 3
Model adamshyper: 3
Model mlp: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 20852.36it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 21535.55it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 11282.23it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 20904.33it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mlp: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





In [6]:
experiments = ["global_only", "local_z_True", "local_z_False", "difference_True", "difference_False"]

def compute_matrix(models: list[str], metrics: dict):
    all_model_values: list[list[float]] = []
    for model in models:
        values: list[float] = []
        for experiment in experiments:
            values.append(metrics[experiment][model][30][3]["MAE"])
        all_model_values.append(values)
    
    return np.array(all_model_values)


datasets = ["wildppg"]
norms = [
    "GZ",
    "GZ + Inst(Endo)",
    "GZ + Inst",
    "GZ + Diff(Endo)",
    "GZ + Diff",
]

data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(BASELINES, metrics),
        compute_matrix(DL, metrics),
    )

In [8]:
datasets = ["dalia", "wildppg", "ieee"]
norms = ["GZ", "GZ + Inst(Endo)", "GZ + Inst", "GZ + Diff(Endo)", "GZ + Diff"]

# Build matrices (n_models × n_norms)
data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(BASELINES, metrics),
        compute_matrix(DL, metrics),
    )

# consistent range for readability
all_vals = np.concatenate([np.concatenate([v[0], v[1]], axis=0) for v in data.values()], axis=0)
ymin, ymax = np.nanmin(all_vals), np.nanmax(all_vals)
pad = 0.05 * (ymax - ymin if ymax > ymin else 1.0)
yrange = [ymin - pad, ymax + pad]

# --- Subplots (disable shared y-axes)
fig = make_subplots(
    rows=2,
    cols=len(datasets),
    shared_yaxes=False,
    subplot_titles=[dataset_to_name[d] for d in datasets],
    horizontal_spacing=0.06,
    vertical_spacing=0.15
)

for col, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]

    # Row 1: Baselines
    for j, norm in enumerate(norms):
        fig.add_trace(
            go.Bar(
                x=[model_to_abbr[m] for m in BASELINES],
                y=base_z[:, j],
                name=norm,
                legendgroup=norm,
                showlegend=(col == 1),  # show legend only for first dataset column
                # hovertemplate="Model: %{x}<br>Norm: %{customdata}<br>MAE: %{y:.3f}<extra></extra>",
                # customdata=[norm] * len(BASELINES)
            ),
            row=1, col=col
        )

    # Row 2: Deep Learning
    for j, norm in enumerate(norms):
        fig.add_trace(
            go.Bar(
                x=[model_to_abbr[m] for m in DL],
                y=dl_z[:, j],
                name=norm,
                legendgroup=norm,
                showlegend=(col == 1),
                # hovertemplate="Model: %{x}<br>Norm: %{customdata}<br>MAE: %{y:.3f}<extra></extra>",
                # customdata=[norm] * len(DL)
            ),
            row=2, col=col
        )

# Layout
fig.update_layout(
    barmode="group",
    height=800,
    width=1400,
    template="plotly_white",
    margin=dict(t=80, l=60, r=40, b=40),

)

fig.show()

In [11]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from src.constants import model_to_abbr, dataset_to_name

# --------------------
# Inputs
# --------------------
datasets = ["dalia", "wildppg", "ieee"]
norms = ["GZ", "GZ + Inst(Endo)", "GZ + Inst", "GZ + Diff(Endo)", "GZ + Diff"]

# Fixed colors per normalization (consistent across all rows)
norm_colors = {
    "GZ":               "#4E79A7",
    "GZ + Inst(Endo)":  "#F28E2B",
    "GZ + Inst":        "#E15759",
    "GZ + Diff(Endo)":  "#76B7B2",
    "GZ + Diff":        "#59A14F",
}

bm = [b for b in BASELINES if b != "msar"]
# Build matrices: each is (n_models, n_norms)
data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(bm, metrics),  # baselines (first 7)
        compute_matrix(DL, metrics),         # deep learning (last 7)
    )

# X order: 7 baselines + 7 DL (make sure MSAR is in BASELINES)
x_models_baseline = [model_to_abbr[m] for m in bm]   # includes MSAR
x_models_dl       = [model_to_abbr[m] for m in DL]
x_models_all      = x_models_baseline + x_models_dl

# Global y range (or compute per-row if you prefer)
all_vals = []
for d in datasets:
    bz, dz = data[d]
    all_vals.append(np.concatenate([bz, dz], axis=0))
all_vals = np.concatenate(all_vals, axis=0)
ymin, ymax = float(np.nanmin(all_vals)), float(np.nanmax(all_vals))
pad = 0.05 * (ymax - ymin if ymax > ymin else 1.0)
yrange = [ymin - pad, ymax + pad]

# --------------------
# Figure: 3 rows (one per dataset), 1 column
# --------------------
fig = make_subplots(
    rows=len(datasets), cols=1,
    shared_xaxes=False, shared_yaxes=False,
    subplot_titles=[dataset_to_name[d] for d in datasets],
    vertical_spacing=0.10
)

# Show each normalization once in legend (top row)
legend_shown_for_norm = {n: False for n in norms}

for r, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]          # shapes: (7, 5) each
    z_full = np.vstack([base_z, dl_z])  # shape: (14, 5)

    # One trace per normalization; Plotly will group bars by x category
    for j, norm in enumerate(norms):
        fig.add_trace(
            go.Bar(
                x=x_models_all,
                y=z_full[:, j],
                name=norm,
                legendgroup=norm,
                # ⚠️ DO NOT set offsetgroup/alignmentgroup here
                marker_color=norm_colors[norm],
                showlegend=(not legend_shown_for_norm[norm] and r == 1),
                hovertemplate="Model: %{x}<br>Norm: " + norm + "<br>MAE: %{y:.3f}<extra></extra>",
            ),
            row=r, col=1
        )
        if r == 1:
            legend_shown_for_norm[norm] = True

    # Axes per row
    fig.update_yaxes(title_text="MAE ↓", row=r, col=1)
    fig.update_xaxes(
        title_text="Model",
        categoryorder="array",
        categoryarray=x_models_all,
        row=r, col=1
    )

    # Optional visual separator between first 7 (Baselines) and last 7 (DL)
    # For categorical axes, a vertical line between index 6 and 7:
    fig.add_vline(
        x=6.5, line_width=1, line_dash="dash", line_color="rgba(0,0,0,0.35)",
        row=r, col=1
    )

fig.update_layout(
    barmode="group",          # group bars by model (x)
    bargap=0.20,
    bargroupgap=0.10,
    height=1000,
    width=1400,
    template="plotly_white",
    margin=dict(t=90, l=60, r=40, b=60),
    legend=dict(
        title="Normalization",
        orientation="h",
        yanchor="bottom",
        y=1.03,
        xanchor="left",
        x=0
    )
)

fig.show()


In [12]:
def get_best_value(
    metric_dict: dict,
    std_dict: dict,
    lbw: int,
    pw: int,
    dataset: str,
    models: list[str],
    metric: str,
) -> Tuple[float, float]:
    best_val: float = np.inf
    best_std: float = np.inf
    for model in models:
        val = (
            metric_dict[model][lbw][pw][metric]
            if metric in metric_dict[model][lbw][pw]
            else np.nan
        )
        print(val)
        if val < best_val:
            best_val = val
            best_std = std_dict[model][lbw][pw][metric]

    return best_val, best_std


def get_df(
    look_back_window: list[int],
    metric_dict: dict,
    std_dict: dict,
    dataset: str,
    models: list[str],
    metric: str,
    pw: int = 3,
):
    means: list[float] = []
    stds: list[float] = []
    for lbw in look_back_window:
        mean, std = get_best_value(
            metric_dict, std_dict, lbw, pw, dataset, models, metric
        )
        means.append(mean)
        stds.append(std)

    d = {"means": means, "stds": stds}
    df = pd.DataFrame.from_dict(d)
    return df


In [None]:

datasets = ["dalia", "wildppg", "ieee"]
models = MODELS
look_back_window = [5, 10, 20, 30, 60]
prediction_window = [3]
metric = "MAE"
start_time = "2025-11-01"


assert len(prediction_window) == 1
pw = prediction_window[0]

dataset_metrics = {}

for dataset in datasets:
    runs_exo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature="mean",
        start_time=start_time,
        local_norm_endo_only=False,
    )
    _, exo_mean, exo_std = get_metrics(runs_exo)

    runs_endo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature="none",
        start_time=start_time,
        local_norm_endo_only=False,
    )
    _, endo_mean, endo_std = get_metrics(runs_endo)

    dataset_metrics[dataset] = (exo_mean, exo_std, endo_mean, endo_std)

Found 105 runs.


100%|██████████| 105/105 [00:00<00:00, 26461.69it/s]

Model simpletm: 15
Model patchtst: 15
Model timexer: 15
Model adamshyper: 15
Model timesnet: 15
Model gpt4ts: 15
Model nbeatsx: 15





Found 105 runs.


100%|██████████| 105/105 [00:00<00:00, 16396.81it/s]

Model simpletm: 15
Model timesnet: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model adamshyper: 15
Model nbeatsx: 15





KeyboardInterrupt: 

In [None]:
from src.constants import (
    BASELINES,
    DL,
    model_colors,
)

use_std = False

TITLE_SIZE = 22 * 2
AXIS_TITLE_SIZE = 40
TICK_SIZE = 32
LEGEND_SIZE = 40
LINE_WIDTH = 8
MARKER_SIZE = 16
LINE_OPACITY = 0.9

SUBPLOT_TITLE_SIZE = 40
LEGEND_Y = -0.18

titles = [f"<b>{dataset_to_name[d]}</b>" for d in datasets]
fig = make_subplots(
    rows=1, cols=len(datasets), subplot_titles=titles, horizontal_spacing=0.08
)


for j, dataset in enumerate(datasets, start=1):

    exo_mean, exo_std, endo_mean, endo_std = dataset_metrics[dataset]

    # best per horizon for baselines and DL
    exo_bl_df = get_df(
        look_back_window, exo_mean, exo_std, dataset, BASELINES, metric, pw
    )
    endo_bl_df = get_df(
        look_back_window, endo_mean, endo_std, dataset, BASELINES, metric, pw
    )
    exo_dl_df = get_df(look_back_window, exo_mean, exo_std, dataset, DL, metric, pw)
    endo_dl_df = get_df(
        look_back_window, endo_mean, endo_std, dataset, DL, metric, pw
    )

    # make indices match horizons (1,3,5,10,20) for clarity
    for _df in (exo_bl_df, endo_bl_df, exo_dl_df, endo_dl_df):
        _df.index = look_back_window

    x = look_back_window
    x_labels = [str(x_val) for x_val in look_back_window]
    bl_color = model_colors["baseline"]
    dl_color = model_colors["dl"]
    showlegend = j == 1

    def add_line(x, y, stds, name, color, dash):
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines+markers",
                line=dict(width=LINE_WIDTH, color=color, dash=dash),
                marker=dict(size=MARKER_SIZE),
                name=name,
                showlegend=showlegend,
                error_y=(
                    dict(type="data", array=stds, visible=True) if use_std else None
                ),
            ),
            row=1,
            col=j,
        )

    # Baselines: Exo solid, Endo dashed
    add_line(
        x,
        exo_bl_df["means"].values,
        exo_bl_df["stds"].values,
        "Baselines — Exo (best)",
        bl_color,
        "solid",
    )
    add_line(
        x,
        endo_bl_df["means"].values,
        endo_bl_df["stds"].values,
        "Baselines — Endo (best)",
        bl_color,
        "dash",
    )

    # DL: Exo solid, Endo dashed
    add_line(
        x,
        exo_dl_df["means"].values,
        exo_dl_df["stds"].values,
        "DL — Exo (best)",
        dl_color,
        "solid",
    )
    add_line(
        x,
        endo_dl_df["means"].values,
        endo_dl_df["stds"].values,
        "DL — Endo (best)",
        dl_color,
        "dash",
    )

    # axes
    fig.update_xaxes(
        title_text="Lookback Window",
        type="category",
        categoryorder="array",
        categoryarray=x_labels,
        row=1,
        col=j,
    )
    fig.update_yaxes(title_text=f"Best {metric}", row=1, col=j)

# layout & legend
fig.update_annotations(font=dict(size=SUBPLOT_TITLE_SIZE))
fig.update_layout(
    legend=dict(
        orientation="h", x=0.5, xanchor="center", y=LEGEND_Y, yanchor="top"
    ),
    margin=dict(b=120),
)
fig.update_xaxes(title_font=dict(size=14))
fig.update_yaxes(title_font=dict(size=14))

subplot_size = 500  # pixels per subplot
rows = 1
cols = len(datasets)

total_width = cols * subplot_size
total_height = rows * subplot_size

fig.update_layout(
    width=total_width,
    height=total_height,
)

fig.show()

nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
