This code will take the logits from subject predictions of x runs and average them - Simple Ensemble

In [2]:
import numpy as np
import os
import pandas as pd
from pathlib import Path
import re
import itertools
import torch
from util.metrics import calculate_metrics

pd.set_option('display.max_colwidth', None)

unet2d =0
unet1d = 1
unet1d_0 = 0
unet2d_0 = 0
OP = 0
ENC = 0
W2V = 0
HUB=0

# --------------------
# Your helpers (use these EXACTLY)
# --------------------
def np_mean_std(xs):
    xs = np.asarray(xs, dtype=float)
    return float(xs.mean()), float(xs.std(ddof=1)) if len(xs) > 1 else 0.0

def print_fold_stats(name, vals):
    m, s = np_mean_std(vals)
    # print(f"{name}: {m:.4f} ± {s:.4f}")
    return m

def ensure_same_subject_order(dfs, subj_col=0):
    """
    Ensures all DataFrames have identical subject order and IDs.
    Raises an error if any subject is missing or out of order.
    """
    subj_lists = [df.iloc[:, subj_col].tolist() for df in dfs]
    ref = subj_lists[0]

    # Check identical length
    for i, sl in enumerate(subj_lists):
        if len(sl) != len(ref):
            raise ValueError(f"[Fold align error] Path {i} has {len(sl)} subjects, expected {len(ref)}.")

    # Check identical contents
    for i, sl in enumerate(subj_lists):
        if set(sl) != set(ref):
            missing = set(ref) - set(sl)
            extra = set(sl) - set(ref)
            raise ValueError(
                f"[Fold align error] Path {i} subject mismatch:\n"
                f"Missing: {sorted(missing)}\nExtra: {sorted(extra)}"
            )

    # If identical sets but different order, reorder
    aligned = []
    for df in dfs:
        df['_subj'] = df.iloc[:, subj_col]
        df = df.set_index('_subj').loc[ref].reset_index(drop=True)
        aligned.append(df)
    return aligned

# --------------------
# Config
# --------------------
# splits = ['xx42', 'xx52', 'xx01', 'xx02','xx03']
splits = ['xx3']
seed=31
print_paths = 0
DECIMALS = 2

def natural_key(p):
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', p.name)]

def safe_iterdir(path: Path):
    if not path.exists():
        return []
    return sorted([d for d in path.iterdir() if d.is_dir()], key=natural_key)

# Build per-split path dicts (kept identical to your naming scheme)
def build_per_split_paths(split):
    loc = {}
    # chan1
    loc["c1_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver138/ch1/'
    loc["c1_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver139/ch1/'
    loc["c1_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch1/'
    loc["c1_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver142/ch1/'
    loc["c1_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver0/ch1/'
    loc["c1_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver56/ch1/'
    loc["c1_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver53/ch1/'
    loc["c1_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver0/ch1/'

    # chan2
    loc["c2_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver294/ch2/'
    loc["c2_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver295/ch2/'
    loc["c2_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch2/'
    loc["c2_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver373/ch2/'
    loc["c2_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver32/ch2/'
    loc["c2_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver173/ch2/'
    loc["c2_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver161/ch2/'
    loc["c2_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver36/ch2/'

    # chan3
    loc["c3_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver552/ch3/'
    loc["c3_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver553/ch3/'
    loc["c3_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch3/'
    loc["c3_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver628/ch3/'
    loc["c3_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver64/ch3/'
    loc["c3_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver334/ch3/'
    loc["c3_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver323/ch3/'
    loc["c3_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver69/ch3/'

    # chan4
    loc["c4_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1028/ch4/'
    loc["c4_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1029/ch4/'
    loc["c4_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch4/'
    loc["c4_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver845/ch4/'
    loc["c4_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver97/ch4/'
    loc["c4_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver479/ch4/'
    loc["c4_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver443/ch4/'
    loc["c4_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver100/ch4/'

    # chan5
    loc["c5_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1152/ch5'
    loc["c5_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1153/ch5'
    loc["c5_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch5'
    loc["c5_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver975/ch5'
    loc["c5_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver144/ch5/'
    loc["c5_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver632/ch5/'
    loc["c5_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver608/ch5/'
    loc["c5_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver130/ch5/'

    # chan6
    loc["c6_path2d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1374/ch6/'
    loc["c6_path2d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl2d_cnn/adam_s{seed}/ver1375/ch6/'
    loc["c6_path1d"]=   f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1/ch6/'
    loc["c6_path1d0"]=  f'/home/sparc/dev/code/saved_results/saved{split}/unetssl_cnn/adam_s{seed}/ver1339/ch6/'
    loc["c6_pathOP"]=   f'/home/sparc/dev/code/saved_results/saved{split}/opera_ce/adam_s{seed}/ver160/ch6/'
    loc["c6_pathEN"]=   f'/home/sparc/dev/code/saved_results/saved{split}/encodec_cnn/adamw_s{seed}/ver749/ch6/'
    loc["c6_pathW2V"]=  f'/home/sparc/dev/code/saved_results/saved{split}/wav2vec_cnn/adam_s{seed}/ver755/ch6/'
    loc["c6_pathHUB"]=  f'/home/sparc/dev/code/saved_results/saved{split}/hubert_cnn/adam_s{seed}/ver163/ch6/'

    return loc

# Generate channel combinations of size 2..6
combos = []
for k in range(1, 7):
    combos.extend(itertools.combinations([1,2,3,4,5,6], k))

all_rows = []

for combo in combos:
    chan_comb = list(combo)  # e.g., [1,2,3]
    accP, senP, speP, mccP = [], [], [], []  # per-split means (each split is mean over folds)

    for split in splits:
        loc = build_per_split_paths(split)

        # Build the list of model paths for this combo and flags (identical to your logic)
        paths = []
        for c in chan_comb:
            if unet2d == 1:
                paths.append(loc[f"c{c}_path2d"])
            if unet1d == 1:
                paths.append(loc[f"c{c}_path1d"])
            if unet1d_0 == 1:
                paths.append(loc[f"c{c}_path1d0"])
            if unet2d_0 == 1:
                paths.append(loc[f"c{c}_path2d0"])
            if OP == 1:
                paths.append(loc[f"c{c}_pathOP"])
            if ENC == 1:
                paths.append(loc[f"c{c}_pathEN"])
            if W2V == 1:
                paths.append(loc[f"c{c}_pathW2V"])
            if HUB == 1:
                paths.append(loc[f"c{c}_pathHUB"])

        # Load per-path, per-fold dataframes
        all_paths_folds = []
        for path in paths:
            Path_ch = Path(path)
            fold_dfs = []
            for d in safe_iterdir(Path_ch):
                csv_path = os.path.join(d, 'subject_pred.csv')
                if os.path.isfile(csv_path):
                    try:
                        sub_pred = pd.read_csv(csv_path)
                        fold_dfs.append(sub_pred)
                    except Exception:
                        pass
            if fold_dfs:
                all_paths_folds.append(fold_dfs)

        if not all_paths_folds:
            continue  # no data for this split/combination

        # Normalize fold counts across paths
        num_paths = len(all_paths_folds)
        num_folds = min(len(f) for f in all_paths_folds)
        if not all(len(f) == num_folds for f in all_paths_folds):
            print(f"[WARN] {chan_comb} split {split}: varying folds; using {num_folds}")

        # Fusion per fold (same as your fusion block)
        fusion_fold_metrics = {"acc": [], "sen": [], "spe": [], "mcc": [], "f1p": [], "f1n": []}
        w = [1.0 / num_paths] * num_paths

        for f_idx in range(num_folds):
            fold_dfs = [all_paths_folds[p_idx][f_idx] for p_idx in range(num_paths)]
            # IMPORTANT: use your ensure_same_subject_order
            fold_dfs = ensure_same_subject_order(fold_dfs, subj_col=0)

            labels = torch.as_tensor(fold_dfs[0].iloc[:, 1].to_numpy())
            logits_list = [torch.as_tensor(df.iloc[:, 4].to_numpy(), dtype=torch.float32) for df in fold_dfs]
            avg_prob = sum(w[p] * logits_list[p] for p in range(num_paths))
            pred = avg_prob.round()  # threshold 0.5

            acc, sen, spe, mcc, f1p, f1n = calculate_metrics(labels, pred)
            fusion_fold_metrics["acc"].append(acc)
            fusion_fold_metrics["sen"].append(sen)
            fusion_fold_metrics["spe"].append(spe)
            fusion_fold_metrics["mcc"].append(mcc)
            fusion_fold_metrics["f1p"].append(f1p)
            fusion_fold_metrics["f1n"].append(f1n)

        # Use your print_fold_stats to get per-split means (over folds)
        accP.append(print_fold_stats('acc', fusion_fold_metrics['acc']))
        senP.append(print_fold_stats('sen', fusion_fold_metrics['sen']))
        speP.append(print_fold_stats('spe', fusion_fold_metrics['spe']))
        mccP.append(print_fold_stats('mcc', fusion_fold_metrics['mcc']))

    # After all splits → compute mean(std) across splits
    def mean_std_fmt(values, decimals=2, percent=True):
        mu, sd = np_mean_std(values)  # uses ddof=1
        if percent:
            mu *= 100; sd *= 100
        return f"{mu:.{decimals}f} ({sd:.{decimals}f})"

    acc_str = mean_std_fmt(accP, DECIMALS, percent=True)
    sen_str = mean_std_fmt(senP, DECIMALS, percent=True)
    spe_str = mean_std_fmt(speP, DECIMALS, percent=True)
    mcc_str = mean_std_fmt(mccP, DECIMALS, percent=False)

    chan_str = "-".join(str(x) for x in chan_comb)
    all_rows.append([chan_str, acc_str, sen_str, spe_str, mcc_str])

# --------------------
# Emit LaTeX table
# --------------------
caption = "Subject-level performance across all channel combinations (mean (std) over splits)"
label = "subject_level_all_combos"
colspec = "c|cccc"

latex_lines = [r"""\begin{table}[t]
\centering
\small
\setlength{\tabcolsep}{2pt}
\renewcommand{\arraystretch}{1.0}
\caption{""" + caption + r"""}
\label{""" + label + r"""}
\begin{tabular}{""" + colspec + r"""}
\hline
\hline
CH & ACC & SENS & SPEC & MCC \\
\hline
"""]

# Sort by combination length then numerically
all_rows_sorted = sorted(all_rows, key=lambda r: (len(r[0].split('-')), [int(x) for x in r[0].split('-')]))
for r in all_rows_sorted:
    latex_lines.append(" & ".join(r) + r" \\" + "\n")

latex_lines.append(r"""\hline
\hline
\end{tabular}
\end{table}
""")

latex_table = "".join(latex_lines)
print(latex_table)


\begin{table}[t]
\centering
\small
\setlength{\tabcolsep}{2pt}
\renewcommand{\arraystretch}{1.0}
\caption{Subject-level performance across all channel combinations (mean (std) over splits)}
\label{subject_level_all_combos}
\begin{tabular}{c|cccc}
\hline
\hline
CH & ACC & SENS & SPEC & MCC \\
\hline
1 & 72.97 (0.00) & 81.86 (0.00) & 63.15 (0.00) & 0.46 (0.00) \\
2 & 71.28 (0.00) & 78.12 (0.00) & 63.92 (0.00) & 0.44 (0.00) \\
3 & 73.94 (0.00) & 81.07 (0.00) & 65.99 (0.00) & 0.49 (0.00) \\
4 & 74.33 (0.00) & 82.49 (0.00) & 65.32 (0.00) & 0.50 (0.00) \\
5 & 65.50 (0.00) & 77.27 (0.00) & 52.56 (0.00) & 0.32 (0.00) \\
6 & 72.28 (0.00) & 73.97 (0.00) & 70.17 (0.00) & 0.45 (0.00) \\
1-2 & 74.64 (0.00) & 82.51 (0.00) & 66.01 (0.00) & 0.50 (0.00) \\
1-3 & 75.64 (0.00) & 82.99 (0.00) & 67.36 (0.00) & 0.52 (0.00) \\
1-4 & 77.38 (0.00) & 85.05 (0.00) & 68.84 (0.00) & 0.55 (0.00) \\
1-5 & 71.25 (0.00) & 79.81 (0.00) & 61.77 (0.00) & 0.43 (0.00) \\
1-6 & 75.98 (0.00) & 78.47 (0.00) & 73.05 (0.00) & 0