In [40]:
import os
import json
import yaml
import pandas as pd
import numpy as np
import wandb
import plotly.graph_objects as go

from plotly.subplots import make_subplots
from tqdm import tqdm
from collections import defaultdict
from typing import Tuple
from pathlib import Path
from numpy.typing import NDArray

from src.constants import model_to_abbr, dataset_to_name
from src.constants import MODELS, BASELINES, DL, model_colors, model_to_name

In [3]:
os.environ.setdefault(
    "WANDB_CACHE_DIR", "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/src/notebooks/artifacts"
)

def model_to_lbw(
    dataset: str,
    model: str,
    params_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/config/params",
) -> int:
    yaml_file = Path(params_path) / model / dataset / "lbw.yaml"
    if not yaml_file.is_file():
        raise FileNotFoundError(f"Missing file: {yaml_file}")
    with yaml_file.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f) or {}
    return int(data["look_back_window"])


def get_metrics(
    runs: list,
    metrics_to_keep: list[str] = [
        "MSE",
        "MAE",
        "DIRACC",
        "MASE",
        "ND",
        "NRMSE",
        "SMAPE",
        "TMAE",
        "TMSE",
    ],
    artifacts_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/src/notebooks/artifacts",
    verbose: bool = True,
) -> Tuple[dict, dict, dict]:
    """
    Preprocess the run metrics for the wandb training runs in the runs list.
    Returns two dataframes, the first storing the mean and the second the standard deviation for
    each metric in the METRICS list defined in src/constants.py

    returns: - raw metric dictionary with keys 'model' -> 'look_back_window' -> 'prediction_window' -> 'fold_nr' -> 'seed'
             - mean dict, taking mean over folds and seeds  keys = 'model' -> 'look_back_window' -> 'prediction_window'
             - std dict, taking std over folds and seeds keys = 'model' -> 'look_back_window' -> 'prediction_window'
    """

    metrics = defaultdict(
        lambda: defaultdict(
            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
        )
    )
    model_run_counts = defaultdict(int)
    for run in tqdm(runs):
        config = run.config
        model_name = config["model"]["name"]
        dataset_name = config["dataset"]["name"]
        is_global = (
            dataset_name in ["dalia", "wildppg", "ieee"]
        ) or dataset_name == "lmitbih"
        look_back_window = config["look_back_window"]
        prediction_window = config["prediction_window"]
        seed = config["seed"]
        if is_global:
            fold = config["folds"]["fold_nr"]
            summary = run.summary._json_dict
            filtered_summary = {k: summary[k] for k in summary if k in metrics_to_keep}
            metrics[model_name][look_back_window][prediction_window][fold][seed] = (
                filtered_summary
            )
        else:
            raw_artifact = next(
                (a for a in run.logged_artifacts() if "raw_metrics" in a.name), None
            )
            if raw_artifact is None:
                print(f"No raw_metrics table for run {run.name}")
                continue
            else:
                art_dir = Path(artifacts_path) / str(raw_artifact.name).replace(
                    ":", "-"
                )

                if not os.path.exists(art_dir):
                    raw_artifact.download()

                json_path = art_dir / "raw_metrics.table.json"
                with open(json_path, "r", encoding="utf-8") as f:
                    obj = json.load(f)
                data = obj["data"]
                cols = obj["columns"]
                df = pd.DataFrame(data, columns=cols)
                for fold, row in df.iterrows():
                    d = row.to_dict()
                    metrics[model_name][look_back_window][prediction_window][fold][
                        seed
                    ] = d

        model_run_counts[model_name] += 1
    if verbose:
        for k, v in model_run_counts.items():
            print(f"Model {k}: {v}")

    processed_metrics_mean = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    processed_metrics_std = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for model, v in metrics.items():
        for lbw, w in v.items():
            for pw, fold_dict in w.items():
                metric_list = defaultdict(list)
                for fold_nr, seed_dict in fold_dict.items():
                    for seed, metric_dict in seed_dict.items():
                        for metric_name, metric_value in metric_dict.items():
                            if metric_name in metrics_to_keep:
                                if isinstance(metric_value, str):
                                    print(
                                        f"VALUE IS STRING {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isinf(metric_value):
                                    print(
                                        f"VALUE IS INF {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                elif np.isnan(metric_value):
                                    print(
                                        f"VALUE IS NAN {metric_value} for model {model} lbw {lbw} pw {pw} seed {seed} fold {fold_nr}"
                                    )
                                else:
                                    metric_list[metric_name].append(metric_value)

                mean = {
                    metric_name: float(np.mean(v))
                    for metric_name, v in metric_list.items()
                }
                std = {
                    metric_name: float(np.std(v))
                    for metric_name, v in metric_list.items()
                }
                processed_metrics_mean[model][lbw][pw] = mean
                processed_metrics_std[model][lbw][pw] = std
    return metrics, processed_metrics_mean, processed_metrics_std


def get_runs(
    dataset: str,
    look_back_window: list[int],
    prediction_window: list[int],
    models: list[str],
    feature: list[str] = ["mean"],
    local_norm_endo_only: bool = False,
    local_norm: list[str] = ["local_z"],
    predictions: bool = False,
    start_time: str = "2025-6-12",
    end_time: str = "2025-6-12",
) -> list:

    conditions = [
        {"config.use_prediction_callback": predictions},
        {"config.local_norm_endo_only": local_norm_endo_only},
        {"config.feature.name": {"$in": feature}},
        {"config.dataset.name": {"$in": [dataset]}},
        {"config.look_back_window": {"$in": look_back_window}},
        {"config.prediction_window": {"$in": prediction_window}},
        {"config.model.name": {"$in": models}},
        {"config.local_norm": {"$in": local_norm}},
        {"state": "finished"},
        {"created_at": {"$gte": start_time}},
        {"created_at": {"$lt": end_time}}
    ]

    filters = {"$and": conditions}

    api = wandb.Api()
    runs = api.runs("c_keusch/thesis", filters=filters)
    runs = list(runs)
    print(f"Found {len(runs)} runs.")

    return runs



In [59]:
datasets = ["dalia", "wildppg", "ieee"]
look_back_window = [30]
prediction_window = [3]
models = MODELS
start_time = "2025-10-28"
end_time = "2025-10-30"
feature = "mean"
local_norms = ["local_z", "difference"]
dataset_metrics = {} 

for dataset in datasets:
    metrics = {}
    _, global_only_metrics, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=[feature], local_norm_endo_only=False, local_norm=["lnone"],start_time=start_time, end_time=end_time))
    metrics["global_only"] = global_only_metrics
    for local_norm in local_norms:
        for local_norm_endo_only in [True, False]:
            _, metr, _ = get_metrics(get_runs(dataset=dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, feature=[feature], local_norm_endo_only=local_norm_endo_only, local_norm=[local_norm], start_time=start_time, end_time=end_time))
            metrics[f"{local_norm}_{local_norm_endo_only}"] = metr
    dataset_metrics[dataset] = metrics

Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 40228.54it/s]

Model xgboost: 3
Model linear: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model nbeatsx: 3
Model gpt4ts: 3
Model timexer: 3
Model mlp: 3
Model gp: 3
Model kalmanfilter: 3
Model mole: 3
Model msar: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 7711.47it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 9218.73it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model msar: 3
Model timexer: 3
Model mole: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 3856.07it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 19246.23it/s]

Model xgboost: 3
Model linear: 3
Model gpt4ts: 3
Model simpletm: 3
Model nbeatsx: 3
Model timesnet: 3
Model timexer: 3
Model patchtst: 3
Model adamshyper: 3
Model mlp: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 20632.56it/s]

Model linear: 3
Model xgboost: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model mole: 3
Model msar: 3
Model timexer: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model mlp: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model gpt4ts: 3
Model kalmanfilter: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model gp: 3
Model mlp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 1817.72it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
Model mlp: 3
Model kalmanfilter: 3
Model gp: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<?, ?it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





Found 42 runs.


100%|██████████| 42/42 [00:00<00:00, 20941.60it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3





In [60]:
experiments = ["global_only", "local_z_True", "local_z_False", "difference_True", "difference_False"]

def compute_matrix(models: list[str], metrics: dict):
    all_model_values: list[list[float]] = []
    for model in models:
        values: list[float] = []
        for experiment in experiments:
            values.append(metrics[experiment][model][30][3]["MAE"])
        all_model_values.append(values)
    
    return np.array(all_model_values)


datasets = ["wildppg"]
norms = [
    "GZ",
    "GZ + Inst(Endo)",
    "GZ + Inst",
    "GZ + Diff(Endo)",
    "GZ + Diff",
]

data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(BASELINES, metrics),
        compute_matrix(DL, metrics),
    )

In [68]:
datasets = ["dalia", "wildppg", "ieee"]
norms = ["GZ", "GZ + Inst(Endo)", "GZ + Inst", "GZ + Diff(Endo)", "GZ + Diff"]

norm_colors = {
    "GZ":               "#4E79A7",
    "GZ + Inst(Endo)":  "#F28E2B",
    "GZ + Inst":        "#E15759",
    "GZ + Diff(Endo)":  "#76B7B2",
    "GZ + Diff":        "#59A14F",
}

#bm = [b for b in BASELINES if b != "msar"]
bm = BASELINES.copy()
data = {}
for d in datasets:
    metrics = dataset_metrics[d]
    data[d] = (
        compute_matrix(bm, metrics),
        compute_matrix(DL, metrics),         
    )

x_models_baseline = [model_to_abbr[m] for m in bm]
x_models_dl       = [model_to_abbr[m] for m in DL]
x_models_all      = x_models_baseline + x_models_dl

all_vals = []
for d in datasets:
    bz, dz = data[d]
    all_vals.append(np.concatenate([bz, dz], axis=0))
all_vals = np.concatenate(all_vals, axis=0)
ymin, ymax = float(np.nanmin(all_vals)), float(np.nanmax(all_vals))
pad = 0.05 * (ymax - ymin if ymax > ymin else 1.0)
yrange = [ymin - pad, ymax + pad]

fig = make_subplots(
    rows=len(datasets), cols=1,
    shared_xaxes=False, shared_yaxes=False,
    subplot_titles=[f"<b>{dataset_to_name[d]} </b>" for d in datasets],
    vertical_spacing=0.10
)

legend_shown_for_norm = {n: False for n in norms}

for r, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]          
    z_full = np.vstack([base_z, dl_z])

    for j, norm in enumerate(norms):
        fig.add_trace(
            go.Bar(
                x=x_models_all,
                y=z_full[:, j],
                name=norm,
                legendgroup=norm,
                marker_color=norm_colors[norm],
                showlegend=(not legend_shown_for_norm[norm] and r == 1),
                hovertemplate="Model: %{x}<br>Norm: " + norm + "<br>MAE: %{y:.3f}<extra></extra>",
            ),
            row=r, col=1
        )
        if r == 1:
            legend_shown_for_norm[norm] = True

    fig.update_yaxes(title_text="MAE", row=r, col=1)
    fig.update_xaxes(
        title_text="Model",
        categoryorder="array",
        categoryarray=x_models_all,
        row=r, col=1
    )

    fig.add_vline(
        x=6.5, line_width=1, line_dash="dash", line_color="rgba(0,0,0,0.35)",
        row=r, col=1
    )

fig.update_layout(
    barmode="group",          
    bargap=0.20,
    bargroupgap=0.10,
    height=1000,
    width=1400,
    template="plotly_white",
    margin=dict(t=90, l=60, r=40, b=60),
    legend=dict(
        title="Normalization",
        orientation="h",
        yanchor="bottom",
        y=-0.1,
        xanchor="left",
        x=0.25
    )
)

fig.show()


# Lookback & Horizon Ablation

In [48]:
def get_best_value(
    metric_dict: dict,
    std_dict: dict,
    lbw: int,
    pw: int,
    dataset: str,
    models: list[str],
    metric: str,
) -> Tuple[float, float]:
    best_val: float = np.inf
    best_std: float = np.inf
    for model in models:
        val = (
            metric_dict[model][lbw][pw][metric]
            if metric in metric_dict[model][lbw][pw]
            else np.nan
        )
        if val < best_val:
            best_val = val
            best_std = std_dict[model][lbw][pw][metric]

    return best_val, best_std


def get_df(
    look_back_window: list[int],
    metric_dict: dict,
    std_dict: dict,
    dataset: str,
    models: list[str],
    metric: str,
    pw: int = 3,
):
    means: list[float] = []
    stds: list[float] = []
    for lbw in look_back_window:
        mean, std = get_best_value(
            metric_dict, std_dict, lbw, pw, dataset, models, metric
        )
        means.append(mean)
        stds.append(std)

    d = {"means": means, "stds": stds}
    df = pd.DataFrame.from_dict(d)
    return df



def get_df_pw(
    lbw: int,
    metric_dict: dict,
    std_dict: dict,
    dataset: str,
    models: list[str],
    metric: str,
    pws: list[int] = [1,3,5,10,20],
):
    means: list[float] = []
    stds: list[float] = []
    for pw in pws:
        mean, std = get_best_value(
            metric_dict, std_dict, lbw, pw, dataset, models, metric
        )
        means.append(mean)
        stds.append(std)

    d = {"means": means, "stds": stds}
    df = pd.DataFrame.from_dict(d)
    return df

In [42]:
datasets = ["dalia", "wildppg", "ieee"]
models = MODELS
look_back_window = [5, 10, 20, 30, 60]
prediction_window = [3]

start_time = "2025-11-01"
end_time = "2025-11-02"


assert len(prediction_window) == 1
pw = prediction_window[0]

lbw_metrics = {}

for dataset in datasets:
    runs_exo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["mean"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, exo_mean, exo_std = get_metrics(runs_exo)

    runs_endo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["none"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, endo_mean, endo_std = get_metrics(runs_endo)

    lbw_metrics[dataset] = (exo_mean, exo_std, endo_mean, endo_std)

Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 13209.61it/s]

Model linear: 15
Model xgboost: 15
Model simpletm: 15
Model patchtst: 15
Model timexer: 15
Model adamshyper: 15
Model timesnet: 15
Model mlp: 15
Model gpt4ts: 15
Model kalmanfilter: 15
Model nbeatsx: 15
Model gp: 15
Model mole: 15
Model msar: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<?, ?it/s]


Model xgboost: 15
Model linear: 15
Model simpletm: 15
Model timesnet: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model adamshyper: 15
Model mole: 15
Model mlp: 15
Model nbeatsx: 15
Model gp: 15
Model msar: 15
Model kalmanfilter: 15
Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 13338.84it/s]

Model linear: 15
Model xgboost: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<?, ?it/s]

Model linear: 15
Model xgboost: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 13159.09it/s]

Model xgboost: 15
Model linear: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 10427.17it/s]

Model xgboost: 15
Model linear: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





In [43]:
start_time = "2025-11-01"
end_time = "2025-11-05"

datasets = ["dalia", "wildppg", "ieee"]
models = MODELS
look_back_window = [20]
prediction_window = [1,3,5,10,20]

assert len(look_back_window) == 1
pw = look_back_window[0]

pw_metrics = {}

for dataset in datasets:
    runs_exo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["mean"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, exo_mean, exo_std = get_metrics(runs_exo)

    runs_endo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["none"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, endo_mean, endo_std = get_metrics(runs_endo)

    pw_metrics[dataset] = (exo_mean, exo_std, endo_mean, endo_std)

Found 210 runs.


100%|██████████| 210/210 [00:00<?, ?it/s]

Model linear: 15
Model xgboost: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 17052.99it/s]

Model xgboost: 15
Model linear: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 15611.55it/s]

Model linear: 15
Model xgboost: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model kalmanfilter: 15
Model msar: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<?, ?it/s]

Model linear: 15
Model xgboost: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 16102.45it/s]

Model xgboost: 15
Model linear: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





Found 210 runs.


100%|██████████| 210/210 [00:00<00:00, 142340.63it/s]

Model xgboost: 15
Model linear: 15
Model timesnet: 15
Model simpletm: 15
Model adamshyper: 15
Model patchtst: 15
Model timexer: 15
Model gpt4ts: 15
Model nbeatsx: 15
Model mole: 15
Model msar: 15
Model kalmanfilter: 15
Model gp: 15
Model mlp: 15





In [57]:
use_std = True

TITLE_SIZE = 60
AXIS_TITLE_SIZE = 35
TICK_SIZE = 32
LEGEND_SIZE = 60
LINE_WIDTH = 3
MARKER_SIZE = 5
LINE_OPACITY = 0.8

SUBPLOT_TITLE_SIZE = 15
LEGEND_Y = -0.18

titles = [f"<b>{dataset_to_name[d]}</b>" for d in datasets]
fig = make_subplots(
    rows=2, cols=len(datasets), subplot_titles=titles, horizontal_spacing=0.08, vertical_spacing=0.1, row_titles=["Lookback Ablation", "Horizon Ablation"]
)

metric = "MAE"


for j, dataset in enumerate(datasets, start=1):

    exo_mean, exo_std, endo_mean, endo_std = lbw_metrics[dataset]
    x_lbw = [5, 10, 20, 30, 60]
    pw = 3

    exo_bl_df = get_df(
        x_lbw, exo_mean, exo_std, dataset, BASELINES, metric, pw
    )
    endo_bl_df = get_df(
        x_lbw, endo_mean, endo_std, dataset, BASELINES, metric, pw
    )
    exo_dl_df = get_df(x_lbw, exo_mean, exo_std, dataset, DL, metric, pw)
    endo_dl_df = get_df(
        x_lbw, endo_mean, endo_std, dataset, DL, metric, pw
    )

    # for _df in (exo_bl_df, endo_bl_df, exo_dl_df, endo_dl_df):
    #     _df.index = x_lbw

    x_labels = [str(x_val) for x_val in x_lbw]
    bl_color = model_colors["baseline"]
    dl_color = model_colors["dl"]

    def add_line(x, y, stds, name, color, dash, row: int = 1):
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines+markers",
                line=dict(width=LINE_WIDTH, color=color, dash=dash),
                marker=dict(size=MARKER_SIZE),
                name=name,
                showlegend=(j == 1) and (row == 1),
                error_y=(
                    dict(type="data", array=stds, visible=True) if use_std else None
                ),
            ),
            row=row,
            col=j,
        )
    
    # lbw ablation
    add_line(
        x_lbw,
        exo_bl_df["means"].values,
        exo_bl_df["stds"].values,
        "Baselines — Exo (best)",
        bl_color,
        "solid",
    )
    add_line(
        x_lbw,
        endo_bl_df["means"].values,
        endo_bl_df["stds"].values,
        "Baselines — Endo (best)",
        bl_color,
        "dash",
    )

    add_line(
        x_lbw,
        exo_dl_df["means"].values,
        exo_dl_df["stds"].values,
        "DL — Exo (best)",
        dl_color,
        "solid",
    )
    add_line(
        x_lbw,
        endo_dl_df["means"].values,
        endo_dl_df["stds"].values,
        "DL — Endo (best)",
        dl_color,
        "dash",
    )

    fig.update_xaxes(
        title_text="Lookback Window",
        type="category",
        categoryorder="array",
        categoryarray=x_labels,
        row=1,
        col=j,
    )
    fig.update_yaxes(title_text=f"Best {metric}", row=1, col=j)

    # pw ablation 
    exo_mean, exo_std, endo_mean, endo_std = pw_metrics[dataset]
    x_pw = [1, 3, 5, 10, 20]
    lbw = 20
    pw_labels = [str(x_val) for x_val in x_pw]

    exo_bl_df = get_df_pw(
        lbw, exo_mean, exo_std, dataset, BASELINES, metric, x_pw
    )
    endo_bl_df = get_df_pw(
        lbw, endo_mean, endo_std, dataset, BASELINES, metric, x_pw
    )
    exo_dl_df = get_df_pw(lbw, exo_mean, exo_std, dataset, DL, metric, x_pw)
    endo_dl_df = get_df_pw(
        lbw, endo_mean, endo_std, dataset, DL, metric, x_pw
    )

    add_line(
        x_pw,
        exo_bl_df["means"].values,
        exo_bl_df["stds"].values,
        "Baselines — Exo (best)",
        bl_color,
        "solid",
        row=2
    )
    add_line(
        x_pw,
        endo_bl_df["means"].values,
        endo_bl_df["stds"].values,
        "Baselines — Endo (best)",
        bl_color,
        "dash",
        row=2
    )

    add_line(
        x_pw,
        exo_dl_df["means"].values,
        exo_dl_df["stds"].values,
        "DL — Exo (best)",
        dl_color,
        "solid",
        row=2
    )
    add_line(
        x_pw,
        endo_dl_df["means"].values,
        endo_dl_df["stds"].values,
        "DL — Endo (best)",
        dl_color,
        "dash",
        row=2
    )

    fig.update_xaxes(
        title_text="Prediction Window",
        type="category",
        categoryorder="array",
        categoryarray=pw_labels,
        row=2,
        col=j,
    )
    fig.update_yaxes(title_text=f"Best {metric}", row=2, col=j)


# layout & legend
fig.update_annotations(font=dict(size=SUBPLOT_TITLE_SIZE))
fig.update_layout(
    legend=dict(
        orientation="h", x=0.5, xanchor="center", y=LEGEND_Y, yanchor="top"
    ),
    margin=dict(b=120),
)
fig.update_xaxes(title_font=dict(size=14))
fig.update_yaxes(title_font=dict(size=14))

subplot_size = 700  # pixels per subplot
rows = 1
cols = len(datasets)

total_width = cols * subplot_size
total_height = rows * subplot_size

fig.update_layout(
    width=total_width,
    height=total_height,
)

fig.show()

# Feature Ablation

In [36]:
datasets = ["dalia", "wildppg"]
look_back_window = [30]
prediction_window = [3]
start_time = "2025-10-28"
end_time = "2025-11-05"
features = ["none","mean", "rms_last2s_rms_jerk_centroid", "catch22"] 
model_to_norm = [("difference", BASELINES), ("local_z", DL)]
feature_metrics = {}
for dataset in datasets:
    d_metrics = {}
    for feature in features:
        print(f"{dataset} {feature}")
        metr = {}
        for (norm, models) in model_to_norm:
            print(norm)
            runs = get_runs(dataset, look_back_window=look_back_window, prediction_window=prediction_window, models=models, start_time=start_time, end_time=end_time, feature=[feature], local_norm_endo_only=False if feature == "none" else True, local_norm=[norm])
            _, c_metr, _ = get_metrics(runs)
            for k, v in c_metr.items():
                metr[k] = v
        d_metrics[feature] = metr
    feature_metrics[dataset] = d_metrics

dalia none
difference
Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
dalia mean
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
dalia rms_last2s_rms_jerk_centroid
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 2303.06it/s]

Model xgboost: 3
Model linear: 3
Model mole: 3
Model mlp: 3
Model gp: 3
Model kalmanfilter: 3
Model msar: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model simpletm: 3
Model timesnet: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
dalia catch22
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model xgboost: 3
Model linear: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
wildppg none
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 2527.99it/s]

Model xgboost: 3
Model linear: 3
Model mole: 3
Model msar: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
wildppg mean
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 20299.70it/s]

Model xgboost: 3
Model linear: 3
Model msar: 3
Model mole: 3
Model kalmanfilter: 3
Model gp: 3
Model mlp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 1884.48it/s]

Model timesnet: 3
Model simpletm: 3
Model adamshyper: 3
Model patchtst: 3
Model timexer: 3
Model gpt4ts: 3
Model nbeatsx: 3
wildppg rms_last2s_rms_jerk_centroid
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 4848.11it/s]

Model linear: 3
Model xgboost: 3
Model kalmanfilter: 3
Model mlp: 3
Model mole: 3
Model msar: 3
Model gp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<?, ?it/s]

Model timexer: 3
Model nbeatsx: 3
Model gpt4ts: 3
Model adamshyper: 3
Model timesnet: 3
Model simpletm: 3
Model patchtst: 3
wildppg catch22
difference





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 5049.61it/s]

Model linear: 3
Model xgboost: 3
Model mole: 3
Model mlp: 3
Model kalmanfilter: 3
Model msar: 3
Model gp: 3
local_z





Found 21 runs.


100%|██████████| 21/21 [00:00<00:00, 5146.99it/s]

Model simpletm: 3
Model timexer: 3
Model gpt4ts: 3
Model patchtst: 3
Model timesnet: 3
Model nbeatsx: 3
Model adamshyper: 3





In [37]:
import plotly.colors as pc

feature_to_name = {"none": "HR only", "mean": "Mean IMU", "rms_last2s_rms_jerk_centroid": "RMS + Last2s + Jerk + Centroid", "catch22": "Catch22"}

metric = "MAE"
look_back_window = 30
prediction_window = 3

models = MODELS.copy()
models.remove("msar")

# Fixed colors per feature (stable across subplots)
feature_colors = {
    feature: pc.qualitative.Plotly[i % len(pc.qualitative.Plotly)]
    for i, feature in enumerate(features)
}

fig = make_subplots(
    rows=len(datasets),
    cols=1,
    shared_xaxes=True,
    subplot_titles=[f"<b>{dataset_to_name[dataset]}</b>" for dataset in datasets],
    vertical_spacing=0.08,
)

for row, dataset in enumerate(datasets, start=1):
    for feature in features:
        x_vals = []
        y_vals = []

        metr_feature = feature_metrics[dataset][feature]

        for model in models:
            if (
                model in metr_feature
                and look_back_window in metr_feature[model]
                and prediction_window in metr_feature[model][look_back_window]
                and metric in metr_feature[model][look_back_window][prediction_window]
            ):
                val = metr_feature[model][look_back_window][prediction_window][metric]
            else:
                val = np.nan

            x_vals.append(model_to_abbr[model])
            y_vals.append(val)

        fig.add_trace(
            go.Bar(
                x=x_vals,
                y=y_vals,
                name=feature_to_name[feature],
                showlegend=(row == 1),
                legendgroup=feature,
                marker_color=feature_colors[feature],  # fixed color
            ),
            row=row,
            col=1,
        )

fig.update_layout(
    barmode="group",
    height=300 * len(datasets),
    title_text="<b>Feature Ablation MAE Performance</b>",
    title_x = 0.5,
    template="plotly_white",
    legend_title_text="Feature",
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-0.2,          # move underneath plots
        xanchor="center",
        x=0.5
    )
)

fig.update_xaxes(title_text="Model", row=len(datasets), col=1)

for r in range(1, len(datasets) + 1):
    fig.update_yaxes(title_text=metric, row=r, col=1)

fig.show()




# Local Results

In [42]:
start_time = "2025-11-06"
end_time = "2025-11-07"

datasets = ["ldalia", "lwildppg"]
models = MODELS
look_back_window = [30]
prediction_window = [3]

dataset_metrics = {}

for dataset in datasets:
    runs_exo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["mean"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, exo_mean, exo_std = get_metrics(runs_exo)

    runs_endo = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=["none"],
        start_time=start_time,
        end_time=end_time,
        local_norm_endo_only=False,
        local_norm=["local_z", "difference"]
    )
    _, endo_mean, endo_std = get_metrics(runs_endo)

    dataset_metrics[dataset] = (exo_mean, exo_std, endo_mean, endo_std)

Found 14 runs.


  7%|▋         | 1/14 [00:00<00:08,  1.50it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
 14%|█▍        | 2/14 [00:02<00:15,  1.26s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 21%|██▏       | 3/14 [00:04<00:16,  1.52s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 29%|██▊       | 4/14 [00:05<00:16,  1.64s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 36%|███▌      | 5/14 [00:08<00:17,  1.93s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 43%|████▎     | 6/14 [00:10<00:16,  2.11s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 50%|█████     | 7/14 [00:12<00:14,  2.06s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 57%|█████▋    | 8/14 [00:14<00:11,  1.97s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 64%|██████▍   | 9/14 [00:16<00:09,  1.91s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 71%|███████▏  | 10/14 [00:18<00:07,  1.96s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 79%|███████▊  | 11/14 [00:20<00:06,  2.09s/it][

Model patchtst: 1
Model simpletm: 1
Model timexer: 1
Model linear: 1
Model adamshyper: 1
Model timesnet: 1
Model xgboost: 1
Model nbeatsx: 1
Model gpt4ts: 1
Model kalmanfilter: 1
Model mlp: 1
Model msar: 1
Model gp: 1
Model mole: 1
Found 14 runs.


  0%|          | 0/14 [00:00<?, ?it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  7%|▋         | 1/14 [00:02<00:31,  2.40s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 14%|█▍        | 2/14 [00:05<00:30,  2.56s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 21%|██▏       | 3/14 [00:07<00:26,  2.44s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 29%|██▊       | 4/14 [00:09<00:24,  2.45s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 36%|███▌      | 5/14 [00:12<00:21,  2.40s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 43%|████▎     | 6/14 [00:14<00:19,  2.43s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 50%|█████     | 7/14 [00:17<00:17,  2.51s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 57%|█████▋    | 8/14 [00:19<00:14,  2.46s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 64%|██████▍   | 9/14 [00:22<00:12,  2.43s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 71%|███████▏  | 10/14 [00:24<00:09,  2.43s/it][34m[1mwa

Model timexer: 1
Model simpletm: 1
Model nbeatsx: 1
Model patchtst: 1
Model gpt4ts: 1
Model adamshyper: 1
Model xgboost: 1
Model timesnet: 1
Model linear: 1
Model kalmanfilter: 1
Model mole: 1
Model mlp: 1
Model gp: 1
Model msar: 1
Found 14 runs.


  0%|          | 0/14 [00:00<?, ?it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  7%|▋         | 1/14 [00:02<00:31,  2.43s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 14%|█▍        | 2/14 [00:05<00:30,  2.54s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 21%|██▏       | 3/14 [00:07<00:25,  2.34s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 29%|██▊       | 4/14 [00:09<00:24,  2.47s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 36%|███▌      | 5/14 [00:12<00:22,  2.49s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 43%|████▎     | 6/14 [00:14<00:18,  2.25s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 50%|█████     | 7/14 [00:16<00:16,  2.30s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 57%|█████▋    | 8/14 [00:18<00:12,  2.14s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 64%|██████▍   | 9/14 [00:20<00:11,  2.23s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 71%|███████▏  | 10/14 [00:22<00:08,  2.11s/it][34m[1mwa

Model linear: 1
Model timesnet: 1
Model xgboost: 1
Model simpletm: 1
Model nbeatsx: 1
Model timexer: 1
Model gpt4ts: 1
Model patchtst: 1
Model adamshyper: 1
Model mlp: 1
Model msar: 1
Model kalmanfilter: 1
Model mole: 1
Model gp: 1
Found 14 runs.


  0%|          | 0/14 [00:00<?, ?it/s][34m[1mwandb[0m:   1 of 1 files downloaded.  
  7%|▋         | 1/14 [00:02<00:32,  2.47s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 14%|█▍        | 2/14 [00:04<00:28,  2.40s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 21%|██▏       | 3/14 [00:07<00:27,  2.49s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 29%|██▊       | 4/14 [00:09<00:24,  2.48s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 36%|███▌      | 5/14 [00:12<00:22,  2.46s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 43%|████▎     | 6/14 [00:14<00:19,  2.45s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 50%|█████     | 7/14 [00:16<00:16,  2.38s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 57%|█████▋    | 8/14 [00:18<00:13,  2.21s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 64%|██████▍   | 9/14 [00:20<00:10,  2.15s/it][34m[1mwandb[0m:   1 of 1 files downloaded.  
 71%|███████▏  | 10/14 [00:23<00:08,  2.25s/it][34m[1mwa

Model simpletm: 1
Model timesnet: 1
Model xgboost: 1
Model linear: 1
Model nbeatsx: 1
Model timexer: 1
Model gpt4ts: 1
Model patchtst: 1
Model adamshyper: 1
Model mlp: 1
Model kalmanfilter: 1
Model msar: 1
Model mole: 1
Model gp: 1





In [52]:
def build_df(metrics: dict, lbw: int = 30, pw: int = 3, metric: str = "MAE"):
    datasets = list(metrics.keys())
    
    columns = [model_to_abbr[model] for model in MODELS]  
    index = []  
    data = []   

    for dataset in datasets:
        exo_mean, _, endo_mean, _ = metrics[dataset]
        row_endo = []
        row_exo = []
        row_imprv = []

        for model in MODELS:
            endo_val = endo_mean[model][lbw][pw][metric]
            exo_val  = exo_mean[model][lbw][pw][metric]
            imprv    = 100 * ((endo_val - exo_val) / endo_val)

            row_endo.append(endo_val)
            row_exo.append(exo_val)
            row_imprv.append(imprv)

        index.append((dataset, "Endo"))
        data.append(row_endo)

        index.append((dataset, "Exo"))
        data.append(row_exo)

        index.append((dataset, "Imprv"))
        data.append(row_imprv)

    index = pd.MultiIndex.from_tuples(index, names=["Dataset", "Metric"])
    df = pd.DataFrame(data, index=index, columns=columns)
    return df

df = build_df(dataset_metrics)

latex_str = df.to_latex(index=True, 
                        header=True, 
                        float_format="%.3f"  
                       )
print(latex_str)

\begin{tabular}{llrrrrrrrrrrrrrr}
\toprule
 &  & LR & MoLE & MSAR & KF & XGB & GP & MLP & TNET & STM & AMSH & PTST & TXER & GPT & NBX \\
Dataset & Metric &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
\midrule
\multirow[t]{3}{*}{ldalia} & Endo & 2.142 & 2.283 & 2.207 & 2.251 & 2.452 & 2.188 & 2.446 & 2.612 & 2.340 & 2.403 & 4.699 & 2.561 & 3.335 & 2.239 \\
 & Exo & 2.121 & 2.280 & 2.146 & 2.265 & 2.444 & 2.281 & 2.379 & 2.725 & 2.308 & 2.366 & 4.716 & 2.540 & 3.417 & 2.236 \\
 & Imprv & 0.965 & 0.156 & 2.739 & -0.629 & 0.337 & -4.270 & 2.737 & -4.325 & 1.366 & 1.542 & -0.364 & 0.790 & -2.438 & 0.108 \\
\cline{1-16}
\multirow[t]{3}{*}{lwildppg} & Endo & 1.555 & 1.623 & 1.568 & 1.594 & 1.527 & 1.552 & 1.624 & 1.626 & 1.641 & 1.646 & 1.581 & 1.803 & 2.219 & 1.698 \\
 & Exo & 1.546 & 1.622 & 1.579 & 1.596 & 1.518 & 1.573 & 1.602 & 1.707 & 1.620 & 1.628 & 1.583 & 1.776 & 2.204 & 1.711 \\
 & Imprv & 0.585 & 0.085 & -0.706 & -0.136 & 0.586 & -1.371 & 1.359 & -4.993 & 1.251 & 1.095 & -0.128 & 1.

# Qualitative Plots

In [None]:
# STYLING
TITLE_SIZE = 28
SUBTITLE_SIZE = 20
AXIS_TITLE_SIZE = 16
TICK_SIZE = 8
LEGEND_SIZE = 16
EXO_ENDO_LEGEND_SIZE = 10
GT_LINE_WIDTH = 3.8
PRED_LINE_WIDTH = 2.4
PRED_OPACITY = 0.8
GT_MARKER_SIZE = 6

best_model = {"dalia": "xgboost", "wildppg": "timesnet", "ieee": "linear"}

start_time = "2025-11-01"
end_time = "2025-11-20"

def load_predictions(
    dataset: str,
    models: list[str],
    look_back_window: list[int],
    prediction_window: list[int],
    feature: str,
    metric: str = "MAE",
    artifacts_path: str = "C:/Users/cleme/ETH/Master/Thesis/ns-forecast/src/notebooks/artifacts",
) -> dict:
    runs = get_runs(
        dataset,
        look_back_window,
        prediction_window,
        models,
        feature=[feature],
        predictions=True,
        start_time=start_time,
        end_time=end_time
    )

    loaded_preds = defaultdict(dict)

    for run in runs:
        # download artifcats
        config = run.config
        model = config["model"]["name"]
        print(f"Processing {model_to_name[model]}")
        raw_artifact = next(
            (a for a in run.logged_artifacts() if "predictions" in a.name), None
        )
        if raw_artifact is None:
            print(f"No predictions for run {run.name}")
            continue
        else:
            art_dir = Path(artifacts_path) / str(raw_artifact.name).replace(":", "-")

            if not os.path.exists(art_dir):
                raw_artifact.download()

            json_path = art_dir / "test_predictions_metrics.json"
            npz_path = art_dir / "test_predictions.npz"
            with open(json_path, "r", encoding="utf-8") as f:
                obj = json.load(f)
            metrics = obj[metric]
            np_arrays = np.load(npz_path, allow_pickle=True)
            loaded_preds[model]["metrics"] = metrics
            loaded_preds[model]["arrays"] = np_arrays

            print(f"Successfully loaded {art_dir}")

    return loaded_preds


def get_index_and_gt(
    loaded_preds, dataset: str, plot_type: str, window: int
) -> Tuple[int, NDArray[np.float32], NDArray[np.float32]]:
    print(loaded_preds.keys())
    best_model_metrics = loaded_preds[best_model[dataset]]["metrics"]
    sorted_idx = np.argsort(best_model_metrics)
    if plot_type == "worst":
        worst_idx = sorted_idx[-1]
    elif plot_type == "best":
        worst_idx = sorted_idx[0]
    else:
        worst_idx = sorted_idx[len(sorted_idx) // 2]

    print(f"worst value: {best_model_metrics[worst_idx]}")

    x = np.arange(window) * 2
    any_preds = loaded_preds[best_model[dataset]]["arrays"]
    gt_series = any_preds["gt_series"]
    n_windows_per_series = [len(s) - window + 1 for s in gt_series]
    cum_lengths = np.cumsum([0] + n_windows_per_series)
    gt_idx = np.searchsorted(cum_lengths, worst_idx, side="right") - 1
    window_pos = worst_idx - cum_lengths[gt_idx]
    gt = gt_series[gt_idx][window_pos : window_pos + window, :]

    return worst_idx, gt, x


def get_index_and_gt_exo(
    loaded_endo,
    loaded_exo,
    dataset: str,
    plot_type: str,
    window: int,
    use_imprv: bool = False,
) -> Tuple[int, NDArray[np.float32], NDArray[np.float32]]:
    endo_metrics = np.array(loaded_endo[best_model[dataset]]["metrics"])
    exo_metrics = np.array(loaded_exo[best_model[dataset]]["metrics"])
    diff = exo_metrics - endo_metrics
    imprv = diff / endo_metrics

    fig = go.Figure()
    fig.add_histogram(
        x=imprv,
        nbinsx=30,
        opacity=0.75,
    )
    fig.show()
    sorted_idx = np.argsort(imprv if use_imprv else diff)
    if plot_type == "worst":
        worst_idx = sorted_idx[-1]
    elif plot_type == "best":
        worst_idx = sorted_idx[0]
    else:
        worst_idx = sorted_idx[len(sorted_idx) // 2]

    print(f"worst value: {diff[worst_idx]}")

    x = np.arange(window) * 2
    exo_preds = loaded_exo[best_model[dataset]]["arrays"]
    gt_series = exo_preds["gt_series"]
    n_windows_per_series = [len(s) - window + 1 for s in gt_series]
    cum_lengths = np.cumsum([0] + n_windows_per_series)
    gt_idx = np.searchsorted(cum_lengths, worst_idx, side="right") - 1
    window_pos = worst_idx - cum_lengths[gt_idx]
    gt = gt_series[gt_idx][window_pos : window_pos + window, :]

    return worst_idx, gt, x


def plot_predictions(
    datasets: list[str],
    look_back_window: list[int] = [20],
    prediction_window: list[int] = [10],
    models: list[str] = MODELS,
    feature: str = "none",
    plot_type: str = "worst",
):
    assert len(look_back_window) == 1
    assert len(prediction_window) == 1
    lbw = look_back_window[0]
    pw = prediction_window[0]

    window = lbw + pw

    print(feature)

    fig = make_subplots(
        rows=1 if feature == "none" else 2,
        cols=len(datasets),
        shared_xaxes=True,
        subplot_titles=[f"<b>{dataset_to_name[d]}</b>" for d in datasets],
        horizontal_spacing=0.03,
        vertical_spacing=0.1,
        row_heights=[1.0] if feature == "none" else [0.8, 0.2],
    )

    for col_idx, dataset in enumerate(datasets, start=1):
        loaded_preds = load_predictions(
            dataset, models, look_back_window, prediction_window, feature=feature
        )

        worst_idx, gt, x = get_index_and_gt(loaded_preds, dataset, plot_type, window)

        print(f"{worst_idx=}")

        fig.add_trace(
            go.Scatter(
                x=x,
                y=gt[:, 0],
                mode="lines+markers",  # dots + line
                name="Ground truth",
                legendgroup=f"{dataset}-gt",
                showlegend=(col_idx == 1),
                line=dict(width=GT_LINE_WIDTH, color="blue"),
                marker=dict(size=GT_MARKER_SIZE, symbol="circle"),
            ),
            row=1,
            col=col_idx,
        )

        if  feature == "mean":
            activity = gt[:, 1]
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=activity,
                    mode="lines",  # dots + line
                    name="Activity",
                    legendgroup=f"{dataset}-act",
                    showlegend=(col_idx == 1),
                    line=dict(width=GT_LINE_WIDTH, color="red"),
                ),
                row=2,
                col=col_idx,
            )

        # Add a vertical boundary at LBW
        fig.add_vline(
            x=(lbw - 1) * 2,
            line_dash="dash",
            line_width=1,
            line_color="black",
            row=1,
            col=col_idx,
        )

        for model in loaded_preds.keys():
            # load model artifact & plot prediction
            arrays = loaded_preds[model]["arrays"]
            preds = arrays["preds"]
            print(f"Length of preds {len(preds)} for {model}")
            pred = preds[worst_idx][:, 0]

            y_pred_full = np.full(window, np.nan, dtype=float)
            y_pred_full[lbw:] = pred[:pw]
            y_pred_full[lbw - 1] = gt[lbw - 1, 0]

            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=y_pred_full,
                    mode="lines",
                    name=f"{model_to_name[model]}",
                    legendgroup=f"{model_to_name[model]}",
                    showlegend=(col_idx == 1),  # one legend entry across subplots
                    line=dict(width=PRED_LINE_WIDTH, color=model_colors[model]),
                    opacity=PRED_OPACITY,
                ),
                row=1,
                col=col_idx,
            )

    fig.update_layout(
        template="seaborn",
        height=400,  # tweak as you like
        width=1200,  # tweak as you like
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.4,  # push legend below plot
            xanchor="center",
            x=0.5,
            font=dict(size=LEGEND_SIZE),
            bgcolor="rgba(0,0,0,0)",
        ),
        margin=dict(l=80, r=30, t=90, b=130),
        font=dict(size=16),  # base font (ticks inherit unless overridden)
    )

    # Axis titles & tick label sizes
    fig.update_xaxes(
        title_text="Time (s)",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=1 if feature == "none" else 2,
    )
    fig.update_yaxes(
        title_text="HR",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=1,
        col=1,
    )
    fig.update_yaxes(
        title_text="IMU",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=2,
        col=1,
    )

    for col in range(2, len(datasets) + 1):
        fig.update_yaxes(
            tickfont=dict(size=TICK_SIZE),
            col=col,
        )

    # Make all subplot titles bigger
    for a in fig.layout.annotations:
        a.font = dict(size=SUBTITLE_SIZE)

    fig.show()


def _plot_trace(
    fig,
    loaded_preds,
    model: str,
    exo_type: str,
    gt: NDArray[np.float32],
    x: NDArray[np.float32],
    window: int,
    worst_idx: int,
    col_idx: int,
    lbw: int,
    pw: int,
) -> None:
    assert exo_type in ["Endo", "Exo"]
    # load model artifact & plot prediction
    arrays = loaded_preds[model]["arrays"]
    preds = arrays["preds"]
    pred = preds[worst_idx][:, 0]

    y_pred_full = np.full(window, np.nan, dtype=float)
    y_pred_full[lbw:] = pred[:pw]
    y_pred_full[lbw - 1] = gt[lbw - 1]

    mode = "lines+markers" if exo_type == "Exo" else "lines"
    dash = "solid" if exo_type == "Exo" else "dash"

    fig.add_trace(
        go.Scatter(
            x=x,
            y=y_pred_full,
            mode=mode,
            name=f"{model_to_name[model]} {exo_type}",
            legendgroup=f"{model_to_name[model]} {exo_type}",
            showlegend=True,
            line=dict(
                width=PRED_LINE_WIDTH,
                color=model_colors[model],
                dash=dash,
            ),
            opacity=PRED_OPACITY,
        ),
        row=1,
        col=col_idx,
    )


def plot_best_exo_improvement(
    datasets: list[str],
    look_back_window: list[int] = [20],
    prediction_window: list[int] = [10],
    models: list[str] = MODELS,
    metric: str = "MAE",
    plot_type: str = "worst",
    use_imprv: bool = True,
):
    assert len(look_back_window) == 1
    assert len(prediction_window) == 1
    lbw = look_back_window[0]
    pw = prediction_window[0]

    window = lbw + pw

    fig = make_subplots(
        rows=2,
        cols=len(datasets),
        shared_xaxes=True,
        subplot_titles=[f"<b>{dataset_to_name[d]}</b>" for d in datasets],
        horizontal_spacing=0.03,
        vertical_spacing=0.10,
        row_heights=[0.8, 0.2],
    )

    for col_idx, dataset in enumerate(datasets, start=1):
        loaded_exo = load_predictions(
            dataset,
            [best_model[dataset]],
            look_back_window,
            prediction_window,
            "mean",
        )
        loaded_endo = load_predictions(
            dataset,
            [best_model[dataset]],
            look_back_window,
            prediction_window,
            "none",
        )

        idx, gt, x = get_index_and_gt_exo(
            loaded_exo=loaded_exo,
            loaded_endo=loaded_endo,
            dataset=dataset,
            plot_type=plot_type,
            window=window,
            use_imprv=use_imprv,
        )

        hr = gt[:, 0]

        fig.add_trace(
            go.Scatter(
                x=x,
                y=hr,
                mode="lines+markers",  # dots + line
                name="Ground truth",
                legendgroup=f"{dataset}-gt",
                showlegend=(col_idx == 1),
                line=dict(width=GT_LINE_WIDTH, color="blue"),
                marker=dict(size=GT_MARKER_SIZE, symbol="circle"),
            ),
            row=1,
            col=col_idx,
        )

        activity = gt[:, 1]
        fig.add_trace(
            go.Scatter(
                x=x,
                y=activity,
                mode="lines",  # dots + line
                name="Activity",
                legendgroup=f"{dataset}-act",
                showlegend=(col_idx == 1),
                line=dict(width=GT_LINE_WIDTH, color="red"),
            ),
            row=2,
            col=col_idx,
        )

        # Add a vertical boundary at LBW
        fig.add_vline(
            x=(lbw - 1) * 2,
            line_dash="dash",
            line_width=1,
            line_color="black",
            # row=1,
            col=col_idx,
        )

        _plot_trace(
            fig,
            loaded_endo,
            best_model[dataset],
            "Endo",
            hr,
            x,
            window,
            idx,
            col_idx,
            lbw,
            pw,
        )
        _plot_trace(
            fig,
            loaded_exo,
            best_model[dataset],
            "Exo",
            hr,
            x,
            window,
            idx,
            col_idx,
            lbw,
            pw,
        )

    fig.update_layout(
        template="seaborn",
        height=400,  # tweak as you like
        width=1200,  # tweak as you like
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.4,  # push legend below plot
            xanchor="center",
            x=0.5,
            font=dict(size=EXO_ENDO_LEGEND_SIZE),
            bgcolor="rgba(0,0,0,0)",
        ),
        margin=dict(l=80, r=30, t=90, b=130),
        font=dict(size=16),  # base font (ticks inherit unless overridden)
    )

    # Axis titles & tick label sizes
    fig.update_xaxes(
        title_text="Time (s)",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=2,
    )
    fig.update_yaxes(
        title_text="HR",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=1,
        col=1,
    )
    fig.update_yaxes(
        title_text="IMU",
        title_font=dict(size=AXIS_TITLE_SIZE),
        tickfont=dict(size=TICK_SIZE),
        row=2,
        col=1,
    )

    for col in range(2, len(datasets) + 1):
        fig.update_yaxes(
            tickfont=dict(size=TICK_SIZE),
            col=col,
        )

    # Make all subplot titles bigger
    for a in fig.layout.annotations:
        a.font = dict(size=SUBTITLE_SIZE)

    fig.show()


In [87]:
datasets = ["dalia", "wildppg", "ieee"]

fig = make_subplots(rows=len(datasets), cols=3)

for dataset in datasets:
    for type in ("worst", "median", "best"):
        plot_predictions([dataset], [10], [3], ["linear", "xgboost", "timesnet"], "mean", type)

mean
Found 1 runs.
Processing TimesNet
Successfully loaded C:\Users\cleme\ETH\Master\Thesis\ns-forecast\src\notebooks\artifacts\test_predictions-njd6gzng-v0
dict_keys(['timesnet'])


KeyError: 'metrics'

In [77]:
plot_predictions(["wildppg"], [10], [3], ["xgboost", "mole", "simpletm", "timesnet", "nbeatsx"], "mean", "worst")

# plot_best_exo_improvement(["ieee"],[10], [3], plot_type="median", use_imprv=False)

mean
Found 3 runs.
Processing SimpleTM
Successfully loaded C:\Users\cleme\ETH\Master\Thesis\ns-forecast\src\notebooks\artifacts\test_predictions-b4mi63c1-v0
Processing TimesNet
Successfully loaded C:\Users\cleme\ETH\Master\Thesis\ns-forecast\src\notebooks\artifacts\test_predictions-ywbw4cry-v0
Processing NBeatsX


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m:   2 of 2 files downloaded.  


Successfully loaded C:\Users\cleme\ETH\Master\Thesis\ns-forecast\src\notebooks\artifacts\test_predictions-t70slw1r-v0
worst value: 27.972320556640625
worst_idx=np.int64(59874)
Length of preds 98504 for simpletm
Length of preds 98504 for timesnet
Length of preds 98504 for nbeatsx
