In [24]:
import os
import json
import yaml
import pandas as pd
import numpy as np
import wandb

from tqdm import tqdm
from collections import defaultdict
from typing import Tuple
from pathlib import Path
from src.constants import MODELS, BASELINES, DL

os.environ.setdefault(
    "WANDB_CACHE_DIR", "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/artifacts"
)


def model_to_lbw(
    dataset: str,
    model: str,
    params_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/config/params",
) -> int:
    yaml_file = Path(params_path) / model / dataset / "lbw.yaml"
    if not yaml_file.is_file():
        raise FileNotFoundError(f"Missing file: {yaml_file}")
    with yaml_file.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f) or {}
    return int(data["look_back_window"])


def get_metrics(
    runs: list,
    metrics_to_keep: list[str] = [
        "MSE",
        "MAE",
        "DIRACC",
        "MASE",
        "ND",
        "NRMSE",
        "SMAPE",
        "TMAE",
        "TMSE",
    ],
    artifacts_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/artifacts",
) -> Tuple[dict, dict, dict]:
    """
    Preprocess the run metrics for the wandb training runs in the runs list.
    Returns two dataframes, the first storing the mean and the second the standard deviation for
    each metric in the METRICS list defined in src/constants.py

    returns: - raw metric dictionary with keys 'model' -> 'look_back_window' -> 'prediction_window' -> 'fold_nr' -> 'seed'
             - mean dict, taking mean over folds and seeds  keys = 'model' -> 'look_back_window' -> 'prediction_window'
             - std dict, taking std over folds and seeds keys = 'model' -> 'look_back_window' -> 'prediction_window'
    """

    metrics = defaultdict(
        lambda: defaultdict(
            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
        )
    )
    model_run_counts = defaultdict(int)
    for run in tqdm(runs):
        config = run.config
        model_name = config["model"]["name"]
        dataset_name = config["dataset"]["name"]
        is_global = (
            dataset_name in ["dalia", "wildppg", "ieee"]
        ) or dataset_name == "lmitbih"
        look_back_window = config["look_back_window"]
        prediction_window = config["prediction_window"]
        seed = config["seed"]
        if is_global:
            fold = config["folds"]["fold_nr"]
            summary = run.summary._json_dict
            filtered_summary = {k: summary[k] for k in summary if k in metrics_to_keep}
            metrics[model_name][look_back_window][prediction_window][fold][seed] = (
                filtered_summary
            )
        else:
            raw_artifact = next(
                (a for a in run.logged_artifacts() if "raw_metrics" in a.name), None
            )
            if raw_artifact is None:
                print(f"No raw_metrics table for run {run.name}")
                continue
            else:
                art_dir = Path(artifacts_path) / str(raw_artifact.name).replace(
                    ":", "-"
                )

                if not os.path.exists(art_dir):
                    raw_artifact.download()

                json_path = art_dir / "raw_metrics.table.json"
                with open(json_path, "r", encoding="utf-8") as f:
                    obj = json.load(f)
                data = obj["data"]
                cols = obj["columns"]
                df = pd.DataFrame(data, columns=cols)
                for fold, row in df.iterrows():
                    d = row.to_dict()
                    metrics[model_name][look_back_window][prediction_window][fold][
                        seed
                    ] = d

        model_run_counts[model_name] += 1

    for k, v in model_run_counts.items():
        print(f"Model {k}: {v}")

    processed_metrics_mean = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    processed_metrics_std = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for model, v in metrics.items():
        for lbw, w in v.items():
            for pw, fold_dict in w.items():
                metric_list = defaultdict(list)
                for fold_nr, seed_dict in fold_dict.items():
                    for seed, metric_dict in seed_dict.items():
                        for metric_name, metric_value in metric_dict.items():
                            if metric_name in metrics_to_keep:
                                if isinstance(metric_value, str):
                                    print(
                                        f"VALUE IS STRING {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isinf(metric_value):
                                    print(
                                        f"VALUE IS INF {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isnan(metric_value):
                                    print(
                                        f"VALUE IS NAN {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                else:
                                    metric_list[metric_name].append(metric_value)

                mean = {
                    metric_name: float(np.mean(v))
                    for metric_name, v in metric_list.items()
                }
                std = {
                    metric_name: float(np.std(v))
                    for metric_name, v in metric_list.items()
                }
                processed_metrics_mean[model][lbw][pw] = mean
                processed_metrics_std[model][lbw][pw] = std
    return metrics, processed_metrics_mean, processed_metrics_std


def get_runs(
    dataset: str,
    look_back_window: list[int],
    prediction_window: list[int],
    models: list[str],
    start_time: str = "2025-6-12",
    feature: str = "mean",
    local_norm_endo_only: bool = False,
    local_norm: str = "local_z",
    predictions: bool = False,
) -> list:
    conditions = [
        {"config.use_prediction_callback": predictions},
        {"config.local_norm_endo_only": local_norm_endo_only},
        {"config.feature.name": {"$in": [feature]}},
        {"config.dataset.name": {"$in": [dataset]}},
        {"config.look_back_window": {"$in": look_back_window}},
        {"config.prediction_window": {"$in": prediction_window}},
        {"config.model.name": {"$in": models}},
        {"config.local_norm": {"$in": [local_norm]}},
        {"state": "finished"},
        {"created_at": {"$gte": start_time}},
    ]

    filters = {"$and": conditions}

    api = wandb.Api()
    runs = api.runs("c_keusch/thesis", filters=filters)
    runs = list(runs)
    print(f"Found {len(runs)} runs.")

    # assert len(runs) > 0, "No runs were found!"
    # assert len(runs) % 3 == 0, "Attention, length of runs is not divisible by 3!"
    return runs



In [69]:
datasets = ["dalia", "wildppg", "ieee"]
look_back_window = [30]
prediction_window = [3]
models = MODELS
start_time = "2025-10-28"
feature = "mean"
local_norms = ["local_z", "difference"]
dataset_metrics = {} 

for dataset in datasets:
    metrics = {}
    _, global_only_metrics, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=feature, local_norm_endo_only=False, local_norm="lnone",start_time=start_time))
    metrics["global_only"] = global_only_metrics
    for local_norm in local_norms:
        for local_norm_endo_only in [True, False]:
            _, metr, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=feature, local_norm_endo_only=local_norm_endo_only, local_norm=local_norm, start_time=start_time))
            metrics[f"{local_norm}_{local_norm_endo_only}"] = metr
    dataset_metrics[dataset] = metrics

Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]


Model xgboost: 3
Model linear: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model nbeatsx: 3
Model gpt4ts: 3
Model timexer: 3
Model mlp: 3
Model gp: 3
Model kalmanfilter: 3
Model mole: 3
Model msar: 3
Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 27452.20it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model mole: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 2673.68it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model gpt4ts: 3
Model simpletm: 3
Model nbeatsx: 3
Model timesnet: 3
Model timexer: 3
Model patchtst: 3
Model adamshyper: 3
Model mlp: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 26526.24it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mlp: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 7372.59it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 7879.09it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 17786.83it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





In [74]:
experiments = ["global_only", "local_z_True", "local_z_False", "difference_True", "difference_False"]

def compute_matrix(models: list[str], metrics: dict):
    all_model_values: list[list[float]] = []
    for model in models:
        values: list[float] = []
        for experiment in experiments:
            values.append(metrics[experiment][model][30][3]["MASE"])
        all_model_values.append(values)
    
    return np.array(all_model_values)

compute_matrix(BASELINES, metrics).shape

(7, 5)

In [91]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from src.constants import model_to_abbr, dataset_to_name

datasets = ["wildppg"]
norms = [
    "GZ",
    "GZ + Inst(Endo)",
    "GZ + Inst",
    "GZ + Diff(Endo)",
    "GZ + Diff",
]

data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(BASELINES, metrics),
        compute_matrix(DL, metrics),
    )

all_vals = np.concatenate([np.concatenate([v[0], v[1]], axis=0) for v in data.values()], axis=0)
zmin, zmax = np.percentile(all_vals, [0, 90])
print(zmin, zmax)

fig = make_subplots(
    rows=2, cols=len(datasets),
    subplot_titles=[dataset_to_name[d] for d in datasets],
    horizontal_spacing=0.06, vertical_spacing=0.15
)


for col, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]

    # Baselines heatmap (row 1)
    fig.add_trace(
        go.Heatmap(
            z=base_z,
            x=norms,
            y=[model_to_abbr[m] for m in BASELINES],
            zmin=zmin, zmax=zmax,
            hovertemplate="Model: %{y}<br>Norm: %{x}<br>MAE: %{z}<extra></extra>",
            showscale=True
        ),
        row=1, col=col
    )

    # DL heatmap (row 2)
    fig.add_trace(
        go.Heatmap(
            z=dl_z,
            x=norms,
            y=[model_to_abbr[m] for m in DL],
            zmin=zmin, zmax=zmax,
            hovertemplate="Model: %{y}<br>Norm: %{x}<br>MAE: %{z}<extra></extra>",
            showscale=True
        ),
        row=2, col=col
    )


for c in range(1,4):
    fig.update_xaxes(title="Normalization", row=2, col=c)  # bottom row x-axis titles
fig.update_yaxes(title="Model", row=1, col=1)
fig.update_yaxes(title="Model", row=2, col=1)

fig.update_layout(
    height=800,
    width=1200,
    template="plotly_white",
    margin=dict(t=80, l=60, r=40, b=40)
)

fig.show()


0.7567278544108073 0.9697676440080008


In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from src.constants import model_to_abbr, dataset_to_name

# --------------------
# Inputs
# --------------------
datasets = ["wildppg"]
norms = [
    "GZ",
    "GZ + Inst(Endo)",
    "GZ + Inst",
    "GZ + Diff(Endo)",
    "GZ + Diff",
]

# Build matrices: shape = (n_models, n_norms)
data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(BASELINES.remove("msar"), metrics),  # baselines: (len(BASELINES), len(norms))
        compute_matrix(DL, metrics),         # DL:        (len(DL),        len(norms))
    )

# Global y-range for MAE across everything (for consistent scaling)
all_vals = np.concatenate([np.concatenate([v[0], v[1]], axis=0) for v in data.values()], axis=0)
ymin, ymax = float(np.nanmin(all_vals)), float(np.nanmax(all_vals))
pad = 0.05 * (ymax - ymin if ymax > ymin else 1.0)
yrange = [ymin - pad, ymax + pad]

# --------------------
# Subplots (2 rows: Baselines / DL)
# --------------------
fig = make_subplots(
    rows=2,
    cols=len(datasets),
    subplot_titles=[dataset_to_name[d] for d in datasets],
    horizontal_spacing=0.06,
    vertical_spacing=0.15
)

# Grouped bars per normalization; one trace per model
for col, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]  # shapes: (n_models, n_norms)

    # --- Row 1: Baselines ---
    for i, m in enumerate(BASELINES):
        fig.add_trace(
            go.Bar(
                x=norms,
                y=base_z[i, :],
                name=f"{model_to_abbr[m]} (BL)",
                legendgroup="baselines",
                showlegend=(col == 1),  # show legend only in first column (row 1)
                hovertemplate=(
                    "Model: %{customdata}<br>"
                    "Norm: %{x}<br>"
                    "MAE: %{y:.3f}<extra></extra>"
                ),
                customdata=[model_to_abbr[m]] * len(norms)
            ),
            row=1, col=col
        )

    # --- Row 2: Deep Learning ---
    for i, m in enumerate(DL):
        fig.add_trace(
            go.Bar(
                x=norms,
                y=dl_z[i, :],
                name=f"{model_to_abbr[m]} (DL)",
                legendgroup="dl",
                showlegend=(col == 1),  # show legend only in first column (row 2)
                hovertemplate=(
                    "Model: %{customdata}<br>"
                    "Norm: %{x}<br>"
                    "MAE: %{y:.3f}<extra></extra>"
                ),
                customdata=[model_to_abbr[m]] * len(norms)
            ),
            row=2, col=col
        )

    # Axes for this column
    fig.update_yaxes(title="MAE ↓", range=yrange, row=1, col=col)
    fig.update_yaxes(title="MAE ↓", range=yrange, row=2, col=col)
    fig.update_xaxes(title="Normalization", row=2, col=col, categoryorder="array", categoryarray=norms)

# Layout
fig.update_layout(
    barmode="group",
    height=800,
    width=1200,
    template="plotly_white",
    margin=dict(t=80, l=60, r=40, b=40),
    legend=dict(tracegroupgap=8, orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0)
)

fig.show()
