### Imports

In [None]:
import os
import numpy as np
import pandas as pd
from typing import Optional
import matplotlib.pyplot as plt

### Evaluation and Plotting Scripts

In [None]:
def evaluate_selection_strategy(csv_path: str, plot: bool = True):
    """
    Parameters
    ----------
    csv_path : str
        Path to results CSV with columns:
        ['dataset_version','evaluation','split','accuracy','accuracy_std','dataset_distance_to_benchmark']
    plot : bool
        If True, plots a bar chart of average accuracies.

    Returns
    -------
    summary : dict
        {
          'baseline_val_accuracy': float,
          'avg_selection_strategy': float,
          'avg_benchmark_only': float,
          'avg_alt_only': float,
          'num_versions': int,
          'num_selected_benchmark': int,
          'num_selected_alt': int
        }
    per_version : pd.DataFrame
        One row per dataset_version (excluding '__benchmark__'),
        with columns: ['dataset_version','bench_val','bench_test','alt_test','chosen','chosen_test_acc']
    """
    df = pd.read_csv(csv_path)

    # 1) Baseline: benchmark model on __benchmark__ validation
    baseline_rows = df[
        (df["dataset_version"] == "__benchmark__") &
        (df["evaluation"] == "benchmark_model") &
        (df["split"] == "val")
    ]
    if baseline_rows.empty:
        raise ValueError("Baseline row not found: __benchmark__/benchmark_model/val")
    baseline = float(baseline_rows["accuracy"].mean())

    # 2) Build a per-version table (exclude the __benchmark__ row)
    versions = sorted(v for v in df["dataset_version"].unique() if v != "__benchmark__")

    rows = []
    missing = []

    # Counters for requested triggers
    count_bench_ge_baseline = 0  # times benchmark val on new dataset >= baseline val
    count_bench_worse_but_alt_even_worse = 0  # times bench worse than baseline but chosen since alt val < bench val

    for v in versions:
        # benchmark val & test
        bench_val_rows = df[(df["dataset_version"] == v) &
                            (df["evaluation"] == "benchmark_model") &
                            (df["split"] == "val")]
        bench_test_rows = df[(df["dataset_version"] == v) &
                             (df["evaluation"] == "benchmark_model") &
                             (df["split"] == "test")]
        # alt val & test (need alt val for the new decision rule)
        alt_val_rows = df[(df["dataset_version"] == v) &
                          (df["evaluation"] == "alt_model") &
                          (df["split"] == "val")]
        alt_test_rows = df[(df["dataset_version"] == v) &
                           (df["evaluation"] == "alt_model") &
                           (df["split"] == "test")]

        if bench_val_rows.empty or bench_test_rows.empty or alt_test_rows.empty or alt_val_rows.empty:
            missing.append(v)
            continue

        bench_val  = float(bench_val_rows["accuracy"].mean())
        bench_test = float(bench_test_rows["accuracy"].mean())
        alt_val    = float(alt_val_rows["accuracy"].mean())
        alt_test   = float(alt_test_rows["accuracy"].mean())

        # Selection strategy:
        # 1) If benchmark VAL on new dataset >= baseline -> choose benchmark TEST
        # 2) Else (benchmark worse than baseline), compare alt VAL vs benchmark VAL on the new dataset:
        #       - if alt VAL >= bench VAL -> choose alt TEST
        #       - else -> choose benchmark TEST (count this special trigger)
        if bench_val >= baseline:
            chosen = "benchmark_model"
            chosen_acc = bench_test
            count_bench_ge_baseline += 1
        else:
            if alt_val >= bench_val:
                chosen = "alt_model"
                chosen_acc = alt_test
            else:
                chosen = "benchmark_model"
                chosen_acc = bench_test
                count_bench_worse_but_alt_even_worse += 1

        rows.append({
            "dataset_version": v,
            "bench_val": bench_val,
            "bench_test": bench_test,
            "alt_test": alt_test,
            "chosen": chosen,
            "chosen_test_acc": chosen_acc,
        })

    per_version = pd.DataFrame(rows).sort_values("dataset_version").reset_index(drop=True)

    if not per_version.empty:
        avg_selection = float(per_version["chosen_test_acc"].mean())
        avg_bench_only = float(per_version["bench_test"].mean())
        avg_alt_only = float(per_version["alt_test"].mean())
        n_sel_bench = int((per_version["chosen"] == "benchmark_model").sum())
        n_sel_alt = int((per_version["chosen"] == "alt_model").sum())
    else:
        avg_selection = avg_bench_only = avg_alt_only = float("nan")
        n_sel_bench = n_sel_alt = 0

    # 3) Print results
    print(f"Baseline (benchmark __benchmark__ VAL): {baseline:.4f}")
    print(f"Versions included: {len(per_version)}")
    if missing:
        print(f"[WARN] Skipped {len(missing)} versions due to missing rows: {missing[:5]}{' ...' if len(missing)>5 else ''}")
    print(f"Selected benchmark for {n_sel_bench} versions; alt for {n_sel_alt} versions.")
    print(f"Trigger counts:")
    print(f"  • Benchmark ≥ baseline on new VAL (picked benchmark): {count_bench_ge_baseline}")
    print(f"  • Benchmark < baseline on new VAL but alt VAL worse (still picked benchmark): {count_bench_worse_but_alt_even_worse}")
    print(f"\nAverage TEST accuracy (selection strategy): {avg_selection:.4f}")
    print(f"Average TEST accuracy (benchmark-only):     {avg_bench_only:.4f}")
    print(f"Average TEST accuracy (alt-only):           {avg_alt_only:.4f}")

    # 4) Plot
    if plot:
        labels = ["Selection strategy", "Benchmark-only", "Alt-only"]
        values = [avg_selection, avg_bench_only, avg_alt_only]
        plt.figure(figsize=(6,4))
        plt.bar(labels, values)
        plt.ylabel("Average TEST accuracy")
        plt.title("Average Test Accuracy by Strategy")
        for i, v in enumerate(values):
            plt.text(i, v + 0.002, f"{v:.3f}", ha="center", va="bottom")
        plt.tight_layout()
        plt.show()

    summary = {
        "baseline_val_accuracy": baseline,
        "avg_selection_strategy": avg_selection,
        "avg_benchmark_only": avg_bench_only,
        "avg_alt_only": avg_alt_only,
        "num_versions": len(per_version),
        "num_selected_benchmark": n_sel_bench,
        "num_selected_alt": n_sel_alt,
        "num_bench_ge_baseline_val": count_bench_ge_baseline,
        "num_bench_worse_but_alt_even_worse": count_bench_worse_but_alt_even_worse,
    }
    return summary, per_version

In [None]:
def plot_results(
    df: pd.DataFrame,
    *,
    split: str = "test",   # 'val' or 'test'
    title: str = "Performance vs. Distance to Benchmark Dataset",
    legend_benchmark: str = "Benchmark model",
    legend_alt: str = "Alt model",
    save_csv_path: Optional[str] = None,
    show: bool = True,
):
    """
    Plot accuracy (mean over folds if per-fold data is provided) vs dataset distance for a given split.
    Accepts either:
      - per-fold results with columns ['fold','evaluation','split','dataset_distance_to_benchmark','accuracy'], or
      - summary with ['evaluation','split','dataset_distance_to_benchmark','accuracy_mean'].
    """
    import matplotlib.pyplot as plt

    if df is None or df.empty:
        print("No results to plot.")
        return

    # detect per-fold vs summary
    if "fold" in df.columns and "accuracy" in df.columns:
        # average per version/evaluation for the chosen split
        plot_df = (
            df[df["split"] == split]
            .groupby(["dataset_version", "evaluation"], as_index=False)
            .agg(
                accuracy_mean=("accuracy", "mean"),
                dataset_distance_to_benchmark=("dataset_distance_to_benchmark", "first"),
            )
        )
    else:
        plot_df = df[df["split"] == split].rename(columns={"accuracy_mean": "accuracy_mean"})

    df_bench = plot_df[plot_df["evaluation"] == "benchmark_model"]
    df_alt   = plot_df[plot_df["evaluation"] == "alt_model"]

    if df_bench.empty and df_alt.empty:
        print(f"No data to plot for split='{split}'.")
        return

    plt.figure(figsize=(7, 5))
    if not df_bench.empty:
        plt.scatter(
            df_bench["dataset_distance_to_benchmark"], df_bench["accuracy_mean"],
            label=legend_benchmark, marker="o"
        )
    if not df_alt.empty:
        plt.scatter(
            df_alt["dataset_distance_to_benchmark"], df_alt["accuracy_mean"],
            label=legend_alt, marker="^"
        )

    plt.xlabel("Dataset distance to benchmark (0–1)")
    plt.ylabel(f"{split.capitalize()} accuracy (mean over folds)")
    plt.title(title + f" — {split.upper()} split")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    if show:
        plt.show()

    if save_csv_path:
        df.to_csv(save_csv_path, index=False)
        print(f"Saved results to: {save_csv_path}")

### Parametrization

In [None]:
# classifiers
from classifiers.WDI_1NN import WDI_1NN
from classifiers.ACM_SVM import ACM_SVM
from classifiers.CASIM import CASIM
from classifiers.EAC_1NN import EAC_1NN
from classifiers.MBW_LR import MBW_LR

# ---------------- Params 
wdi_1nn_params = {"template_threshold": 0.5}
acm_svm_params = None
casim_params = {
    "num_features": 672,
    "n_estimators": 1,
    "n_jobs_multirocket": 1,
    "random_state": 42,
    "alphas": np.logspace(-3, 3, 10),
}
eac_1nn_params = {"attenuation_coefficient_per_min": 0.001}  # 0.0667}
mbw_lr_params = {
    "penalty": None,
    "fit_intercept": False,
    "solver": "lbfgs",
    "multi_class": "ovr",
    "decision_bounds": True,
    "confidence_interval": 1.96,
}

# list of models and model_params to evaluate
models_to_evaluate = [
    (WDI_1NN, wdi_1nn_params, "WDI-1NN", True),
    (CASIM, casim_params, "CASIM", False),
    (EAC_1NN, eac_1nn_params, "EAC-1NN", True),
    (MBW_LR, mbw_lr_params, "MBW-LR", False),
    (ACM_SVM, acm_svm_params, "ACM-SVM", False),
]

### Experiments

#### TEP Evaluation

In [None]:
datasets = ["fcc"]
splits = ["test"]

for dataset in datasets:
    for _, _, model_name, _ in models_to_evaluate:
        csv_path = f"results/{dataset}_results_{model_name.lower().replace('-', '_')}.csv"
        if not os.path.isfile(csv_path):
            print(f"[WARN] Results CSV not found: {csv_path}. Skipping.")
            continue
        print(f"\n[INFO] Loading and plotting: {csv_path}")
        df = pd.read_csv(csv_path)
        for split in splits:
            try:
                plot_results(
                    df,
                    split=split,
                    title=f"{model_name} — {dataset.upper()} — Distance vs Performance",
                    show=True
                )
            except Exception as e:
                print(f"[ERROR] Failed to plot {csv_path} (split={split}): {e}")

In [None]:
for model, model_params, model_name, use_argmin in models_to_evaluate:
    print(f"\n=== Evaluating selection strategy for model: {model_name} ===")
    csv_path = f"results/fcc_results_{model_name.lower().replace('-', '_')}.csv"
    if not os.path.isfile(csv_path):
        print(f"[WARN] Results CSV not found: {csv_path}. Skipping.")
        continue
    summary, per_version = evaluate_selection_strategy(csv_path, plot=True)

#### FCC Evaluation

In [None]:
datasets = ["tep"]
splits = ["test"]

for dataset in datasets:
    for _, _, model_name, _ in models_to_evaluate:
        csv_path = f"results/{dataset}_results_{model_name.lower().replace('-', '_')}.csv"
        if not os.path.isfile(csv_path):
            print(f"[WARN] Results CSV not found: {csv_path}. Skipping.")
            continue
        print(f"\n[INFO] Loading and plotting: {csv_path}")
        df = pd.read_csv(csv_path)
        for split in splits:
            try:
                plot_results(
                    df,
                    split=split,
                    title=f"{model_name} — {dataset.upper()} — Distance vs Performance",
                    show=True
                )
            except Exception as e:
                print(f"[ERROR] Failed to plot {csv_path} (split={split}): {e}")

In [None]:
for model, model_params, model_name, use_argmin in models_to_evaluate:
    print(f"\n=== Evaluating selection strategy for model: {model_name} ===")
    csv_path = f"results/tep_results_{model_name.lower().replace('-', '_')}.csv"
    if not os.path.isfile(csv_path):
        print(f"[WARN] Results CSV not found: {csv_path}. Skipping.")
        continue
    summary, per_version = evaluate_selection_strategy(csv_path, plot=True)