In [2]:
import numpy as np 
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np


import os
import argparse
import pickle
import numpy as np
import scipy
import pycatch22

from pathlib import Path
from tqdm import tqdm
from numpy.lib.stride_tricks import sliding_window_view
from numpy.typing import NDArray
from scipy.io import loadmat
from scipy.signal import butter, sosfiltfilt



In [None]:
path = "../../data/"
hr = np.load(path + "hr.npy")
ecg = np.load(path + "ecg.npy")
ecgpks = np.load(path + "ecgpks.npy")
rrs = np.load(path + "rrs.npy")
rrxs = np.load(path + "rrxs.npy")

In [None]:
fs   = 128
low  = 0
high = 1920
seconds = (high - low) // fs
assert 0 <= low < high, "low/high must define a valid window"

y = ecg[low:high]
N = high - low
t = np.arange(N) / fs 

ecgpks = np.asarray(ecgpks, dtype=int)
mask = (ecgpks >= low) & (ecgpks < high)
peak_idx_abs = ecgpks[mask]          
peak_idx_win = peak_idx_abs - low    
peak_t = peak_idx_win / fs
peak_y = y[peak_idx_win]

n_prev_peaks = np.sum(ecgpks < low)
n_peaks_window = np.sum(mask)
window_rrs =  60 * (128 / rrs[n_prev_peaks: n_prev_peaks + n_peaks_window - 1])
rrs_t = (rrxs[n_prev_peaks : n_prev_peaks + n_peaks_window - 1] - low) / fs

fig = make_subplots(rows=4, cols=1, shared_xaxes=True)

fig.add_trace(
    go.Scatter(x=t, y=y), row=1, col=1
)

fig.add_trace(
    go.Scatter(x=t, y=y), row=2, col=1
)


fig.add_trace(go.Scatter(
    x=peak_t,
    y=peak_y,
    mode="markers",
    marker=dict(color="red", size=10),
    name="peaks"
), row=2, col=1)



fig.add_trace(
    go.Bar(
        x=rrs_t,
        y=window_rrs / fs,
        marker_color="green",
        name="rrs",
        width=0.05  # controls bar width
    ),
    row=3, col=1
)



n_hr = int((seconds - 8) / 2 + 1)
x_hr = np.arange(4, 8 + n_hr * 2, step=2)

fig.add_trace(go.Scatter(
        x=x_hr,
        y=hr[:n_hr],
        mode="markers",
        marker=dict(color="red", size=10),
        name="hr"
    ), row=4, col=1)

fig.update_layout(height=2000)

fig.show()

In [9]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = scipy.signal.butter(
        order, [low, high], analog=False, btype="band", output="sos"
    )
    return sos


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    sos = butter_bandpass(lowcut, highcut, fs, order=order)
    y = scipy.signal.sosfiltfilt(sos, data)
    return y


def load_wildppg_participant(path: Path):
    """
    Loads the data of a WildPPG participant and cleans it to receive nested dictionaries
    """
    loaded_data = scipy.io.loadmat(path)
    loaded_data["id"] = loaded_data["id"][0]
    if len(loaded_data["notes"]) == 0:
        loaded_data["notes"] = ""
    else:
        loaded_data["notes"] = loaded_data["notes"][0]

    for bodyloc in ["sternum", "head", "wrist", "ankle"]:
        bodyloc_data = dict()  # data structure to feed cleaned data into
        sensors = loaded_data[bodyloc][0].dtype.names
        for sensor_name, sensor_data in zip(sensors, loaded_data[bodyloc][0][0]):
            bodyloc_data[sensor_name] = dict()
            field_names = sensor_data[0][0].dtype.names
            for sensor_field, field_data in zip(field_names, sensor_data[0][0]):
                bodyloc_data[sensor_name][sensor_field] = field_data[0]
                if sensor_field == "fs":
                    bodyloc_data[sensor_name][sensor_field] = bodyloc_data[sensor_name][
                        sensor_field
                    ][0]
        loaded_data[bodyloc] = bodyloc_data
    return loaded_data


def panPeakDetect(detection, fs: int):
    """
    Jiapu Pan and Willis J. Tompkins.
    A Real-Time QRS Detection Algorithm.
    In: IEEE Transactions on Biomedical Engineering
    BME-32.3 (1985), pp. 230–236.

    Original implementation by Luis Howell luisbhowell@gmail.com, Bernd Porr, bernd.porr@glasgow.ac.uk, DOI: 10.5281/zenodo.3353396
    """
    min_distance = int(0.25 * fs)

    signal_peaks = [0]
    noise_peaks = []

    SPKI = 0.0
    NPKI = 0.0

    threshold_I1 = 0.0
    threshold_I2 = 0.0

    RR_missed = 0
    indexes = []

    missed_peaks = []
    peaks = scipy.signal.find_peaks(detection, distance=min_distance)[0]

    thres_weight = 0.125

    for index, peak in enumerate(peaks):
        if peak > 4 * fs and threshold_I1 > max(
            detection[peak - 4 * fs : peak]
        ):  # reset thresholds if we do not see any peaks anymore
            SPKI_n = max(detection[peak - 4 * fs : peak])
            NPKI = min(
                NPKI * SPKI_n / SPKI, np.percentile(detection[peak - 4 * fs : peak], 80)
            )
            SPKI = SPKI_n
            threshold_I1 = NPKI + 0.25 * (SPKI - NPKI)
            threshold_I2 = 0.5 * threshold_I1

        if detection[peak] > threshold_I1 and (peak - signal_peaks[-1]) > 0.3 * fs:
            signal_peaks.append(peak)
            indexes.append(index)
            SPKI = (
                thres_weight * detection[signal_peaks[-1]] + (1 - thres_weight) * SPKI
            )
            if RR_missed != 0:
                if signal_peaks[-1] - signal_peaks[-2] > RR_missed:
                    missed_section_peaks = peaks[indexes[-2] + 1 : indexes[-1]]
                    missed_section_peaks2 = []
                    for missed_peak in missed_section_peaks:
                        if (
                            missed_peak - signal_peaks[-2] > min_distance
                            and signal_peaks[-1] - missed_peak > min_distance
                            and detection[missed_peak] > threshold_I2
                        ):
                            missed_section_peaks2.append(missed_peak)

                    if len(missed_section_peaks2) > 0:
                        signal_missed = [detection[i] for i in missed_section_peaks2]
                        index_max = np.argmax(signal_missed)
                        missed_peak = missed_section_peaks2[index_max]
                        missed_peaks.append(missed_peak)
                        signal_peaks.append(signal_peaks[-1])
                        signal_peaks[-2] = missed_peak
            if len(signal_peaks) > 100 and thres_weight > 0.1:
                thres_weight = 0.0125

        else:
            noise_peaks.append(peak)
            NPKI = thres_weight * detection[noise_peaks[-1]] + (1 - thres_weight) * NPKI

        threshold_I1 = NPKI + 0.25 * (SPKI - NPKI)
        threshold_I2 = 0.5 * threshold_I1

        if len(signal_peaks) > 8:
            RR = np.diff(signal_peaks[-9:])
            RR_ave = int(np.mean(RR))
            RR_missed = int(1.66 * RR_ave)

    signal_peaks.pop(0)

    return signal_peaks


def pan_tompkins_detector(unfiltered_ecg, sr):
    """
    Jiapu Pan and Willis J. Tompkins.
    A Real-Time QRS Detection Algorithm.
    In: IEEE Transactions on Biomedical Engineering
    BME-32.3 (1985), pp. 230–236.

    Original implementation by Luis Howell luisbhowell@gmail.com, Bernd Porr, bernd.porr@glasgow.ac.uk, DOI: 10.5281/zenodo.3353396
    """
    maxQRSduration = 0.150  # sec
    filtered_ecg = butter_bandpass_filter(unfiltered_ecg, 5, 15, sr, order=1)

    diff = np.diff(filtered_ecg)
    squared = diff * diff

    mwa = scipy.ndimage.uniform_filter1d(squared, size=int(maxQRSduration * sr))
    # cap mwa during motion artefacts to make sure it does not screw the thresholds
    maxvals = (
        scipy.ndimage.maximum_filter1d(filtered_ecg, size=int(maxQRSduration * sr))[:-1]
        / 400
    )
    mwa = np.asarray([v if v < maxval else maxval for maxval, v in zip(maxvals, mwa)])

    mwa[: int(maxQRSduration * sr * 2)] = 0

    searchr = int(maxQRSduration * sr)
    peakfind = butter_bandpass_filter(unfiltered_ecg, 7.5, 20, sr, order=1)

    mwa_peaks = panPeakDetect(mwa, sr)
    r_peaks2 = []
    for rp in mwa_peaks:
        r_peaks2.append(
            rp - searchr + np.argmax(peakfind[rp - searchr : rp + searchr + 1])
        )
    r_peaks3 = []
    for rp in r_peaks2:
        r_peaks3.append(
            rp - 2 + np.argmax(unfiltered_ecg[rp - 2 : rp + 3])
        )  # adjust by at most 2 samples to hit raw data max
    return np.asarray(r_peaks3)


def quotient_filter(hbpeaks, outlier_over=5, sampling_rate=128, tol=0.8):
    """
    Function that applies a quotient filter similar to what is described in
    "Piskorki, J., Guzik, P. (2005), Filtering Poincare plots"
    it preserves peaks that are part of a sequence of [outlier_over] peaks with
    a tolerance of [tol]"""
    good_hbeats = []
    good_rrs = []
    good_rrs_x = []
    for i, peak in enumerate(hbpeaks[: -(outlier_over - 1)]):
        hb_intervals = [
            hbpeaks[j] - hbpeaks[j - 1] for j in range(i + 1, i + outlier_over)
        ]
        hr = 60 / ((sum(hb_intervals)) / ((outlier_over - 1) * sampling_rate))
        if (
            min(hb_intervals) > max(hb_intervals) * tol and hr > 35 and hr < 185
        ):  # -> good data
            for p in hbpeaks[i : i + outlier_over]:
                if len(good_hbeats) == 0 or p > good_hbeats[-1]:
                    good_hbeats.append(p)
                    if len(good_hbeats) > 1:
                        rr = good_hbeats[-1] - good_hbeats[-2]
                        if (
                            max(hb_intervals) * tol < rr
                            and rr < min(hb_intervals) / tol
                        ):
                            good_rrs.append(rr)
                            good_rrs_x.append((good_hbeats[-1] + good_hbeats[-2]) / 2)
    return np.array(good_hbeats), np.array(good_rrs), np.array(good_rrs_x)



In [10]:
person = 0

def filter(outlier_over: int = 5, tol: float = 0.75):
    datadir = "C:/Users/cleme/ETH/Master/Thesis/data/WildPPG/data" 
    winsize = 8  # 8s window size
    stride = 2  # 2s stride
    all_hrs: list[NDArray[np.float32]] = []
    for pidx, p in enumerate(Path(datadir).iterdir()):
        if pidx != person:
            continue
        print(pidx, " load ", p)
        part_data = load_wildppg_participant(p.absolute())

        x = part_data["ankle"]["acc_x"]["v"]
        y = part_data["ankle"]["acc_y"]["v"]
        z = part_data["ankle"]["acc_z"]["v"]
        imu = np.sqrt(x**2 + y**2 + z**2) 

        r_peaks = pan_tompkins_detector(
            part_data["sternum"]["ecg"]["v"], part_data["sternum"]["ecg"]["fs"]
        )
        ecgpks_filt, rrs, rrxs = quotient_filter(r_peaks, outlier_over=outlier_over, tol=tol)

        fs = part_data["sternum"]["ecg"]["fs"]
        hrs: list[NDArray[np.float32]] = []
        imus = []
        for win_s in tqdm(range(0, max(ecgpks_filt), stride * fs)):
            rr_in_win = rrs[
                np.logical_and(
                    rrxs > win_s,
                    rrxs < win_s + winsize * fs,
                )
            ]
            if len(rr_in_win) > 1:  # at least 2
                hrs.append(60 * len(rr_in_win) / (np.sum(rr_in_win) / fs))
            else:
                hrs.append(np.nan)  # invalid / noisy ecg

            imus.append(np.mean(imu[win_s: win_s + winsize * fs]))
        return np.array(hrs), np.array(imus), np.array(rrs)



def nan_run_lengths(x: np.ndarray):
    """
    Returns the lengths of consecutive NaN runs in a 1D array.
    Example: [1, nan, nan, 2, nan] -> [2, 1]
    """
    isnan = np.isnan(x)
    if not isnan.any():
        return np.array([], dtype=int)
    
    # Find indices where state changes (NaN <-> non-NaN)
    changes = np.diff(isnan.astype(int))
    # +1 for start of NaN run, -1 for end
    starts = np.where(changes == 1)[0] + 1
    ends   = np.where(changes == -1)[0] + 1

    # Handle edge cases: starts or ends at array boundaries
    if isnan[0]:
        starts = np.r_[0, starts]
    if isnan[-1]:
        ends = np.r_[ends, len(x)]

    run_lengths = ends - starts
    return run_lengths


In [13]:
combinations = [(5, 0.75), (5, 0.8), (7,0.75), (7,0.8)]
# rows=len(combinations) + 1
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True)
hrs = []
rrs = []
for i,(out, tol) in enumerate(combinations, start=1):
    hr, imu, rr = filter(outlier_over=out, tol=tol)
    hrs.append(hr)
    rrs.append(rr)
    time = list(range(len(hr)))
    fig.add_trace(go.Scatter(x=time, y=hr), row=1, col=1)
    nan_runs = nan_run_lengths(hr)
    print(f"Number of NANS: {np.isnan(hr).sum()} | Fraction of NANS: {np.isnan(hr).sum() / len(hr)}")
    print(f"Length of runs: {nan_runs} |# runs {len(nan_runs)} # below 4: {(nan_runs < 4).sum()}")

fig.add_trace(go.Scatter(x=time, y=imu), row=2, col=1)
fig.update_layout(height=2000)
fig

0  load  C:\Users\cleme\ETH\Master\Thesis\data\WildPPG\data\WildPPG_Part_an0.mat


100%|██████████| 22184/22184 [00:01<00:00, 18807.83it/s]


Number of NANS: 139 | Fraction of NANS: 0.006265777136675081
Length of runs: [ 4  1  3  2  1  1  5 14  1  3  2  3  4  2  1  2  1  1 11  7  1  3  3  1
  8  8 15  1  3  1  7  7  4  3  1  4] |# runs 36 # below 4: 23
0  load  C:\Users\cleme\ETH\Master\Thesis\data\WildPPG\data\WildPPG_Part_an0.mat


100%|██████████| 22184/22184 [00:01<00:00, 17678.47it/s]


Number of NANS: 190 | Fraction of NANS: 0.00856473133790119
Length of runs: [ 4  1  6  2  1  1  2  5 20  2  3  1  2  2  3  1  4  1  2  1  2  5 23  7
  1  2  3  3  1  8  8 17  3  1  5  1  7  1  9  7  3  4  5] |# runs 43 # below 4: 26
0  load  C:\Users\cleme\ETH\Master\Thesis\data\WildPPG\data\WildPPG_Part_an0.mat


100%|██████████| 22184/22184 [00:01<00:00, 17365.65it/s]


Number of NANS: 303 | Fraction of NANS: 0.01365849260728453
Length of runs: [ 1  4  1  6  2  1  4  6  7 19  1  4  7  1  2  1  3 11  6  1  4  7 45  7
  4  2  4  1  3  9  1  1  8  8 21  6  1  5 19  5  2  9 12  8  1  5 17] |# runs 47 # below 4: 18
0  load  C:\Users\cleme\ETH\Master\Thesis\data\WildPPG\data\WildPPG_Part_an0.mat


100%|██████████| 22184/22184 [00:01<00:00, 12147.67it/s]


Number of NANS: 336 | Fraction of NANS: 0.015146051208077894
Length of runs: [ 1  4  1  6  2  1  4  6  7 20  2  4  7  1  2  1  3 11  6  1  1  4  7 45
  7  8  4  2  2  9  1  3  9  4  1  4  9 33  7  1  5 19  1  5  3  9 12  8
  1  5 17] |# runs 51 # below 4: 20


In [None]:
fig = make_subplots(rows=1, cols=4, shared_xaxes=True, shared_yaxes=True)
for i, rr in enumerate(rrs, start=1): 
    fig.add_trace(go.Scatter(x=rr[:-1], y=rr[1:], mode="markers"), row=1, col=i)
fig

In [14]:
import statsmodels.api as sm
from scipy.interpolate import UnivariateSpline


def statsmodels_kalman_impute(series):
    x = np.asarray(series, dtype=float)
    mask = np.isfinite(x)

    mod = sm.tsa.UnobservedComponents(x, level='local level')
    res = mod.fit(disp=False)
    filled = res.smoothed_state[0]  # smoothed series
    return filled



def spline_interpolate(hr: np.ndarray, s: float = 0.5):
    """
    Smoothly interpolates NaNs in a 1D time series using cubic splines.
    s = smoothing factor (higher = smoother)
    """
    x = np.arange(len(hr))
    mask = np.isfinite(hr)
    if mask.sum() < 4:
        return hr.copy()  # not enough data to fit a spline

    # Fit spline only to valid points
    spline = UnivariateSpline(x[mask], hr[mask], s=s)
    hr_smooth = spline(x)
    return hr_smooth


In [15]:
hr = hrs[-1]
comp = hrs[0]
# smoothed = statsmodels_kalman_impute(hr)
smoothed = spline_interpolate(hr)
fig = make_subplots(rows=4, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(x=time, y=comp), row=1, col=1)
fig.add_trace(go.Scatter(x=time, y=hr), row=2, col=1)
fig.add_trace(go.Scatter(x=time, y=smoothed), row=3, col=1)
fig.add_trace(go.Scatter(x=time, y=imu), row=4, col=1)
fig.update_layout(height=1600)
fig

In [None]:
path = "C:/Users/cleme/ETH/Master/Thesis/data/euler/dalia_filtered_preprocessed/"

data = np.load(path + "S1.npz")

fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
hr = data["hr"]
wrist_mean = data["wrist_mean"][:, 0]
chest_mean = data["chest_mean"][:, 0]

x = list(range(len(hr)))

fig.add_trace(go.Scatter(x=x, y=hr), row=1, col=1)
fig.add_trace(go.Scatter(x=x, y=wrist_mean), row=2, col=1)
fig.add_trace(go.Scatter(x=x, y=chest_mean), row=3, col=1)

fig.update_layout(height=1200)
fig
print(f"wrist {np.corrcoef(hr, wrist_mean)}")
print(f"chest {np.corrcoef(hr, chest_mean)}")

In [None]:
from scipy.io import loadmat 
import plotly.express as px
path = "C:/Users/cleme/ETH/Master/Thesis/data/euler/IEEEPPG/Training_data/Training_data/"
mat_file = loadmat(path + "DATA_01_TYPE01.mat")
hr_mat = loadmat(path + "DATA_01_TYPE01_BPMtrace.mat")
hr = hr_mat["BPM0"][:, 0]
seconds = 30 
fs = 125
ecg = mat_file["sig"][0,:(fs * seconds)]
x = np.arange(len(ecg)) / fs
fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(x=x, y=ecg), row=1, col=1)
hr_x = np.arange(seconds, step=2) + 8 
fig.add_trace(go.Scatter(x=hr_x, y= hr[: len(hr_x)]), row=2, col=1)
fig.update_layout(height=1000)

# Normalization Experiment

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# --------------------
# Config
# --------------------
datasets = ["WildPPG", "Capture24", "DaLiA"]
norms = [
    "Global Z",
    "Global Z + Instance",
    "Global Z + Diff",
    "Global Z + Diff (Endo)",
    "Global Z + Instance (Endo)",
]

# 14 models: first 7 = Baselines, last 7 = DL (per your instruction)
models = ["LR","MoLE","MSAR","KF","XGB","GP","MLP","TNET","STM","AMSH","PTST","TXER","GPT","NBX"]
baseline_models = models[:7]
dl_models = models[7:]

# --------------------
# Mock data (replace with your real arrays)
# Shape we want per dataset: (7 models, 5 norms) for both groups
# --------------------
rng = np.random.default_rng(123)

def mock_block(n_models, n_norms, loc=10, scale=1.8):
    z = np.round(rng.normal(loc=loc, scale=scale, size=(n_models, n_norms)), 2)
    return np.clip(z, 5.0, 16.0)  # MAE bounds

# Build data dict: per dataset -> (baselines_z, dl_z)
data = {}
for d in datasets:
    # Slightly different centers per dataset for variety
    center_shift = {"WildPPG": 10.0, "Capture24": 9.5, "DaLiA": 10.5}[d]
    data[d] = (
        mock_block(len(baseline_models), len(norms), loc=center_shift, scale=1.8),
        mock_block(len(dl_models), len(norms), loc=center_shift - 0.2, scale=1.8),
    )

# Global zmin/zmax for consistent color range across all heatmaps
all_vals = np.concatenate([np.concatenate([v[0], v[1]], axis=0) for v in data.values()], axis=0)
zmin, zmax = float(np.min(all_vals)), float(np.max(all_vals))

# --------------------
# Figure with 2 rows (Baselines / DL) × 3 columns (datasets)
# --------------------
fig = make_subplots(
    rows=2, cols=3,
    subplot_titles=(
        [f"{d} • Baselines" for d in datasets] +
        [f"{d} • Deep Learning" for d in datasets]
    ),
    horizontal_spacing=0.06, vertical_spacing=0.15
)

# Use separate coloraxis per column so each column shares a colorbar
coloraxes = ["coloraxis", "coloraxis2", "coloraxis3"]

for col, d in enumerate(datasets, start=1):
    base_z, dl_z = data[d]

    # Baselines heatmap (row 1)
    fig.add_trace(
        go.Heatmap(
            z=base_z,
            x=norms,
            y=baseline_models,
            coloraxis=coloraxes[col-1],
            zmin=zmin, zmax=zmax,
            hovertemplate="Model: %{y}<br>Norm: %{x}<br>MAE: %{z}<extra></extra>",
            showscale=False  # colorbar set per column below
        ),
        row=1, col=col
    )

    # DL heatmap (row 2)
    fig.add_trace(
        go.Heatmap(
            z=dl_z,
            x=norms,
            y=dl_models,
            coloraxis=coloraxes[col-1],
            zmin=zmin, zmax=zmax,
            hovertemplate="Model: %{y}<br>Norm: %{x}<br>MAE: %{z}<extra></extra>",
            showscale=False
        ),
        row=2, col=col
    )

# # Add a colorbar for each column
# fig.update_layout(
#     **{
#         "coloraxis": dict(colorscale="Viridis", colorbar=dict(title="MAE", x=0.155, len=0.78)),
#         "coloraxis2": dict(colorscale="Viridis", colorbar=dict(title="MAE", x=0.5,   len=0.78)),
#         "coloraxis3": dict(colorscale="Viridis", colorbar=dict(title="MAE", x=0.845, len=0.78)),
#     }
# )

# Style
for i in range(1, 3*2 + 1):
    # make text readable
    pass

# Axis labels
for c in range(1, 4):
    fig.update_xaxes(title="Normalization", row=2, col=c)  # bottom row x-axis titles
    fig.update_yaxes(title="Model", row=1, col=c)
    fig.update_yaxes(title="Model", row=2, col=c)

fig.update_layout(
    height=800,
    width=1200,
    title="Normalization Experiments (MAE ↓): Baselines vs Deep Learning across Datasets",
    template="plotly_white",
    margin=dict(t=80, l=60, r=40, b=40)
)

fig.show()


# Feature Engineering Experiment

In [6]:
import numpy as np
import pandas as pd
import plotly.express as px

# Setup
models = ["LR","MoLE","MSAR","KF","XGB","GP","MLP","TNET","STM","AMSH","PTST","TXER","GPT","NBX"]
feature_sets = ["No Exo", "Mean", "catch22", "Other"]
datasets = ["WildPPG", "Capture24", "DaLiA"]

rng = np.random.default_rng(42)
values = np.round(rng.normal(loc=10, scale=2, size=(len(models), len(feature_sets), len(datasets))), 2)
values = np.clip(values, 5.0, 16.0)

# Build long-form DataFrame
rows = []
for d_idx, dset in enumerate(datasets):
    for i, model in enumerate(models):
        for j, feature in enumerate(feature_sets):
            rows.append((dset, model, feature, values[i, j, d_idx]))
df = pd.DataFrame(rows, columns=["Dataset", "Model", "Feature Set", "MAE"])

# Ensure consistent order
df["Model"] = pd.Categorical(df["Model"], categories=models, ordered=True)

# Plot with facet per dataset
fig = px.bar(
    df,
    x="Model",
    y="MAE",
    color="Feature Set",
    barmode="group",
    facet_col="Dataset",
    text="MAE",
    height=520,
    template="plotly_white"
)

fig.update_traces(textposition="outside", cliponaxis=False)
fig.update_layout(
    title="Time-Series Forecasting: Model × Feature Set × Dataset (MAE ↓)",
    xaxis_title="Model",
    yaxis_title="MAE (lower is better)",
    legend_title="Feature Set",
    bargap=0.15,
    bargroupgap=0.06
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))  # simplify facet titles

fig.show()
