In [None]:
# === Class-wise metrics with 95% bootstrap CIs + confusion partner + confidence ===
# - Expects columns: keystep_id (int true), pred_keystep_id (int pred), all_preds (stringified logits [[...]])
# - Produces: /mnt/data/metrics_out/classwise_metrics_with_CIs.csv and overall_summary.csv
# - Minimal deps: pandas, numpy, scikit-learn

import pandas as pd
import numpy as np
import ast
from pathlib import Path
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score

import matplotlib.pyplot as plt

import seaborn as sns

In [None]:

# ---- CONFIG ----
CSV_PATH = "./results/model_id_job_1173491_task_classification_on_20250715-163830/preds.csv"
modality = "resnet_ego_imu"  # for output filenames

N_BOOT = 300          # bootstrap resamples for CIs
ALPHA = 0.05          # 95% CI
RANDOM_SEED = 42
MIN_SUPPORT = 1       # filter ultra-rare classes if desired (e.g., 5 or 10)

# ---- HELPERS ----
def safe_parse_logits(x: str) -> np.ndarray:
    """Parse '[[...]]' or '[...]' into 1D array of logits."""
    arr = np.array(ast.literal_eval(x), dtype=float)
    if arr.ndim == 2 and arr.shape[0] == 1:
        arr = arr[0]
    return arr

def softmax(z: np.ndarray) -> np.ndarray:
    z = z - np.max(z)
    e = np.exp(z)
    return e / e.sum()

def bootstrap_classwise_metrics(
    y_true: np.ndarray, y_pred: np.ndarray, n_classes: int,
    n_boot: int, alpha: float, rng: np.random.Generator
):
    """Dataset-level bootstrap of per-class P/R/F1."""
    prec_hist = {k: [] for k in range(n_classes)}
    rec_hist  = {k: [] for k in range(n_classes)}
    f1_hist   = {k: [] for k in range(n_classes)}
    sup_hist  = {k: [] for k in range(n_classes)}
    n = len(y_true)

    for _ in range(n_boot):
        idx = rng.integers(0, n, size=n)
        yt, yp = y_true[idx], y_pred[idx]
        P, R, F1, S = precision_recall_fscore_support(
            yt, yp, labels=list(range(n_classes)), zero_division=0
        )
        for k in range(n_classes):
            prec_hist[k].append(P[k]); rec_hist[k].append(R[k])
            f1_hist[k].append(F1[k]);  sup_hist[k].append(S[k])

    qlo, qhi = 100 * (alpha/2), 100 * (1 - alpha/2)
    ci_prec = {k: (np.percentile(prec_hist[k], qlo), np.percentile(prec_hist[k], qhi)) for k in range(n_classes)}
    ci_rec  = {k: (np.percentile(rec_hist[k],  qlo), np.percentile(rec_hist[k],  qhi)) for k in range(n_classes)}
    ci_f1   = {k: (np.percentile(f1_hist[k],   qlo), np.percentile(f1_hist[k],   qhi)) for k in range(n_classes)}
    ci_sup  = {k: (np.percentile(sup_hist[k],  qlo), np.percentile(sup_hist[k],  qhi)) for k in range(n_classes)}
    return ci_prec, ci_rec, ci_f1, ci_sup

def top_confusion_partner(cm: np.ndarray) -> Dict[int, Tuple[int, int, float]]:
    """
    For each true class i, find the most frequent *wrong* predicted class j.
    Returns: i -> (j, count, frac_of_errors); if no errors -> (-1, 0, 0.0)
    """
    out = {}
    for i in range(cm.shape[0]):
        row = cm[i].copy()
        row[i] = 0
        errs = row.sum()
        if errs == 0:
            out[i] = (-1, 0, 0.0)
        else:
            j = int(np.argmax(row))
            cnt = int(row[j])
            out[i] = (j, cnt, float(cnt / errs))
    return out

def classwise_confidence_stats(y_true: np.ndarray, y_pred: np.ndarray, conf: np.ndarray, n_classes: int):
    """Median max-softmax confidence for correct vs incorrect per true class."""
    med_correct, med_incorrect = {}, {}
    for k in range(n_classes):
        mk = (y_true == k)
        if not np.any(mk):
            med_correct[k] = np.nan; med_incorrect[k] = np.nan
            continue
        conf_k = conf[mk]
        ok  = conf_k[y_pred[mk] == k]
        bad = conf_k[y_pred[mk] != k]
        med_correct[k]   = float(np.median(ok))  if ok.size  else np.nan
        med_incorrect[k] = float(np.median(bad)) if bad.size else np.nan
    return med_correct, med_incorrect

# ---- LOAD ----
df = pd.read_csv(CSV_PATH)

# Truth/pred ids
y_true = df["keystep_id"].astype(int).values
y_pred = df["pred_keystep_id"].astype(int).values

# Parse logits -> probs + max confidence
logits_list = df["all_preds"].apply(safe_parse_logits).tolist()
num_classes = len(logits_list[0])
probs = np.vstack([softmax(z) for z in logits_list])
pred_from_logits = np.argmax(probs, axis=1)
max_conf = np.max(probs, axis=1)

# Label map (robust to repeats)
id_to_label = {int(k): v["keystep_label"].iloc[0] for k, v in df.groupby("keystep_id")}

# ---- METRICS ----
P, R, F1, S = precision_recall_fscore_support(y_true, y_pred, labels=list(range(num_classes)), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
top_conf = top_confusion_partner(cm)
med_c, med_i = classwise_confidence_stats(y_true, y_pred, max_conf, num_classes)

# Bootstrap CIs
rng = np.random.default_rng(RANDOM_SEED)
ci_prec, ci_rec, ci_f1, ci_sup = bootstrap_classwise_metrics(y_true, y_pred, num_classes, N_BOOT, ALPHA, rng)

# ---- REPORT TABLE ----
rows = []
for k in range(num_classes):
    support = int(S[k])
    if support < MIN_SUPPORT:
        continue
    label = id_to_label.get(k, f"Class_{k}")
    j, cnt, frac = top_conf.get(k, (-1, 0, 0.0))
    partner_label = id_to_label.get(j, "—") if j >= 0 else "—"
    rows.append({
        "class_id": k,
        "label": label,
        "support": support,
        "precision": P[k],
        "precision_CI_low": ci_prec[k][0],
        "precision_CI_high": ci_prec[k][1],
        "recall": R[k],
        "recall_CI_low": ci_rec[k][0],
        "recall_CI_high": ci_rec[k][1],
        "f1": F1[k],
        "f1_CI_low": ci_f1[k][0],
        "f1_CI_high": ci_f1[k][1],
        "top_confused_with_id": (j if j >= 0 else np.nan),
        "top_confused_with_label": partner_label,
        "top_confusion_fraction_of_errors": frac,
        "median_conf_correct": med_c[k],
        "median_conf_incorrect": med_i[k],
    })

report_df = pd.DataFrame(rows).sort_values(by=["f1", "support"], ascending=[True, False]).reset_index(drop=True)

# ---- MACRO SUMMARY ----
overall_acc = accuracy_score(y_true, y_pred)
macro_p = np.nanmean(P)
macro_r = np.nanmean(R)
macro_f1 = np.nanmean(F1)
mismatch_rate = float(np.mean(pred_from_logits != y_pred))

summary = pd.DataFrame([{
    "num_samples": len(df),
    "num_classes": num_classes,
    "overall_accuracy": overall_acc,
    "macro_precision": macro_p,
    "macro_recall": macro_r,
    "macro_f1": macro_f1,
    "pred_vs_logits_mismatch_rate": mismatch_rate,
    "bootstrap_resamples": N_BOOT,
    "alpha": ALPHA
}])

# ---- SAVE + DISPLAY ----
out_dir = Path("./analysis/")
out_dir.mkdir(parents=True, exist_ok=True)
classwise_path = out_dir / f"{modality}_classwise_metrics_with_CIs.csv"
summary_path = out_dir / f"{modality}_overall_summary.csv"
report_df.to_csv(classwise_path, index=False)
summary.to_csv(summary_path, index=False)

print("Saved:")
print(f"- {classwise_path}")
print(f"- {summary_path}")

# Nice display in Jupyter if helper is available
try:
    from caas_jupyter_tools import display_dataframe_to_user
    display_dataframe_to_user("Class-wise metrics with 95% bootstrap CI", report_df)
    display_dataframe_to_user("Overall summary", summary)
except Exception:
    display(report_df.head(20))
    display(summary)


In [None]:

def plot_top_confusion_partners(report_df: pd.DataFrame,
                                top_n: int = 10,
                                support_min: int = 1,
                                title: str = None):
    """
    Plot the worst `top_n` classes by F1 and show, for each, the top confusion partner
    and the fraction of that class's errors that go to that partner.
    
    Expected columns in report_df:
      - 'label' (str): class name
      - 'f1' (float)
      - 'support' (int)
      - 'top_confused_with_label' (str)
      - 'top_confusion_fraction_of_errors' (float in [0,1])
    """
    # Filter to classes with enough support (optional) and valid confusion info
    df = report_df.copy()
    df = df[df['support'] >= support_min]
    df = df[~df['top_confused_with_label'].isna()]  # keep rows that have a partner
    df = df.sort_values(['f1', 'support'], ascending=[True, False]).head(top_n)

    if df.empty:
        raise ValueError("No classes meet the filter criteria (support_min, top_n, or missing columns).")

    # Build y labels like: "<class> → <partner>"
    y_labels = df.apply(
        lambda r: f"{str(r['label'])} → {str(r['top_confused_with_label'])}"
                  if isinstance(r['top_confused_with_label'], str) else str(r['label']),
        axis=1
    )

    frac = df['top_confusion_fraction_of_errors'].astype(float).clip(lower=0, upper=1)
    f1_vals = df['f1'].astype(float)

    # Plot
    fig, ax = plt.subplots(figsize=(10, max(4, 0.5 * len(df))))
    y_pos = np.arange(len(df))[::-1]  # reverse so worst is at top visually
    ax.barh(y_pos, frac, alpha=0.9)   # default matplotlib color; no explicit color

    # Add value annotations (as percentages) and F1 text at the end of bars
    for i, (v, f1) in enumerate(zip(frac.values[::-1], f1_vals.values[::-1])):
        ypos = i
        ax.text(v + 0.01, ypos, f"{v*100:.0f}%", va='center')
        ax.text(min(0.98, max(0.02, v/2)), ypos, f"F1={f1:.2f}", va='center', ha='center', fontsize=9, color='white')

    ax.set_yticks(np.arange(len(df)))
    ax.set_yticklabels(list(y_labels.values)[::-1])
    ax.set_xlabel("Fraction of this class's errors going to top confusion partner")
    ax.set_xlim(0, 1)

    ttl = title or f"Top Confusion Partners for Bottom {len(df)} Classes (by F1)"
    ax.set_title(ttl)
    ax.grid(axis='x', linestyle='--', alpha=0.4)
    plt.tight_layout()
    return fig, ax

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

def plot_top_confusion_partners_clean(report_df: pd.DataFrame,
                                      top_n: int = 10,
                                      support_min: int = 1,
                                      title: str = None):
    """
    Plot the top confusion partners for the bottom-N classes by F1,
    showing only the percentage of errors to the top confusion partner (no F1 text).
    """
    df = report_df.copy()
    df = df[df['support'] >= support_min]
    df = df[~df['top_confused_with_label'].isna()]
    df = df.sort_values(['f1', 'support'], ascending=[True, False]).head(top_n)

    if df.empty:
        raise ValueError("No valid classes to plot after filtering.")

    # Build y-axis labels: "true → partner"
    y_labels = df.apply(
        lambda r: f"{str(r['label'])} → {str(r['top_confused_with_label'])}",
        axis=1
    )

    frac = df['top_confusion_fraction_of_errors'].astype(float).clip(0, 1)

    fig, ax = plt.subplots(figsize=(10, max(4, 0.5 * len(df))))
    y_pos = np.arange(len(df))[::-1]

    # Bars
    ax.barh(y_pos, frac, alpha=0.9)

    # Annotate bars with percentages only
    for i, v in enumerate(frac.values[::-1]):
        ypos = i
        ax.text(v + 0.01, ypos, f"{v*100:.0f}%", va='center')

    ax.set_yticks(np.arange(len(df)))
    ax.set_yticklabels(list(y_labels.values)[::-1])
    ax.set_xlabel("Fraction of this class's errors going to top confusion partner")
    ax.set_xlim(0, 1)

    ttl = title or f"Top Confusion Partners for Bottom {len(df)} Classes (by F1)"
    ax.set_title(ttl)
    ax.grid(axis='x', linestyle='--', alpha=0.4)
    plt.tight_layout()
    return fig, ax

# Example usage:
# report_df = pd.read_csv("/mnt/data/metrics_out/classwise_metrics_with_CIs.csv")
# plot_top_confusion_partners(report_df, top_n=10, support_min=5)
# plt.show()
ego_report_df = pd.read_csv("./analysis/ego_classwise_metrics_with_CIs.csv")

plot_top_confusion_partners_clean(ego_report_df, top_n=10, support_min=5)
plt.show()


In [None]:
ego_imu_report_df = pd.read_csv("./analysis/resnet_ego_imu_classwise_metrics_with_CIs.csv")

plot_top_confusion_partners_clean(ego_imu_report_df, top_n=10, support_min=5)
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_partner_reduction_stacked(
    report_video: pd.DataFrame,
    report_imu: pd.DataFrame,
    top_n: int = 10,
    support_min: int = 1,
    min_errors: int = 3,          # require ≥ this many errors in BOTH; set 0 to keep all
    key_on: str = "class_id",
    rank_by: str = "video",       # "video" | "max" | "delta"
    title: str = None,
    zero_eps: float = 0.004       # tiny sliver so 0% bars are still visible
):
    """Bars: Video-only (TOP), Video+IMU (BOTTOM). Δ=(IMU−Video) shown to the right of IMU bar."""
    # --- join & prep ---
    cols = [key_on, "label", "support",
            "top_confused_with_label", "top_confusion_fraction_of_errors",
            "f1", "recall"]

    A = report_video[cols].copy().rename(columns={
        "support":"support_video",
        "top_confused_with_label":"partner_video",
        "top_confusion_fraction_of_errors":"frac_video",
        "f1":"f1_video",
        "recall":"recall_video",
    })
    B = report_imu[cols].copy().rename(columns={
        "support":"support_imu",
        "top_confused_with_label":"partner_imu",
        "top_confusion_fraction_of_errors":"frac_imu",
        "f1":"f1_imu",
        "recall":"recall_imu",
    })

    M = pd.merge(A, B, on=[key_on, "label"], how="inner")
    # error counts
    M["errors_video"] = (M["support_video"] * (1 - M["recall_video"])).round().astype("Int64")
    M["errors_imu"]   = (M["support_imu"]   * (1 - M["recall_imu"])).round().astype("Int64")

    # fractions
    M["frac_video"] = pd.to_numeric(M["frac_video"], errors="coerce").clip(0, 1)
    M["frac_imu"]   = pd.to_numeric(M["frac_imu"],   errors="coerce").clip(0, 1)

    # filters
    M = M[(M["support_video"] >= support_min) & (M["support_imu"] >= support_min)].copy()
    if min_errors > 0:
        M = M[(M["errors_video"] >= min_errors) & (M["errors_imu"] >= min_errors)].copy()
        if M.empty:
            raise ValueError("No classes after min_errors filter; lower min_errors or set to 0.")

    # Δ (negative = improvement with IMU)
    M["delta"] = (M["frac_imu"] - M["frac_video"]).astype(float)

    # ranking
    if rank_by == "video":
        M = M.sort_values(["frac_video", "f1_video"], ascending=[False, True])
    elif rank_by == "max":
        M["frac_max"] = M[["frac_video","frac_imu"]].max(axis=1)
        M = M.sort_values(["frac_max"], ascending=False)
    elif rank_by == "delta":
        M = M.sort_values(["delta"], ascending=True)  # most negative (best reduction) first
    else:
        raise ValueError("rank_by must be one of {'video','max','delta'}")

    M = M.head(top_n).reset_index(drop=True)

    # plotting arrays
    n = len(M)
    y = np.arange(n)*0.7          # we'll invert y-axis so 0 is visually at top
    h = 0.14                  # small vertical offset so bars sit close

    fv_true = M["frac_video"].to_numpy(float)
    fi_true = M["frac_imu"].to_numpy(float)

    # draw 0% bars as tiny slivers so both bars are always visible
    fv_plot = np.where(np.isnan(fv_true) | (fv_true == 0.0), zero_eps, fv_true)
    fi_plot = np.where(np.isnan(fi_true) | (fi_true == 0.0), zero_eps, fi_true)

    fig, ax = plt.subplots(figsize=(9, max(4.6, 0.55*n)))

    # Video-only (TOP bar), Video+IMU (BOTTOM bar)
    ax.barh(y - h, fv_plot, height=0.25, label="Video-only")
    ax.barh(y + h, fi_plot, height=0.25, label="Video + IMU")

    # annotate % on bar ends
    for i in range(n):
        ax.text(fv_plot[i] + 0.01, y[i] - h, f"{fv_true[i]*100:.0f}%", va="center")
        ax.text(fi_plot[i] + 0.01, y[i] + h, f"{fi_true[i]*100:.0f}%", va="center")
        # Δ to the RIGHT of the IMU bar
        ax.text(min(1.0, fi_plot[i] + 0.1), y[i] + h, f"Δ={(fi_true[i]-fv_true[i])*100:+.0f}%",
                va="center", ha="left", fontsize=9)

    # labels & cosmetics
    y_labels = M.apply(lambda r: f"GT: {r['label']}\nvideo→{r['partner_video']} | vid+IMU→{r['partner_imu']}", axis=1)
    ax.set_yticks(y)
    ax.set_yticklabels(y_labels)
    ax.invert_yaxis()  # put first row at top
    ax.set_xlabel("Fraction of this class's errors going to top confusion partner")
    ax.set_xlim(0, 1.09)  # little room for Δ text
    ax.legend(loc="lower right")
    ax.grid(axis="x", linestyle="--", alpha=0.35)
    ax.set_title(title or "Top Confusion Partner Concentration: Video (top) vs Video+IMU (bottom)")
    plt.tight_layout()
    return fig, ax, M


In [None]:
plot_confusion_partner_reduction_stacked(ego_report_df, ego_imu_report_df, top_n=5, support_min=5, rank_by="delta")
plt.show()

In [None]:
def plot_worse_with_imu(
    report_video: pd.DataFrame,
    report_imu: pd.DataFrame,
    top_k: int = 3,
    support_min: int = 1,
    min_errors: int = 3,          # require ≥ this many errors in BOTH to avoid 100% single-error artifacts
    key_on: str = "class_id",
    title: str = None,
    zero_eps: float = 0.004,      # tiny sliver so 0% bars still visible
    selected_keystep_ids: list = None  # NEW: list of specific keystep IDs to plot
):
    """
    Show the classes where adding IMU made confusion *more concentrated*:
      Δ = (IMU − Video) on 'top_confusion_fraction_of_errors' > 0.
    Plots the worst `top_k` classes with tight paired bars and Δ next to the IMU bar.

    Expects each report to have columns:
      [key_on, 'label', 'support', 'recall', 'top_confused_with_label', 'top_confusion_fraction_of_errors']
    
    Args:
        selected_keystep_ids: Optional list of specific keystep IDs to plot. 
                             If provided, only these classes will be plotted (ignoring top_k).
                             If None, plots top_k classes by delta.
    """

    # --- prepare & join ---
    cols = [key_on, "label", "support", "recall",
            "top_confused_with_label", "top_confusion_fraction_of_errors"]

    A = report_video[cols].copy().rename(columns={
        "support":"support_video",
        "recall":"recall_video",
        "top_confused_with_label":"partner_video",
        "top_confusion_fraction_of_errors":"frac_video",
    })
    B = report_imu[cols].copy().rename(columns={
        "support":"support_imu",
        "recall":"recall_imu",
        "top_confused_with_label":"partner_imu",
        "top_confusion_fraction_of_errors":"frac_imu",
    })

    M = pd.merge(A, B, on=[key_on, "label"], how="inner")

    # convert labels to title case for better display and replace '_' with ' '
    M["label"] = M["label"].str.replace('_', ' ').str.title()
    M["partner_video"] = M["partner_video"].str.replace('_', ' ').str.title()
    M["partner_imu"] = M["partner_imu"].str.replace('_', ' ').str.title()

    # Compute error counts for filtering & annotation
    M["errors_video"] = (M["support_video"] * (1 - M["recall_video"])).round().astype("Int64")
    M["errors_imu"]   = (M["support_imu"]   * (1 - M["recall_imu"])).round().astype("Int64")

    # Clean/clip fractions
    M["frac_video"] = pd.to_numeric(M["frac_video"], errors="coerce").clip(0, 1)
    M["frac_imu"]   = pd.to_numeric(M["frac_imu"],   errors="coerce").clip(0, 1)

    # Basic filters
    M = M[(M["support_video"] >= support_min) & (M["support_imu"] >= support_min)].copy()
    if min_errors > 0:
        M = M[(M["errors_video"] >= min_errors) & (M["errors_imu"] >= min_errors)].copy()
    if M.empty:
        raise ValueError("No classes after filters; relax min_errors/support_min.")

    # Δ > 0 means *worse with IMU* (more concentrated confusion)
    M["delta"] = (M["frac_imu"] - M["frac_video"]).astype(float)

    # Filter to only positive deltas
    M = M[M["delta"] > 0].copy()
    
    if M.empty:
        raise ValueError("No classes where IMU increased confusion concentration (Δ<=0 everywhere).")
    
    # NEW: Select classes based on selected_keystep_ids or top_k
    if selected_keystep_ids is not None:
        # Filter to only the selected keystep IDs
        W = M[M[key_on].isin(selected_keystep_ids)].copy()
        if W.empty:
            raise ValueError(f"None of the selected keystep IDs {selected_keystep_ids} found in filtered data with Δ>0.")
        # Sort by delta descending for consistent ordering
        W = W.sort_values("delta", ascending=False).reset_index(drop=True)
    else:
        # Use top_k as before
        W = M.sort_values("delta", ascending=False).head(top_k).reset_index(drop=True)

    # --- plotting ---
    n = len(W)
    y = np.arange(n) * 0.5    # tighter spacing between class rows
    h = 0.08                   # tight pairing separation

    fv_true = W["frac_video"].to_numpy(float)
    fi_true = W["frac_imu"].to_numpy(float)

    # make zero bars visible as slivers
    fv = np.where((~np.isfinite(fv_true)) | (fv_true == 0.0), zero_eps, fv_true)
    fi = np.where((~np.isfinite(fi_true)) | (fi_true == 0.0), zero_eps, fi_true)

    # Set seaborn style
    # sns.set_style("whitegrid")
    # sns.set_context("paper", font_scale=1.2)
    
    fig, ax = plt.subplots(figsize=(9, max(3.5, 0.5*n)))
    sns.reset_orig()  # reset to original matplotlib params after seaborn set_context

    # Get seaborn color palette
    colors = sns.color_palette("tab10", 2)
    
    # Video (top bar) and IMU (bottom bar) with seaborn colors
    ax.barh(y - h, fv, height=0.16, label="Video-only", color=colors[0])
    ax.barh(y + h, fi, height=0.16, label="Video + IMU", color=colors[1])

    # annotate % and Δ to the right of bars, and predicted class ON the bars
    for i in range(n):
        # Video bar: percentage and predicted class
        ax.text(fv[i] + 0.01, y[i] - h, f"{fv_true[i]*100:.0f}%", va="center")
        ax.text(fv[i]/2, y[i] - h, f"{W.iloc[i]['partner_video']}", 
                va="center", ha="center", fontsize=11, fontweight="bold", color="white")
        
        # IMU bar: percentage, predicted class, and delta
        ax.text(fi[i] + 0.01, y[i] + h, f"{fi_true[i]*100:.0f}%", va="center")
        ax.text(fi[i]/2, y[i] + h, f"{W.iloc[i]['partner_imu']}", 
                va="center", ha="center", fontsize=11, fontweight="bold", color="white")
        ax.text(min(1.0, fi[i] + 0.06), y[i] + h, f"Δ={(fi_true[i]-fv_true[i])*100:+.0f}%",
                va="center", ha="left", fontsize=11, color='red', fontweight="bold")

    # y labels: show only GT class
    y_labels = W["label"]
    # split long labels into two lines if needed
    y_labels = y_labels.apply(lambda x: x if len(x) <= 7 else '\n'.join(x.split(' ')[:len(x.split(' '))//2]) + '\n' + ' '.join(x.split(' ')[len(x.split(' '))//2:]))

    ax.set_yticks(y)
    ax.set_yticklabels(y_labels, fontsize=11)
    ax.invert_yaxis()
    # set y-axis title
    ax.set_ylabel("Ground Truth Class", fontsize=14)

    ax.set_xlabel("Confusion Concentration (Fraction of Errors)", fontsize=14)
    ax.set_xlim(0, 1.06)
    ax.legend(loc="upper right", frameon=True)
    ax.grid(axis="x", linestyle="--", alpha=0.35)
    
    plt.tight_layout()

    # save figure
    plt.savefig("./analysis/worse_with_imu_confusion_concentration.png", dpi=300)
    return fig, ax, W

# Example usage:
# Plot top 2 (default behavior)
# plot_worse_with_imu(ego_report_df, ego_imu_report_df, top_k=2, support_min=3)

# Plot specific keystep IDs
# plot_worse_with_imu(ego_report_df, ego_imu_report_df, 
#                     selected_keystep_ids=[5, 12, 18], 
#                     support_min=3)
plot_worse_with_imu(ego_report_df, ego_imu_report_df, top_k=2, support_min=3, selected_keystep_ids=[15,25])

In [None]:
import pandas as pd
import numpy as np

def get_top_worse_classes_with_examples(
    report_video: pd.DataFrame,
    report_imu: pd.DataFrame,
    preds_video: pd.DataFrame,
    preds_imu: pd.DataFrame,
    top_k: int = 3,
    support_min: int = 1,
    min_errors: int = 3,             # require ≥ this many errors in BOTH models
    key_on: str = "class_id",
    # prediction CSV column names (override if yours differ)
    true_id_col: str = "keystep_id",
    pred_id_col: str = "pred_keystep_id",
    start_frame_col: str = "start_frame",
    end_frame_col: str = "end_frame",
    trial_col: str = "trial_id",
    subject_col: str = "subject_id",
    max_examples_per_class: int = 15
):
    """
    Returns:
      worst_classes: DataFrame with the top_k classes where IMU increased confusion concentration
                     columns: [class_id,label,partner_video,frac_video,partner_imu,frac_imu,delta,errors_video,errors_imu]
      examples:      DataFrame with rows from preds_imu for those classes, restricted to
                     instances where IMU predicted the (IMU) top confusion partner.
                     columns: [class_id,label,partner_imu,delta,keystep_id,start_frame,end_frame,trial_id,subject_id]
    """

    # ---- build comparison table (similar to earlier helpers) ----
    cols = [key_on, "label", "support", "recall",
            "top_confused_with_label", "top_confusion_fraction_of_errors"]

    A = report_video[cols].copy().rename(columns={
        "support":"support_video",
        "recall":"recall_video",
        "top_confused_with_label":"partner_video",
        "top_confusion_fraction_of_errors":"frac_video",
    })
    B = report_imu[cols].copy().rename(columns={
        "support":"support_imu",
        "recall":"recall_imu",
        "top_confused_with_label":"partner_imu",
        "top_confusion_fraction_of_errors":"frac_imu",
    })

    M = pd.merge(A, B, on=[key_on, "label"], how="inner")
    # numeric clean-up
    M["frac_video"] = pd.to_numeric(M["frac_video"], errors="coerce").clip(0, 1)
    M["frac_imu"]   = pd.to_numeric(M["frac_imu"],   errors="coerce").clip(0, 1)
    # error counts for filtering
    M["errors_video"] = (M["support_video"] * (1 - M["recall_video"])).round().astype("Int64")
    M["errors_imu"]   = (M["support_imu"]   * (1 - M["recall_imu"])).round().astype("Int64")

    # basic filters
    M = M[(M["support_video"] >= support_min) & (M["support_imu"] >= support_min)].copy()
    if min_errors > 0:
        M = M[(M["errors_video"] >= min_errors) & (M["errors_imu"] >= min_errors)].copy()

    if M.empty:
        raise ValueError("No shared classes after filters; relax support_min/min_errors.")

    # Δ > 0  => IMU increases confusion concentration (worse)
    M["delta"] = (M["frac_imu"] - M["frac_video"]).astype(float)
    worst_classes = M[M["delta"] > 0].sort_values("delta", ascending=False).head(top_k).copy()

    if worst_classes.empty:
        raise ValueError("No classes where IMU increased confusion concentration (Δ <= 0).")

    # ---- map partner labels to IDs so we can pull examples from preds_imu ----
    # build label->id mapping from either report (prefer IMU report)
    label_to_id = dict(zip(report_imu["label"], report_imu[key_on]))
    # fallback if missing
    label_to_id.update({row["label"]: row[key_on] for _, row in report_video.iterrows() if row["label"] not in label_to_id})

    rows = []
    for _, r in worst_classes.iterrows():
        cls_id   = int(r[key_on])
        cls_lab  = r["label"]
        partner_lab = r["partner_imu"]
        partner_id  = label_to_id.get(partner_lab, None)

        # filter IMU predictions: true class == this class AND predicted == partner
        df = preds_imu[preds_imu[true_id_col] == cls_id].copy()
        if partner_id is not None and pred_id_col in df.columns:
            df = df[df[pred_id_col] == partner_id]

        # If still empty (e.g., partner_id not found), keep *all mispredictions* for the class
        if df.empty and pred_id_col in preds_imu.columns:
            df = preds_imu[(preds_imu[true_id_col] == cls_id) &
                           (preds_imu[pred_id_col] != cls_id)].copy()

        # select just the fields you want + context
        keep_cols = [true_id_col, start_frame_col, end_frame_col, trial_col, subject_col]
        missing = [c for c in keep_cols if c not in df.columns]
        if missing:
            # degrade gracefully if any column names differ
            present = [c for c in keep_cols if c in df.columns]
            df = df[present].copy()
        else:
            df = df[keep_cols].copy()

        # attach context for ease of triage
        df.insert(0, "partner_imu", partner_lab)
        df.insert(0, "label", cls_lab)
        df.insert(0, key_on, cls_id)
        df.insert(0, "delta", r["delta"])

        # cap the number of examples per class
        if len(df) > max_examples_per_class:
            df = df.head(max_examples_per_class)

        rows.append(df)

    examples = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()

    # rename columns in output to exactly what you asked for
    rename_map = {
        true_id_col: "keystep_id",
        start_frame_col: "start_frame",
        end_frame_col: "end_frame",
        trial_col: "trial_id",
        subject_col: "subject_id",
    }
    examples = examples.rename(columns={k: v for k, v in rename_map.items() if k in examples.columns})

    # reorder for readability if all requested columns are present
    desired = ["delta", key_on, "label", "partner_imu", "keystep_id",
               "start_frame", "end_frame", "trial_id", "subject_id"]
    examples = examples[[c for c in desired if c in examples.columns]]

    return worst_classes[[
        key_on, "label", "partner_video", "frac_video",
        "partner_imu", "frac_imu", "delta", "errors_video", "errors_imu"
    ]], examples

worst_classes, examples = get_top_worse_classes_with_examples(
    ego_report_df,
    ego_imu_report_df,
    preds_video=pd.read_csv("./results/model_id_job_1173414_task_classification_on_20250715-145545/preds.csv"),
preds_imu=pd.read_csv("./results/model_id_job_1173491_task_classification_on_20250715-163830/preds.csv"),
    top_k=30,
    support_min=1,
    min_errors=1
)

print("Worst classes where IMU increased confusion concentration:")
display(worst_classes)
print("\nExamples from IMU model where it predicted the top confusion partner:")
display(examples)

# save examples to CSV
examples.to_csv("./analysis/worse_with_imu_examples.csv", index=False)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# --- helpers (include once, outside the loop) ---
def _ensure_frame(df: pd.DataFrame, frame_col: str = "frame") -> pd.DataFrame:
    if frame_col not in df.columns:
        df = df.reset_index(drop=True)
        df[frame_col] = np.arange(len(df), dtype=int)
    return df

def _acc_energy(df: pd.DataFrame, x="x_01", y="y_01", z="z_01"):
    if all(c in df.columns for c in (x, y, z)):
        return np.sqrt(df[x]**2 + df[y]**2 + df[z]**2)
    # fallback: try any accel-looking numeric columns
    cand = [c for c in df.columns if c.lower() not in {"timestamp", "frame"}]
    cand = [c for c in cand if df[c].dtype.kind in "fc"]
    if len(cand) >= 3:
        arr = df[cand[:3]].to_numpy(dtype=float)
        return np.sqrt((arr**2).sum(axis=1))
    return None

def _zscore_rolling(x: pd.Series, win: int = 5) -> pd.Series:
    x = (x - x.mean()) / (x.std() + 1e-8)
    return x.rolling(win, center=True, min_periods=1).median()


def analyze_imu_energy(imu_data: pd.DataFrame,
                       imu_segment: pd.DataFrame,
                       row: dict, loc: str = None):
    """
    Analyze IMU accelerometer energy in a labeled segment vs full file.
    Prints diagnostics and creates a plot with context.

    Args:
      imu_data:    Full IMU DataFrame with at least 'x_01','y_01','z_01' columns.
      imu_segment: Segment DataFrame (subset of imu_data) for the labeled window.
      row:         Dictionary with metadata about the segment (e.g., start_frame, end_frame).
    """
    # ------------- inside your loop, REPLACE the plotting block with this -------------
    # Make sure both full IMU file and this segment have 'frame'
    imu_data = _ensure_frame(imu_data)
    imu_segment = _ensure_frame(imu_segment)

    # Compute accelerometer magnitude (energy proxy) for full file and the segment
    E_full = _acc_energy(imu_data)
    E_seg  = _acc_energy(imu_segment)

    if E_full is None or E_seg is None:
        print("⚠️ Could not compute accel energy (missing x_01/y_01/z_01).")
    else:
        # Smooth & z-score
        Ef = _zscore_rolling(pd.Series(E_full, index=imu_data.index))

        # Diagnostics on the labeled segment (use segment indices)
        Es = _zscore_rolling(pd.Series(E_seg,  index=imu_segment.index))
        E = Es.to_numpy()
        frames_seg = imu_segment["frame"].to_numpy()

        med = np.median(E); iqr = np.percentile(E, 75) - np.percentile(E, 25)
        thr = med + 0.5 * (iqr if iqr > 1e-8 else 1.0)
        frac_high = float((E > thr).mean()) if len(E) else np.nan
        peak_rel = np.nan; peak_frame = None; peak_val = None
        if len(E) > 0:
            pk = int(np.argmax(E))
            peak_rel = (frames_seg[pk] - frames_seg[0]) / max(1, (frames_seg[-1] - frames_seg[0]))
            peak_frame = frames_seg[pk]
            peak_val = Es.iloc[pk]

        motion_level = "LOW" if frac_high < 0.25 else ("MED" if frac_high < 0.6 else "HIGH")
        phase_hint   = "EARLY" if (peak_rel is not None and peak_rel < 0.33) else \
                    ("MID"   if (peak_rel is not None and peak_rel < 0.66) else "LATE")

        print(f"IMU energy → frac_high={frac_high:.2f}, peak_rel={peak_rel:.2f} (0=start,1=end)  "
            f"→ Motion:{motion_level}, Peak:{phase_hint}")

        # ---- Single context plot with everything annotated ----
        N = 30  # context frames on each side
        s0, s1 = int(row.get('start_frame')), int(row.get('end_frame'))
        f0 = max(0, s0 - N)
        f1 = s1 + N

        ctx = imu_data[(imu_data["frame"] >= f0) & (imu_data["frame"] <= f1)].copy()
        Ec  = _zscore_rolling(pd.Series(_acc_energy(ctx), index=ctx.index))

        # Set seaborn style
        sns.set_style("whitegrid")
        sns.reset_orig()  # reset to original matplotlib params after seaborn set_context
        
        fig, ax = plt.subplots(figsize=(9, 3))


        # Get seaborn color palette
        colors = sns.color_palette("tab10")

        # 1) plot context energy with seaborn color
        ax.plot(ctx["frame"], Ec, label="Accel energy (z, smoothed)", color=colors[0], linewidth=1.5)

        # 2) shade labeled window
        ax.axvspan(s0, s1, color=colors[2], alpha=0.15, label="Labeled window")

        # 3) context threshold (robust)
        ctx_med = np.nanmedian(Ec)
        ctx_iqr = np.nanpercentile(Ec, 75) - np.nanpercentile(Ec, 25)
        ctx_thr = ctx_med + 0.5 * (ctx_iqr if ctx_iqr > 1e-8 else 1.0)
        # ax.axhline(ctx_thr, ls="--", alpha=0.35, label="Context thr")

        # 4) segment start/end markers
        ax.axvline(s0, color='k', lw=0.8, alpha=0.35)
        ax.axvline(s1, color='k', lw=0.8, alpha=0.35)

        # 5) segment-only threshold (horizontal)
        # ax.axhline(thr, ls=":", alpha=0.6, label="Segment thr")

        # 6) peak marker if available
        # if peak_frame is not None and peak_val is not None:
        #     ax.plot([peak_frame], [peak_val+0.2], marker='o', ms=6)
        #     ax.annotate("peak", (peak_frame, peak_val+0.2),
        #                 xytext=(peak_frame, peak_val + 0.6),
        #                 arrowprops=dict(arrowstyle="->", lw=0.8), ha="center", va="bottom")

        # 7) compact on-plot textbox with diagnostics
        keystep_name = row['label'].replace('_', ' ').title()
        partner_label = row.get('partner_imu', '?').replace('_', ' ').title()

        #
        txt = (f"KS {row.get('keystep_id')} • {row.get('subject_id','?')}-trial:{row.get('trial_id','?')}\n"
            f"Frames {s0}-{s1} | Motion:{motion_level} | Peak:{phase_hint}\n"
            f"frac_high={frac_high:.2f} | peak_rel={peak_rel:.2f}")
        
        if loc == "upper left":
            ax.text(0.02, 0.95, txt, transform=ax.transAxes, va="top", ha="left",
                bbox=dict(boxstyle="round,pad=0.35", fc="white", alpha=0.35, ec="0.2"), fontsize=10)
        if loc == "middle":
            ax.text(0.33, 0.95, txt, transform=ax.transAxes, va="top", ha="left",
                bbox=dict(boxstyle="round,pad=0.35", fc="white", alpha=0.35, ec="0.2"), fontsize=10)

        if loc == "upper right":
            ax.text(0.95, 0.95, txt, transform=ax.transAxes, va="top", ha="right",
                bbox=dict(boxstyle="round,pad=0.35", fc="white", alpha=0.35, ec="0.2"), fontsize=10)

        # ax.set_title(f"Smartwatch IMU Energy • GT: {keystep_name} → PRED: {partner_label}")
        ax.set_title(f"Smartwatch IMU Energy • GT: {keystep_name}")
        ax.set_xlabel("Frame", fontsize=14)
        ax.set_ylabel("Z-scored energy", fontsize=14)
        ax.legend(loc="lower left", fontsize=11, ncol=2, frameon=True)
        ax.grid(axis="y", linestyle="--", alpha=0.25)
        
        # sns.despine(left=True, bottom=True)
        fig.tight_layout()

        # ---- save figure ----
        os.makedirs("analysis", exist_ok=True)
        subj = str(row.get('subject_id','na')).replace('/', '-')
        tri  = str(row.get('trial_id','na')).replace('/', '-')
        ksid = str(row.get('keystep_id','na'))
        fname = f"{subj}_{tri}_ks{ksid}_{s0}-{s1}.png"
        out_path = os.path.join("analysis", fname)
        plt.show()
        fig.savefig(out_path, dpi=150)
        plt.close(fig)
        print(f"Saved figure → {out_path}")

In [None]:
# check examples to diagnose
for _, row in examples.iterrows():
    print(f"Class ID: {row['class_id']}, Label: {row['label']}, Partner: {row['partner_imu']}, "
          f"Keystep ID: {row['keystep_id']}, Trial: {row.get('trial_id', 'N/A')}, "
          f"Subject: {row.get('subject_id', 'N/A')}, Frames: {row.get('start_frame', 'N/A')}-{row.get('end_frame', 'N/A')}")
    subject_id = row.get('subject_id', 'N/A')
    trial_id = row.get('trial_id', 'N/A')

    keystep_id = row['keystep_id']

    print(f"Loading IMU data for keystep ID {keystep_id} from subject {subject_id}, trial {trial_id}...")

    # if keystep_id == 63:
    #     trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/stroke/{trial_id}/"
    #     imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_stroke_{trial_id}_synchronized_smartwatch_01.csv"

    # if keystep_id == 25:
    #     trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/chest_pain/{trial_id}/"
    #     imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_chestpain_{trial_id}_sync_smartwatch.csv"

    if keystep_id == 26:
        trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/chest_pain/{trial_id}/"
        imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_chestpain_{trial_id}_sync_smartwatch.csv"

    # if keystep_id == 0: # approach patient
    #     trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/stroke/{trial_id}/"
    #     imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_stroke_{trial_id}_synchronized_smartwatch_02.csv"

    # if keystep_id == 0: # approach patient
    #     trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/stroke/{trial_id}/"
    #     imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_stroke_{trial_id}_sync_smartwatch.csv"


    # if keystep_id == 15: # no action
    #     trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/cardiac_arrest/{trial_id}/"
    #     imu_path = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_cardiacarrest_{trial_id}_sync_smartwatch.csv"


    else:
        print(f"Skipping keystep ID {keystep_id} for IMU analysis.")
        continue

    # load IMU data
    try:
        imu_data = pd.read_csv(imu_path)
        print(f"Loaded IMU data from: {imu_path}")
        # get the relevant frames
        start_frame = row.get('start_frame', None)
        end_frame = row.get('end_frame', None)
        if start_frame is not None and end_frame is not None:
            # USE INDEX IF frame column does not exist
            if 'frame' not in imu_data.columns:
                imu_data = imu_data.reset_index().rename(columns={'index': 'frame'})
            imu_segment = imu_data[(imu_data['frame'] >= start_frame) & (imu_data['frame'] <= end_frame)]
            print(f"IMU data segment for frames {start_frame}-{end_frame}:")
            # print(imu_segment)

            # # basic statistics
            # print("IMU segment statistics:")

            # # plot IMU data segment
            # plt.figure(figsize=(10, 6))
            # for col in imu_segment.columns:
            #     if col != 'frame' and col != 'timestamp':
            #         plt.plot(imu_segment['frame'], imu_segment[col], label=col)
            # plt.title(f"IMU Data Segment for Keystep ID {row['keystep_id']}")
            # plt.xlabel("Frame")
            # plt.ylabel("Sensor Values")
            # plt.legend()
            # plt.show()

            analyze_imu_energy(imu_data, imu_segment, row)
        else:
            print("Start or end frame missing; cannot extract segment.")

    except Exception as e:
        print(f"Failed to load IMU data from: {imu_path}. Error: {e}")
    print("-----")

    break



In [None]:
# keystep 25 imu data segment
subject_id = "cars_2"
scenario_id = "stroke"
trial_id = 0
imu_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/{scenario_id}/{trial_id}/smartwatch_data/{subject_id.replace('_','')}_{scenario_id}_{trial_id}_synchronized_smartwatch_01.csv"

ks_25_start = int(375.3253*30)  # convert to frames
ks_25_end   = int(377.77643*30)

imu_data_25 = pd.read_csv(imu_path)
imu_data_25 = _ensure_frame(imu_data_25)
imu_segment_25 = imu_data_25[(imu_data_25['frame'] >= ks_25_start) & (imu_data_25['frame'] <= ks_25_end)]

analyze_imu_energy(imu_data_25, imu_segment_25, {
    'keystep_id': 25,
    'label': 'place_v3_lead',
    'start_frame': ks_25_start,
    'end_frame': ks_25_end,
    'subject_id': subject_id,
    'trial_id': trial_id,
}, loc="upper left")




subject_id = "ms1"
scenario_id = "chest_pain"
trial_id = 6
imu_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/ms1/chest_pain/6/smartwatch_data/ms1_chestpain_6_sync_smartwatch.csv"

# "file_path": "/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/cars_1/chest_pain/0/smartwatch_data/cars1_chestpain_0_sync_smartwatch.csv"
    # "/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/cars_1/chestpain/0/smartwatch_data/cars1_chestpain_0_sync_smartwatch.csv"

# /standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/cars1/chestpain/0/smartwatch_data/cars1_chestpain_0_sync_smartwatch.csv
ks_26_start =  int(305.531*30)  
ks_26_end   =  int(313.20165*30)

imu_data_26 = pd.read_csv(imu_path)
imu_data_26 = _ensure_frame(imu_data_26)
imu_segment_26 = imu_data_26[(imu_data_26['frame'] >= ks_26_start) & (imu_data_26['frame'] <= ks_26_end)]   
analyze_imu_energy(imu_data_26, imu_segment_26, {
    'keystep_id': 26,
    'label': 'place_v4_lead',
    'start_frame': ks_26_start,
    'end_frame': ks_26_end,
    'subject_id': subject_id,
    'trial_id': trial_id,
}, loc="upper left")




# {
#     "keystep_id": "2_2uGfna2P",
#     "start_t": 314.784,
#     "end_t": 322.11136,
#     "label": "place_v3_lead",
#     "class_id": 25
# },
ks_25_start =  int( 314.784*30)  
ks_25_end   =  int(322.11136*30)

imu_data_25 = pd.read_csv(imu_path)
imu_data_25 = _ensure_frame(imu_data_25)
imu_segment_25 = imu_data_25[(imu_data_25['frame'] >= ks_25_start) & (imu_data_25['frame'] <= ks_25_end)]
analyze_imu_energy(imu_data_25, imu_segment_25, {
    'keystep_id': 25,
    'label': 'place_v3_lead',
    'start_frame': ks_25_start,
    'end_frame': ks_25_end,
    'subject_id': subject_id,
    'trial_id': trial_id,
}, loc="upper left")







# ks_15_start = int(26.251*30)
# ks_15_end   = int(28.68495*30)
# subject_id = "cars_1"
# scenario_id = "chest_pain"
# ks_15_label = "no_action"
# trial_id = 0
# imu_path = "/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/cars_1/chest_pain/0/smartwatch_data/cars1_chestpain_0_sync_smartwatch.csv"
# imu_data_15 = pd.read_csv(imu_path)
# imu_data_15 = _ensure_frame(imu_data_15)
# imu_segment_15 = imu_data_15[(imu_data_15['frame'] >= ks_15_start) & (imu_data_15['frame'] <= ks_15_end)]
# analyze_imu_energy(imu_data_15, imu_segment_15, {
#     'keystep_id': 15,
#     'label': ks_15_label,
#     'start_frame': ks_15_start,
#     'end_frame': ks_15_end,
#     'subject_id': subject_id,
#     'trial_id': trial_id,
# }, loc="upper right")


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ----------------- helpers (define once) -----------------
def _ensure_frame(df: pd.DataFrame, frame_col: str = "frame") -> pd.DataFrame:
    """Add a monotonically increasing 'frame' column if it's missing."""
    if frame_col not in df.columns:
        df = df.reset_index(drop=True)
        df[frame_col] = np.arange(len(df), dtype=int)
    return df

def _acc_energy(df: pd.DataFrame, x="x_01", y="y_01", z="z_01") -> pd.Series:
    """Return √(x²+y²+z²); fallback to first three numeric, excluding timestamp/frame."""
    if all(c in df.columns for c in (x, y, z)):
        return np.sqrt(df[x]**2 + df[y]**2 + df[z]**2)
    cand = [c for c in df.columns if c.lower() not in {"timestamp", "frame"} and df[c].dtype.kind in "fc"]
    if len(cand) >= 3:
        arr = df[cand[:3]].to_numpy(dtype=float)
        return np.sqrt((arr**2).sum(axis=1))
    return None

def _zscore_rolling(x: pd.Series, win: int = 5) -> pd.Series:
    x = (x - x.mean()) / (x.std() + 1e-8)
    return x.rolling(win, center=True, min_periods=1).median()

def _motion_bin(frac_high: float) -> str:
    return "LOW" if frac_high < 0.25 else ("MED" if frac_high < 0.6 else "HIGH")

def _phase_bin(peak_rel: float ) -> str:
    if peak_rel is None or np.isnan(peak_rel): 
        return "N/A"
    return "EARLY" if peak_rel < 0.33 else ("MID" if peak_rel < 0.66 else "LATE")

# ----------------- main plotting utility -----------------
def plot_trial_confused_segments(
    imu_path: str,
    subject_id: str,
    trial_id: str,
    keystep_segments: list[dict],
    context_pad: int = 30,         # frames before/after the window
    dpi: int = 150,
    out_dir: str = "analysis",
    out_name: str = f"{subject_id}_{trial_id}_segments.png",   # default built from subject/trial
):
    """
    Make a single figure with one stacked subplot per segment, each showing:
      - context energy (z-scored, smoothed),
      - shaded labeled window,
      - context threshold (median + 0.5*IQR over context),
      - segment threshold (median + 0.5*IQR over segment),
      - start/end vertical lines,
      - peak marker + on-plot diagnostics textbox.

    Saves to analysis/<subject>_<trial>_segments.png (or out_name if provided).
    Returns fig, axes for optional further tweaking.
    """
    # Load IMU
    imu_data = pd.read_csv(imu_path)
    print(f"Loaded IMU data from: {imu_path}")
    imu_data = _ensure_frame(imu_data)

    # Precompute full energy (to avoid recalculating)
    E_full = _acc_energy(imu_data)
    if E_full is None:
        raise ValueError("Could not compute accel energy (missing x_01/y_01/z_01 and no numeric fallback).")
    Ef = _zscore_rolling(pd.Series(E_full, index=imu_data.index))

    n = len(keystep_segments)
    if n == 0:
        raise ValueError("keystep_segments is empty.")

    fig, axes = plt.subplots(
        nrows=n, ncols=1,
        figsize=(11, max(3.5, 2.7 * n)),
        sharex=False, sharey=False
    )
    if n == 1:
        axes = [axes]

    for ax, seg in zip(axes, keystep_segments):
        ks = int(seg["keystep_id"])
        s0 = int(seg["start_frame"])
        s1 = int(seg["end_frame"])

        # Context slice
        f0 = max(0, s0 - context_pad)
        f1 = s1 + context_pad
        ctx = imu_data[(imu_data["frame"] >= f0) & (imu_data["frame"] <= f1)].copy()
        Ec  = _zscore_rolling(pd.Series(_acc_energy(ctx), index=ctx.index))

        # Segment slice (within context for consistency)
        seg_mask = (ctx["frame"] >= s0) & (ctx["frame"] <= s1)
        seg_ctx = ctx[seg_mask].copy()
        Es = _zscore_rolling(pd.Series(_acc_energy(seg_ctx), index=seg_ctx.index))

        # Thresholds
        # context threshold
        ctx_med = np.nanmedian(Ec)
        ctx_iqr = np.nanpercentile(Ec, 75) - np.nanpercentile(Ec, 25)
        ctx_thr = ctx_med + 0.5 * (ctx_iqr if ctx_iqr > 1e-8 else 1.0)
        # segment threshold
        E = Es.to_numpy(dtype=float)
        seg_med = np.nanmedian(E)
        seg_iqr = np.nanpercentile(E, 75) - np.nanpercentile(E, 25)
        seg_thr = seg_med + 0.5 * (seg_iqr if seg_iqr > 1e-8 else 1.0)

        # Diagnostics
        frac_high = float((E > seg_thr).mean()) if len(E) else np.nan
        peak_rel = np.nan; peak_frame = None; peak_val = None
        if len(E) > 0:
            # detect peak on RAW magnitude within segment (more intuitive), but annotate at z-score height
            E_raw = _acc_energy(seg_ctx)
            pk = int(np.nanargmax(E_raw))
            frames_seg = seg_ctx["frame"].to_numpy()
            peak_frame = frames_seg[pk]
            peak_rel   = (peak_frame - frames_seg[0]) / max(1, (frames_seg[-1] - frames_seg[0]))
            peak_val   = Es.iloc[pk]  # z-score value at peak frame for vertical position

        motion_level = _motion_bin(frac_high)
        phase_hint   = _phase_bin(peak_rel)

        # --------- plot on one axes ---------
        ax.plot(ctx["frame"], Ec, label="Accel energy (z, smoothed)")
        ax.axvspan(s0, s1, color="orange", alpha=0.15, label="Labeled window")
        ax.axhline(ctx_thr, ls="--", alpha=0.35, label="Context thr")
        ax.axhline(seg_thr, ls=":",  alpha=0.6,  label="Segment thr")
        ax.axvline(s0, color='k', lw=0.8, alpha=0.35)
        ax.axvline(s1, color='k', lw=0.8, alpha=0.35)

        # Peak marker & annotation
        # if (peak_frame is not None) and (peak_val is not None) and np.isfinite(peak_val):
        #     ax.plot([peak_frame], [peak_val], marker='o', ms=5)
        #     ax.annotate(
        #         "peak",
        #         (peak_frame, peak_val),
        #         xytext=(peak_frame, peak_val + 0.35),
        #         arrowprops=dict(arrowstyle="->", lw=0.8),
        #         ha="center", va="bottom"
        #     )

        # Diagnostics textbox
        txt = (f"KS {ks} • Frames {s0}-{s1} • Pad±{context_pad}\n"
               f"Motion:{motion_level} | Peak:{phase_hint}\n"
               f"frac_high={frac_high:.2f} | peak_rel={peak_rel:.2f}")
        ax.text(
            0.01, 0.98, txt, transform=ax.transAxes, va="top", ha="left",
            bbox=dict(boxstyle="round,pad=0.35", fc="white", alpha=0.85, ec="0.5"),
            fontsize=9
        )

        ax.set_ylabel("Z-scored energy")
        ax.grid(axis="y", linestyle="--", alpha=0.25)

    axes[-1].set_xlabel("Frame")

    # Put a single legend on the first axes only (cleaner)
    handles, labels = axes[0].get_legend_handles_labels()
    axes[0].legend(handles, labels, loc="upper right", fontsize=8, ncol=2)

    title = f"IMU Context Energy • {subject_id}/{trial_id} • {len(keystep_segments)} segments"
    fig.suptitle(title, y=0.995, fontsize=12)
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    os.makedirs(out_dir, exist_ok=True)
    if out_name is None:
        out_name = f"{subject_id}_{trial_id}_segments.png"
    out_path = os.path.join(out_dir, out_name)
    fig.savefig(out_path, dpi=dpi)
    print(f"Saved figure → {out_path}")
    return fig, axes

# ----------------- example usage -----------------
subject_id = "ms1"
trial_id   = "7"
trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/chest_pain/{trial_id}/"
imu_path   = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_chestpain_{trial_id}_sync_smartwatch.csv"

# FIXED: dict keys must be strings
keystep_segments = [
    {"keystep_id": 25, "start_frame": 4471, "end_frame": 4691},
    {"keystep_id": 26, "start_frame": 4100, "end_frame": 4218},
]

# run
plot_trial_confused_segments(
    imu_path=imu_path,
    subject_id=subject_id,
    trial_id=trial_id,
    keystep_segments=keystep_segments,
    context_pad=30,   # adjust if you want more/less context
    dpi=150
)


In [None]:
# Python 3.9
import os
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ---------- helpers (raw, no post-processing) ----------
def _ensure_frame(df: pd.DataFrame, frame_col: str = "frame") -> pd.DataFrame:
    if frame_col not in df.columns:
        df = df.reset_index(drop=True)
        df[frame_col] = np.arange(len(df), dtype=int)
    return df

def _acc_magnitude(df: pd.DataFrame, x="sw_value_X_Axis", y="sw_value_Y_Axis", z="sw_value_Z_Axis") -> np.ndarray:
    if not all(c in df.columns for c in (x, y, z)):
        raise ValueError("Expected accelerometer columns: sw_value_X_Axis, sw_value_Y_Axis, sw_value_Z_Axis.")
    arr = df[[x, y, z]].to_numpy(dtype=float)
    return np.sqrt((arr ** 2).sum(axis=1))

def _overlap_at_lag(a: np.ndarray, b: np.ndarray, lag: int) -> Tuple[np.ndarray, np.ndarray]:
    a = np.asarray(a, dtype=float).ravel()
    b = np.asarray(b, dtype=float).ravel()
    n, m = len(a), len(b)
    if lag >= 0:
        L = min(n, m - lag)
        if L <= 1: return np.array([]), np.array([])
        return a[:L], b[lag:lag+L]
    else:
        shift = -lag
        L = min(n - shift, m)
        if L <= 1: return np.array([]), np.array([])
        return a[shift:shift+L], b[:L]

def _max_norm_xcorr(a: np.ndarray, b: np.ndarray, max_lag: int) -> Tuple[float, int]:
    a = np.asarray(a, dtype=float).ravel()
    b = np.asarray(b, dtype=float).ravel()
    best_corr, best_lag = -np.inf, 0
    for lag in range(-max_lag, max_lag + 1):
        aa, bb = _overlap_at_lag(a, b, lag)
        if aa.size <= 1: continue
        aa0 = aa - aa.mean()
        bb0 = bb - bb.mean()
        denom = (np.linalg.norm(aa0) * np.linalg.norm(bb0)) + 1e-8
        corr = float(np.dot(aa0, bb0)) / denom
        if corr > best_corr:
            best_corr, best_lag = corr, lag
    if not np.isfinite(best_corr):  # very short signals
        best_corr, best_lag = 0.0, 0
    return best_corr, best_lag

def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)
    num = float(np.dot(a, b))
    den = float(np.linalg.norm(a) * np.linalg.norm(b)) + 1e-8
    return num / den

def _dtw_distance_raw(a: np.ndarray, b: np.ndarray) -> float:
    n, m = len(a), len(b)
    D = np.full((n + 1, m + 1), np.inf, dtype=float)
    D[0, 0] = 0.0
    for i in range(1, n + 1):
        ai = a[i - 1]
        for j in range(1, m + 1):
            bj = b[j - 1]
            cost = (ai - bj) ** 2
            D[i, j] = cost + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
    return float(np.sqrt(D[n, m]) / (n + m))

def _cosine_centered(a: np.ndarray, b: np.ndarray) -> float:
    a0 = a - a.mean()
    b0 = b - b.mean()
    num = float(np.dot(a0, b0))
    den = float(np.linalg.norm(a0) * np.linalg.norm(b0)) + 1e-8
    return num / den  # == Pearson

def _cosine_delta(a: np.ndarray, b: np.ndarray) -> float:
    # first differences (shortens by 1 sample)
    if len(a) < 2 or len(b) < 2:
        return np.nan
    da = np.diff(a).astype(float)
    db = np.diff(b).astype(float)
    num = float(np.dot(da, db))
    den = float(np.linalg.norm(da) * np.linalg.norm(db)) + 1e-8
    return num / den


# ---------- main (raw plotting + raw similarities) ----------
def plot_trial_segment_similarity_raw(
    imu_path: str,
    subject_id: str,
    trial_id: str,
    keystep_segments: List[Dict[str, int]],
    x_pad: int = 30,      # frames before/after for context plotting only
    dpi: int = 150,
    out_dir: str = "analysis",
    out_name: str = None,
    print_tables: bool = False,    # set True if you still want the CSV/tables
):
    """
    Plot raw accelerometer magnitude for each segment in its context (±x_pad frames),
    vertically stacked. No smoothing/z-scoring/resampling.

    Only the FIRST subplot (reference) shows keystep identity (no similarity stats).
    Subsequent subplots show Cosine and DTW vs the FIRST segment.
    """
    os.makedirs(out_dir, exist_ok=True)
    imu = pd.read_csv(imu_path)
    imu = _ensure_frame(imu)

    # gather raw magnitudes and contexts
    segments = []  # {'ks','ks_name','s0','s1','ctx_frames','mag_ctx','mag_seg'}
    for seg in keystep_segments:
        ks = int(seg["keystep_id"])
        s0 = int(seg["start_frame"]); s1 = int(seg["end_frame"])
        if s1 < s0:
            raise ValueError(f"Bad frames for KS {ks}: end_frame < start_frame")
        f0 = max(0, s0 - x_pad); f1 = s1 + x_pad
        ctx = imu[(imu["frame"] >= f0) & (imu["frame"] <= f1)].copy()
        seg_df = imu[(imu["frame"] >= s0) & (imu["frame"] <= s1)].copy()
        if seg_df.empty:
            raise ValueError(f"Empty segment for KS {ks}: frames {s0}-{s1}")

        mag_ctx = _acc_magnitude(ctx)        # raw
        mag_seg = _acc_magnitude(seg_df)     # raw
        segments.append({
            "ks": ks,
            "ks_name": seg.get("keystep_name", f"KS{ks}"),
            "s0": s0, "s1": s1,
            "ctx_frames": ctx["frame"].to_numpy(),
            "mag_ctx": mag_ctx,
            "mag_seg": mag_seg
        })

    n = len(segments)
    if n == 0:
        raise ValueError("No segments provided.")

    # reference (first segment)
    ref = segments[0]["mag_seg"]
    ref_name = segments[0]["ks_name"]

    # Optional: compute pairwise (only if printing/saving tables)
    if print_tables:
        labels = [f"{d['ks_name']}:{d['s0']}-{d['s1']}" for d in segments]
        COS = np.zeros((n, n), dtype=float)
        DTW = np.zeros((n, n), dtype=float)
        for i in range(n):
            ai = segments[i]["mag_seg"]
            for j in range(n):
                bj = segments[j]["mag_seg"]
                # Cosine on overlapping region at best lag
                max_lag = max(1, int(0.15 * max(len(ai), len(bj))))
                _, lag = _max_norm_xcorr(ai, bj, max_lag=max_lag)
                ao, bo = _overlap_at_lag(ai, bj, lag)
                COS[i, j] = _cosine_similarity(ao, bo) if len(ao) > 1 else np.nan
                DTW[i, j] = _dtw_distance_raw(ai, bj)
        df_cos = pd.DataFrame(COS, index=labels, columns=labels)
        df_dtw = pd.DataFrame(DTW, index=labels, columns=labels)
        print("\nPairwise COSINE (raw, overlap at best lag):")
        print(df_cos.round(3))
        print("\nPairwise DTW distance (raw; lower=more similar):")
        print(df_dtw.round(3))
        # save CSVs
        base = f"{subject_id}_{trial_id}"
        df_cos.to_csv(os.path.join(out_dir, f"{base}_sim_cosine_raw.csv"))
        df_dtw.to_csv(os.path.join(out_dir, f"{base}_sim_dtw_raw.csv"))

    # figure
    fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(11, max(3.0, 2.5 * n)), sharex=False, sharey=False)
    if n == 1:
        axes = [axes]

    for idx, (ax, seg) in enumerate(zip(axes, segments)):
        ks_name = seg["ks_name"]; s0 = seg["s0"]; s1 = seg["s1"]
        frames = seg["ctx_frames"]; mag_ctx = seg["mag_ctx"]; mag_seg = seg["mag_seg"]

        # plot raw magnitude in context
        ax.plot(frames, mag_ctx, label="Accel magnitude (raw)")
        ax.axvspan(s0, s1, color="orange", alpha=0.15, label="Labeled window")
        ax.axvline(s0, color='k', lw=0.8, alpha=0.35)
        ax.axvline(s1, color='k', lw=0.8, alpha=0.35)

        if idx == 0:
            # Reference only: keystep identity, no similarity numbers
            txt = f"{ks_name}  •  Frames {s0}-{s1}  (reference)"
        else:
            # Similarity vs reference: Cosine + DTW only
            max_lag = max(1, int(0.15 * max(len(mag_seg), len(ref))))
            _, lag = _max_norm_xcorr(mag_seg, ref, max_lag=max_lag)
            ao, bo = _overlap_at_lag(mag_seg, ref, lag)
            cos = _cosine_similarity(ao, bo) if len(ao) > 1 else np.nan
            dtw = _dtw_distance_raw(mag_seg, ref)

            
            txt = (f"{ks_name}  •  Frames {s0}-{s1}\n"
                   f"vs {ref_name}  →  cos={cos:.2f}  |  dtw={dtw:.3f}")

        ax.text(0.01, 0.98, txt, transform=ax.transAxes, va="top", ha="left",
                bbox=dict(boxstyle="round,pad=0.35", fc="white", alpha=0.85, ec="0.5"),
                fontsize=9)

        ax.set_ylabel("Accel mag (raw)")
        ax.grid(axis="y", linestyle="--", alpha=0.25)

    axes[-1].set_xlabel("Frame")
    handles, labels_ = axes[0].get_legend_handles_labels()
    axes[0].legend(handles, labels_, loc="upper right", fontsize=8, ncol=1)

    title = f"Raw Segment Similarity • {subject_id}/{trial_id} • {len(segments)} segments"
    fig.suptitle(title, y=0.995, fontsize=12)
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    if out_name is None:
        out_name = f"{subject_id}_{trial_id}_segments_similarity_raw.png"
    out_path = os.path.join(out_dir, out_name)
    fig.savefig(out_path, dpi=dpi)
    plt.show()
    plt.close(fig)
    print(f"Saved figure → {out_path}")

    return fig, axes



subject_id = "ms1"
trial_id   = "7"
trial_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/chest_pain/{trial_id}/"
imu_path   = trial_path + f"smartwatch_data/{subject_id.replace('_','')}_chestpain_{trial_id}_sync_smartwatch.csv"

keystep_segments = [
    {"keystep_id": 25, "keystep_name":"place_v3_lead", "start_frame": 4471, "end_frame": 4691},
    {"keystep_id": 26, "keystep_name":"place_v4_lead", "start_frame": 4100, "end_frame": 4218},
    # add more if needed…
]

plot_trial_segment_similarity_raw(
    imu_path=imu_path,
    subject_id=subject_id,
    trial_id=trial_id,
    keystep_segments=keystep_segments,
    x_pad=30,   # context for plotting only
    dpi=150,
    out_dir="analysis"
)



In [None]:
import pandas as pd
import ast
from typing import List, Dict

def extract_keystep_segments(csv_path: str, target_keysteps: List[int] = [25, 26]) -> pd.DataFrame:
    """
    Load the predictions CSV and extract segments belonging to specific keystep IDs.

    Parameters
    ----------
    csv_path : str
        Path to your predictions CSV file.
    target_keysteps : list[int]
        List of keystep IDs to filter (default: [25, 26])

    Returns
    -------
    pd.DataFrame
        Filtered DataFrame containing only the desired keysteps, with parsed predictions.
    """
    # --- Load and basic cleanup ---
    df = pd.read_csv(csv_path)
    if "keystep_id" not in df.columns:
        raise ValueError("CSV must contain a 'keystep_id' column.")

    # --- Filter by keystep IDs ---
    df_filtered = df[df["keystep_id"].isin(target_keysteps)].copy()

    # --- Parse all_preds safely ---
    def parse_preds(x):
        try:
            arr = ast.literal_eval(x)
            return arr[0] if isinstance(arr, list) and len(arr) > 0 else []
        except Exception:
            return []

    df_filtered["parsed_preds"] = df_filtered["all_preds"].apply(parse_preds)

    # --- Useful metadata fields ---
    cols_keep = [
        "keystep_label", "keystep_id", "start_frame", "end_frame",
        "subject_id", "trial_id", "pred_keystep_id", "parsed_preds"
    ]
    existing_cols = [c for c in cols_keep if c in df_filtered.columns]
    df_filtered = df_filtered[existing_cols]

    # --- Convert frames to int for consistency ---
    df_filtered["start_frame"] = df_filtered["start_frame"].astype(int)
    df_filtered["end_frame"]   = df_filtered["end_frame"].astype(int)

    print(f"Extracted {len(df_filtered)} segments for keysteps {target_keysteps}.")
    print(df_filtered[["subject_id", "trial_id", "keystep_id", "start_frame", "end_frame"]].head())

    return df_filtered

def find_imu_data(subject_id, trial_id):
    # Placeholder function to find IMU segments
    # In practice, this would involve loading the IMU data and extracting the relevant segment

    imu_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/chest_pain/{trial_id}/smartwatch_data/"
    # check if imu_path exists
    if not os.path.exists(imu_path):
        imu_path = f"/standard/UVA-DSA/NIST EMS Project Data/EgoEMS_AAAI2026/{subject_id}/stroke/{trial_id}/smartwatch_data/"
        if not os.path.exists(imu_path):
            raise FileNotFoundError(f"IMU data path not found for subject {subject_id}, trial {trial_id}.")
    # find the csv file in imu_path
    imu_files = [f for f in os.listdir(imu_path) if f.endswith(".csv")]
    if not imu_files:
        raise FileNotFoundError(f"No CSV files found in {imu_path}.")
    

    imu_csv = os.path.join(imu_path, imu_files[0])
    imu_data = pd.read_csv(imu_csv)

    # ensure frame column
    imu_data = _ensure_frame(imu_data)
    

    return imu_data


In [None]:
imu_csv_path = "./results/model_id_job_1173491_task_classification_on_20250715-163830/preds.csv"
segments_25_26 = extract_keystep_segments(imu_csv_path, target_keysteps=[25, 26])

print("\n")

ks_25_segments = []
ks_26_segments = []

# Example: iterate or export to inspect
for _, row in segments_25_26.iterrows():
    print(f"KS{row.keystep_id} | {row.subject_id}/{row.trial_id} | "
          f"Frames {row.start_frame}-{row.end_frame} | Pred→{row.pred_keystep_id}")
    
    # find the imu data for this subject/trial
    imu_data = find_imu_data(row.subject_id, row.trial_id)

    imu_segment = imu_data[(imu_data['frame'] >= row.start_frame) & (imu_data['frame'] <= row.end_frame)]

    if row.keystep_id == 25:
        ks_25_segments.append(imu_segment)
    elif row.keystep_id == 26:
        ks_26_segments.append(imu_segment)

#     print(len(imu_segment), "frames in segment")

In [None]:
import numpy as np
import pandas as pd
from typing import List, Tuple

# ---------- signal helpers ----------
def _get_xyz_cols(df: pd.DataFrame) -> Tuple[str, str, str]:
    print(df.columns)
    # support both schemas
    if all(c in df.columns for c in ("x_01","y_01","z_01")):
        return "x_01","y_01","z_01"
    if all(c in df.columns for c in ("x_02","y_02","z_02")):
        return "x_02","y_02","z_02"
    if all(c in df.columns for c in ("x_03","y_03","z_03")):
        return "x_03","y_03","z_03"
    if all(c in df.columns for c in ("sw_value_X_Axis","sw_value_Y_Axis","sw_value_Z_Axis")):
        return "sw_value_X_Axis","sw_value_Y_Axis","sw_value_Z_Axis"
    raise ValueError("IMU DataFrame missing expected accel columns.")

def _acc_mag(df: pd.DataFrame) -> np.ndarray:
    x,y,z = _get_xyz_cols(df)
    arr = df[[x,y,z]].to_numpy(dtype=float)
    return np.sqrt((arr**2).sum(axis=1))

def _overlap_at_lag(a: np.ndarray, b: np.ndarray, lag: int) -> Tuple[np.ndarray, np.ndarray]:
    a = np.asarray(a, float).ravel(); b = np.asarray(b, float).ravel()
    n, m = len(a), len(b)
    if lag >= 0:
        L = min(n, m-lag)
        if L <= 1: return np.array([]), np.array([])
        return a[:L], b[lag:lag+L]
    else:
        sh = -lag
        L = min(n-sh, m)
        if L <= 1: return np.array([]), np.array([])
        return a[sh:sh+L], b[:L]

def _max_norm_xcorr(a: np.ndarray, b: np.ndarray, max_lag: int) -> Tuple[float,int]:
    a = np.asarray(a, float).ravel(); b = np.asarray(b, float).ravel()
    best_corr, best_lag = -np.inf, 0
    for lag in range(-max_lag, max_lag+1):
        aa, bb = _overlap_at_lag(a, b, lag)
        if aa.size <= 1: continue
        aa0 = aa - aa.mean(); bb0 = bb - bb.mean()
        denom = (np.linalg.norm(aa0)*np.linalg.norm(bb0)) + 1e-8
        corr = float(np.dot(aa0, bb0)) / denom
        if corr > best_corr:
            best_corr, best_lag = corr, lag
    if not np.isfinite(best_corr): best_corr, best_lag = 0.0, 0
    return best_corr, best_lag

def _cosine_centered(a: np.ndarray, b: np.ndarray) -> float:
    a0 = a - a.mean(); b0 = b - b.mean()
    denom = (np.linalg.norm(a0)*np.linalg.norm(b0)) + 1e-8
    return float(np.dot(a0,b0)) / denom

def _dtw_raw(a: np.ndarray, b: np.ndarray) -> float:
    n, m = len(a), len(b)
    D = np.full((n+1, m+1), np.inf); D[0,0] = 0.0
    for i in range(1, n+1):
        ai = a[i-1]
        for j in range(1, m+1):
            bj = b[j-1]
            cost = (ai - bj)**2
            D[i,j] = cost + min(D[i-1,j], D[i,j-1], D[i-1,j-1])
    return float(np.sqrt(D[n,m])/(n+m))

def _cliffs_delta(x: np.ndarray, y: np.ndarray) -> float:
    # rank-based effect size in [-1,1]
    x = np.asarray(x); y = np.asarray(y)
    x_sorted = np.sort(x); y_sorted = np.sort(y)
    i = j = more = less = 0
    nx, ny = len(x_sorted), len(y_sorted)
    while i < nx and j < ny:
        if x_sorted[i] > y_sorted[j]:
            more += nx - i; j += 1
        elif x_sorted[i] < y_sorted[j]:
            less += ny - j; i += 1
        else:
            # handle ties by advancing both
            ii, jj = i, j
            while ii < nx and x_sorted[ii] == x_sorted[i]: ii += 1
            while jj < ny and y_sorted[jj] == y_sorted[j]: jj += 1
            ties_x = ii - i; ties_y = jj - j
            i, j = ii, jj
    return (more - less) / (nx*ny + 1e-8)

# ---------- feature extraction for each segment ----------
def _segment_features(seg_df: pd.DataFrame) -> dict:
    mag = _acc_mag(seg_df)
    if mag.size == 0:
        return {"mean":np.nan,"std":np.nan,"energy":np.nan,"peak_rel":np.nan,"frac_high":np.nan}
    med = np.median(mag); iqr = np.percentile(mag,75)-np.percentile(mag,25)
    thr = med + 0.5*(iqr if iqr>1e-8 else 1.0)
    peak_rel = float(np.argmax(mag) / max(1, (len(mag)-1)))
    return {
        "mean": float(mag.mean()),
        "std": float(mag.std(ddof=0)),
        "energy": float(np.sum(mag**2)),
        "peak_rel": peak_rel,
        "frac_high": float((mag > thr).mean())
    }

# ---------- compute distributions & similarities ----------
import os
import numpy as np
import pandas as pd
from typing import List, Tuple

# ... keep your helpers _get_xyz_cols, _acc_mag, _overlap_at_lag, _max_norm_xcorr,
#     _cosine_centered, _dtw_raw, _cliffs_delta, _segment_features as-is ...

def analyze_ks25_ks26(
    ks25: List[pd.DataFrame],
    ks26: List[pd.DataFrame],
    print_examples: int = 5,
    out_dir: str = "analysis"
):
    os.makedirs(out_dir, exist_ok=True)

    # ---------- Per-segment features ----------
    rows = []
    for df in ks25:
        f = _segment_features(df); f["keystep"] = 25; rows.append(f)
    for df in ks26:
        f = _segment_features(df); f["keystep"] = 26; rows.append(f)
    feats = pd.DataFrame(rows)

    def med_iqr_vals(x: np.ndarray) -> Tuple[float, float, float]:
        x = x[~np.isnan(x)]
        if x.size == 0:
            return np.nan, np.nan, np.nan
        return float(np.nanmedian(x)), float(np.nanpercentile(x,25)), float(np.nanpercentile(x,75))

    # ---------- Build magnitude lists ----------
    mags25 = [ _acc_mag(df) for df in ks25 ]
    mags26 = [ _acc_mag(df) for df in ks26 ]

    # ---------- Pairwise similarities ----------
    def pairwise_stats(A: List[np.ndarray], B: List[np.ndarray]):
        cos_c = []; dtw = []; pairs = []
        for i, a in enumerate(A):
            for j, b in enumerate(B):
                max_lag = max(1, int(0.15*max(len(a), len(b))))
                _, lag = _max_norm_xcorr(a, b, max_lag)
                ao, bo = _overlap_at_lag(a, b, lag)
                cos_c.append(_cosine_centered(ao, bo) if len(ao) > 1 else np.nan)
                dtw.append(_dtw_raw(a, b))
                pairs.append((i, j))
        return np.array(cos_c), np.array(dtw), pairs

    cos25_25, dtw25_25, _ = pairwise_stats(mags25, mags25)  # includes diagonal
    cos26_26, dtw26_26, _ = pairwise_stats(mags26, mags26)
    # remove diagonals (i==j)
    def rm_diag(vals: np.ndarray, n: int) -> np.ndarray:
        if n <= 1: return np.array([])
        mask = np.ones(n*n, dtype=bool)
        mask[np.arange(0, n*n, n+1)] = False
        return vals[mask]
    if len(mags25) > 1:
        cos25_25 = rm_diag(cos25_25, len(mags25))
        dtw25_25 = rm_diag(dtw25_25, len(mags25))
    else:
        cos25_25 = np.array([]); dtw25_25 = np.array([])
    if len(mags26) > 1:
        cos26_26 = rm_diag(cos26_26, len(mags26))
        dtw26_26 = rm_diag(dtw26_26, len(mags26))
    else:
        cos26_26 = np.array([]); dtw26_26 = np.array([])

    cos25_26, dtw25_26, pairs_25_26 = pairwise_stats(mags25, mags26)

    # ---------- Summaries to DataFrame ----------
    rows_summary = []
    for name, cos_vals, dtw_vals in [
        ("Within KS25", cos25_25, dtw25_25),
        ("Within KS26", cos26_26, dtw26_26),
        ("Cross KS25↔KS26", cos25_26, dtw25_26),
    ]:
        cos_med, cos_q25, cos_q75 = med_iqr_vals(cos_vals)
        dtw_med, dtw_q25, dtw_q75 = med_iqr_vals(dtw_vals)
        rows_summary.append({
            "group": name,
            "n_pairs_cos": int(np.sum(~np.isnan(cos_vals))),
            "centered_cos_median": cos_med,
            "centered_cos_q25": cos_q25,
            "centered_cos_q75": cos_q75,
            "n_pairs_dtw": int(np.sum(~np.isnan(dtw_vals))),
            "dtw_median": dtw_med,
            "dtw_q25": dtw_q25,
            "dtw_q75": dtw_q75,
        })
    sim_summary_df = pd.DataFrame(rows_summary)
    sim_summary_path = os.path.join(out_dir, "ks25_ks26_similarity_summary.csv")
    sim_summary_df.to_csv(sim_summary_path, index=False)

    # ---------- Effect sizes (Cliff’s δ: cross vs within) ----------
    eff_rows = []
    if cos25_25.size:
        eff_rows.append({
            "metric": "centered_cosine",
            "reference": "Within KS25",
            "delta_cross_vs_within": _cliffs_delta(cos25_26[~np.isnan(cos25_26)], cos25_25[~np.isnan(cos25_25)])
        })
    if cos26_26.size:
        eff_rows.append({
            "metric": "centered_cosine",
            "reference": "Within KS26",
            "delta_cross_vs_within": _cliffs_delta(cos25_26[~np.isnan(cos25_26)], cos26_26[~np.isnan(cos26_26)])
        })
    # For DTW (lower=more similar): flip sign so positive means cross is "more similar" than within
    if dtw25_25.size:
        eff_rows.append({
            "metric": "dtw",
            "reference": "Within KS25",
            "delta_cross_vs_within": _cliffs_delta(-dtw25_26[~np.isnan(dtw25_26)], -dtw25_25[~np.isnan(dtw25_25)])
        })
    if dtw26_26.size:
        eff_rows.append({
            "metric": "dtw",
            "reference": "Within KS26",
            "delta_cross_vs_within": _cliffs_delta(-dtw25_26[~np.isnan(dtw25_26)], -dtw26_26[~np.isnan(dtw26_26)])
        })
    effect_sizes_df = pd.DataFrame(eff_rows)
    effect_sizes_path = os.path.join(out_dir, "ks25_ks26_effect_sizes.csv")
    effect_sizes_df.to_csv(effect_sizes_path, index=False)

    # ---------- Top cross-class pairs ----------
    k = min(print_examples, len(dtw25_26))
    top_pairs_dtw_df = pd.DataFrame(columns=["ks25_index","ks26_index","dtw"])
    top_pairs_cos_df = pd.DataFrame(columns=["ks25_index","ks26_index","centered_cosine"])
    if k > 0:
        idx_sorted_dtw = np.argsort(dtw25_26)[:k]
        top_pairs_dtw_df = pd.DataFrame(
            [{"ks25_index": int(pairs_25_26[idx][0]),
              "ks26_index": int(pairs_25_26[idx][1]),
              "dtw": float(dtw25_26[idx])} for idx in idx_sorted_dtw]
        )
        top_pairs_dtw_df.to_csv(os.path.join(out_dir, "ks25_ks26_top_pairs_dtw.csv"), index=False)

        idx_sorted_cos = np.argsort(-cos25_26)[:k]
        top_pairs_cos_df = pd.DataFrame(
            [{"ks25_index": int(pairs_25_26[idx][0]),
              "ks26_index": int(pairs_25_26[idx][1]),
              "centered_cosine": float(cos25_26[idx])} for idx in idx_sorted_cos]
        )
        top_pairs_cos_df.to_csv(os.path.join(out_dir, "ks25_ks26_top_pairs_cos.csv"), index=False)

    # Also keep your original per-segment features CSV write if desired elsewhere
    return feats, sim_summary_df, effect_sizes_df, top_pairs_dtw_df, top_pairs_cos_df


feats_df, sim_df, eff_df, top_dtw_df, top_cos_df = analyze_ks25_ks26(
    ks_25_segments, ks_26_segments, print_examples=5, out_dir="analysis"
)

feats_df.to_csv("analysis/ks25_ks26_segment_features.csv", index=False)
print("Wrote:",
      "analysis/ks25_ks26_segment_features.csv,",
      "analysis/ks25_ks26_similarity_summary.csv,",
      "analysis/ks25_ks26_effect_sizes.csv,",
      "analysis/ks25_ks26_top_pairs_dtw.csv,",
      "analysis/ks25_ks26_top_pairs_cos.csv")