<a href="https://colab.research.google.com/github/KravitzLab/FED3Analyses/blob/main/FED3_Bandit_V1_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This notebook analyzes FED3 Bandit data
<br>
<img src="https://fed3bandit.readthedocs.io/en/latest/_static/fed3bandit_logo1.svg" width="200" />


Authors: Chantelle Murrell and Sebastian Alves<br>
Updated: 12-01-25  
Version 1.0.1

In [None]:
# @title Install libraries and import them {"run":"auto"}

import importlib.util
import subprocess
import sys

# Packages to ensure are installed (add others here if you like)
packages = {
    "fed3": "git+https://github.com/earnestt1234/fed3.git",
    "fed3bandit": "fed3bandit",
    "pingouin": "pingouin",
    "ipydatagrid": "ipydatagrid",
    "openpyxl": "openpyxl",
}

for name, source in packages.items():
    if importlib.util.find_spec(name) is None:
        print(f"Installing {name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", source])

# ----------------------------
# Imports
# ----------------------------
# Standard library
import copy
import io
import math
import os
import re
import shutil
import tempfile
import threading
import time
import warnings
import zipfile
import glob
from datetime import datetime, timedelta
from os.path import basename, splitext

# Third-party
from ipydatagrid import DataGrid, TextRenderer
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pingouin as pg
import fed3
import fed3.plot as fplot
import fed3bandit as f3b
from scipy.stats import f_oneway
import statsmodels.api as sm
from statsmodels.formula.api import ols
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.gridspec as gridspec
from google.colab import files
from tqdm.auto import tqdm

# ----------------------------
# Configuration
# ----------------------------
warnings.filterwarnings("ignore")
plt.rcParams.update({'font.size': 12, 'figure.autolayout': True})
plt.rcParams['figure.figsize'] = [6, 4]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("Packages installed and imports ready.")


In [None]:
#@title Upload files

# Reset caches to avoid duplicates if you re-run this cell
feds, loaded_files, session_types = [], [], []

def extract_session_type(csv_path, fallback="Unknown"):
    """Read 'Session_Type ' or variants; return first non-empty value."""
    try:
        df = pd.read_csv(csv_path, sep=None, engine='python', dtype=str)
        df.columns = [c.strip() for c in df.columns]
        lower = {c.casefold(): c for c in df.columns}
        for cand in ["session_type", "session type", "sessiontype", "session"]:
            if cand in lower:
                col = lower[cand]
                vals = df[col].dropna().astype(str).str.strip()
                vals = vals[vals.ne("")]
                if not vals.empty:
                    return vals.iloc[0]
    except Exception:
        pass
    return fallback

uploaded = files.upload()

for name, data in uploaded.items():
    if name.lower().endswith(".zip"):
        with zipfile.ZipFile(io.BytesIO(data)) as zf:
            for zi in zf.infolist():
                if not zi.filename.lower().endswith(".csv"):
                    continue
                file_data = zf.read(zi)
                if len(file_data) <= 1024:
                    continue
                with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv", delete=False) as tmp:
                    tmp.write(file_data); tmp_path = tmp.name
                try:
                    session_type = extract_session_type(tmp_path)
                    df = fed3.load(tmp_path)
                    df.name = os.path.basename(zi.filename)
                    df.attrs = {"Session_type": session_type}
                    feds.append(df)
                    loaded_files.append(os.path.basename(zi.filename))
                    session_types.append(session_type)
                except Exception as e:
                    print(f"Error loading {zi.filename}: {e}")
                finally:
                    os.remove(tmp_path)
    elif name.lower().endswith(".csv"):
        if len(data) <= 1024:
            continue
        with tempfile.NamedTemporaryFile(mode="w+b", suffix=".csv", delete=False) as tmp:
            tmp.write(data); tmp_path = tmp.name
        try:
            session_type = extract_session_type(tmp_path)
            df = fed3.load(tmp_path)
            df.name = os.path.basename(name)
            df.attrs = {"Session_type": session_type}
            feds.append(df)
            loaded_files.append(os.path.basename(name))
            session_types.append(session_type)
        except Exception as e:
            print(f"Error loading {name}: {e}")
        finally:
            os.remove(tmp_path)

print(f"Loaded {len(loaded_files)} files. Session types captured for all.")
# Optional quick plot
if feds:
    try:
        fed3.as_aligned(feds, alignment="datetime", inplace=True)
        plt.figure(figsize=(8, 4))
        fplot.line(feds, y='pellets'); plt.legend().remove(); plt.tight_layout(); plt.show()
    except Exception as e:
        print(f"Plotting skipped: {e}")

In [None]:
# @title Build Key


import os, glob, io
import numpy as np
import pandas as pd
from ipydatagrid import DataGrid, TextRenderer
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from google.colab import files as colab_files
from google.colab import output as colab_output

# Require that the file-upload cell has already populated these:
assert 'loaded_files' in globals() and 'session_types' in globals(), \
    "Run the 'Upload FED3 files' cell first."

colab_output.enable_custom_widget_manager()

# ---------------------------
# Base: bare-bones Key_Df from loaded data
# ---------------------------
def _make_base_key_df():
    return pd.DataFrame({"filename": loaded_files, "Session_type": session_types})

def _file_base(s):
    return os.path.splitext(os.path.basename(str(s)))[0].strip()

def _norm_base_lower(s):
    return _file_base(s).lower()

# ---------------------------
# Key scanner: detect Mouse_ID or filename columns
# ---------------------------
def _scan_key_columns(df):
    """
    Returns dict:
      {
        'has_mouse': bool,
        'has_filename': bool,
        'filename_col': 'filename'|'File'|None,
        'msg': str
      }
    Accepts keys that have either Mouse_ID or a filename column (filename/File).
    """
    info = {'has_mouse': False, 'has_filename': False, 'filename_col': None, 'msg': ''}
    try:
        cols = [str(c).strip() for c in df.columns]
        has_mouse = 'Mouse_ID' in cols
        fname_col = 'filename' if 'filename' in cols else ('File' if 'File' in cols else None)
        info.update({
            'has_mouse': has_mouse,
            'has_filename': fname_col is not None,
            'filename_col': fname_col
        })
        if has_mouse:
            info['msg'] = "'Mouse_ID' found."
        elif fname_col:
            info['msg'] = f"'{fname_col}' found; will match on filename."
        else:
            info['msg'] = "Neither 'Mouse_ID' nor 'filename'/'File' found in provided key."
    except Exception as e:
        info['msg'] = f"Error while checking key: {e}"
    return info

# ---------------------------
# Read uploaded key (CSV/XLSX), accept Mouse_ID or filename
# ---------------------------
def _read_key_from_upload(name, content_bytes):
    """Return (df_or_None, message). Reads CSV/XLSX bytes from Colab upload."""
    ext = name.lower().rsplit('.', 1)[-1] if '.' in name else ''
    try:
        bio = io.BytesIO(content_bytes)
        if ext == 'xlsx':
            xls = pd.ExcelFile(bio, engine='openpyxl')
            frames = [pd.read_excel(xls, sheet_name=s) for s in xls.sheet_names]
            key_df = pd.concat(frames, ignore_index=True, sort=False)
        elif ext == 'csv':
            key_df = pd.read_csv(bio, sep=None, engine='python')
        else:
            return None, f"Unsupported key type .{ext}"

        key_df = key_df.copy()
        key_df.columns = [str(c).strip() for c in key_df.columns]
        scan = _scan_key_columns(key_df)
        if not (scan['has_mouse'] or scan['has_filename']):
            return None, scan['msg']

        # Normalize types/columns we might use later
        if scan['has_mouse']:
            globals()['KEY_MATCH_MODE'] = 'mouse_id'
            key_df['Mouse_ID'] = key_df['Mouse_ID'].astype(str).str.strip()

        if scan['has_filename']:
            globals()['KEY_MATCH_MODE'] = 'filename'
            fcol = scan['filename_col']
            key_df[fcol] = key_df[fcol].astype(str).str.strip()
            key_df['_key_file_base_lower'] = key_df[fcol].map(_norm_base_lower)
        else:
            globals()['KEY_MATCH_MODE'] = None

        # Persist a deterministic copy on disk for reproducibility
        fixed_path = f"_uploaded_key.{ext}"
        with open(fixed_path, "wb") as f:
            f.write(content_bytes)
        globals()['uploaded_key_path'] = fixed_path

        return key_df, f"Key loaded from upload ({name}) and saved to {fixed_path}. {scan['msg']}"
    except Exception as e:
        return None, f"Error reading uploaded key: {e}"

# ---------------------------
# Matching filename <-> Mouse_ID
# ---------------------------
def _match_mouse_id_to_filenames(filenames, key_df):
    """Return DataFrame: filename, Mouse_ID, match_status based on Mouse_ID substring in filename."""
    base_names_lower = [_norm_base_lower(f) for f in filenames]
    mouse_ids = (
        key_df['Mouse_ID']
        .dropna().astype(str).map(str.strip)
        .replace({'': np.nan}).dropna().unique().tolist()
    )
    rows = []
    for fname, base in zip(filenames, base_names_lower):
        hits = [mid for mid in mouse_ids if str(mid).lower() in base]
        if len(hits) == 1:
            rows.append({"filename": fname, "Mouse_ID": hits[0], "match_status": "Matched (Mouse_ID in filename)"})
        elif len(hits) > 1:
            longest = max(len(str(h)) for h in hits)
            best = [h for h in hits if len(str(h)) == longest]
            if len(best) == 1:
                rows.append({"filename": fname, "Mouse_ID": best[0], "match_status": "Matched (longest Mouse_ID token)"})
            else:
                rows.append({"filename": fname, "Mouse_ID": None, "match_status": f"Ambiguous Mouse_ID: {hits}"})
        else:
            rows.append({"filename": fname, "Mouse_ID": None, "match_status": "Mouse_ID not found in filename"})
    return pd.DataFrame(rows)

# ---------------------------
# Build/Rematch function
# ---------------------------
status_box = widgets.Output()
Key_Df = _make_base_key_df()  # start bare-bones

def _collapse_key_suffixes(df, keep_filename_from_base=True):
    """
    Remove duplicate columns created by merge suffixes.
    - If both 'col' and 'col_key' exist, keep the key version (rename col_key -> col, drop base col).
    - Special-case: if keep_filename_from_base and base == 'filename',
      drop 'filename_key' and keep the base 'filename'.
    """
    cols = list(df.columns)
    rename_map = {}
    drop_cols = []

    for col in cols:
        if col.endswith('_key'):
            base = col[:-4]
            if keep_filename_from_base and base == 'filename':
                # keep the 'filename' from the bare-bones (actual loaded_files)
                drop_cols.append(col)
            elif base in df.columns:
                # key version wins: drop base col, rename *_key -> base
                drop_cols.append(base)
                rename_map[col] = base
            else:
                # no base column; just strip "_key"
                rename_map[col] = base

    df = df.drop(columns=drop_cols, errors='ignore').rename(columns=rename_map)
    return df


def build_or_rematch_key_df(key_df=None, msg_hint=""):
    """
    If key_df provided and valid:
      - If a filename column exists (filename/File), always use filename-based merge.
        - If Mouse_ID is also present, it is preserved from the key file.
      - Else, if only Mouse_ID exists, use substring-based Mouse_ID matching.
    Else: keep bare-bones.

    When a key is used, any duplicate column names between the bare-bones
    and the key are resolved so that the key's values win (except 'filename').
    """
    global Key_Df
    files_df = _make_base_key_df().copy()
    files_df['_file_base_lower'] = files_df['filename'].map(_norm_base_lower)

    if key_df is None:
        # Pure bare-bones Key_Df
        Key_Df = files_df.drop(columns=['_file_base_lower']).copy()
        Key_Df["match_status"] = "No key"
        with status_box:
            clear_output(wait=True)
            print("Key status: No key provided; showing bare-bones Key_Df.")
        return

    # Identify key capabilities
    scan = _scan_key_columns(key_df)
    kd = key_df.copy()

    # Helper: make a unique version of the key for whichever join we use
    def _dedup(df, subset_cols):
        dup_counts = df[subset_cols].astype(str).agg('|'.join, axis=1).value_counts()
        n_dups = int((dup_counts > 1).sum())
        if n_dups:
            with status_box:
                clear_output(wait=True)
                print(f"Note: {n_dups} duplicate key(s) on {subset_cols}; taking the first occurrence.")
        return df.drop_duplicates(subset=subset_cols, keep="first")

    # ---------------------------
    # 1) Prefer filename-based route whenever filename column exists
    #    (whether or not Mouse_ID is present)
    # ---------------------------
    if scan['has_filename']:
        fcol = scan['filename_col']

        # Normalize filename column in key
        kd[fcol] = kd[fcol].astype(str).str.strip()
        kd['_key_file_base_lower'] = kd[fcol].map(_norm_base_lower)

        # One row per normalized filename in the key
        key_unique = _dedup(kd, ['_key_file_base_lower'])

        # Merge by normalized basename (case-insensitive, extension-stripped)
        Key_Df = (
            files_df
            .merge(
                key_unique,
                left_on="_file_base_lower",
                right_on="_key_file_base_lower",
                how="left",
                suffixes=("", "_key")
            )
            .drop(columns=['_file_base_lower', '_key_file_base_lower'])
        )

        # Clean up duplicate columns so key overwrites bare-bones where appropriate
        Key_Df = _collapse_key_suffixes(Key_Df, keep_filename_from_base=True)

        # match_status for filename-based matching
        Key_Df["match_status"] = np.where(
            Key_Df[fcol].notna(),
            "Matched (filename)",
            "Filename not found in key"
        )

    # ---------------------------
    # 2) If no filename column but Mouse_ID exists, use substring route
    # ---------------------------
    elif scan['has_mouse']:
        # Mouse_ID route (fallback when no filename column exists)
        kd['Mouse_ID'] = kd['Mouse_ID'].astype(str).str.strip()
        key_unique = _dedup(kd, ['Mouse_ID'])

        matched = _match_mouse_id_to_filenames(
            files_df['filename'].tolist(),
            key_unique
        )[["filename", "Mouse_ID", "match_status"]]

        Key_Df = (
            files_df
            .merge(matched, on="filename", how="left")
            .merge(key_unique, on="Mouse_ID", how="left", suffixes=("", "_key"))
            .drop(columns=['_file_base_lower'])
        )

        # Clean up duplicate columns (e.g., Mouse_ID vs Mouse_ID_key)
        Key_Df = _collapse_key_suffixes(Key_Df, keep_filename_from_base=True)

    # ---------------------------
    # 3) Neither filename nor Mouse_ID in key
    # ---------------------------
    else:
        # Neither route available (shouldn't happen due to earlier check)
        Key_Df = files_df.drop(columns=['_file_base_lower']).copy()
        Key_Df["match_status"] = "Key missing Mouse_ID and filename columns"

    with status_box:
        clear_output(wait=True)
        if msg_hint:
            print(msg_hint)
        print(f"Merged key columns into Key_Df ({len(Key_Df)} rows, {len(Key_Df.columns)} cols).")

# ---------------------------
# Grid UI
# ---------------------------
def make_grid(df: pd.DataFrame):
    g = DataGrid(
        df,
        editable=True,
        selection_mode='cell',
        layout={'height': '420px'},
        base_row_size=28,
        base_column_size=120,
    )
    g.default_renderer = TextRenderer(text_wrap=True)
    return g

def rebuild_grid(msg=""):
    global grid, ui
    df = Key_Df.copy().reset_index(drop=True)
    new_grid = make_grid(df)
    ui.children = (upload_row, new_grid, controls, status_box)
    grid = new_grid
    with status_box:
        if msg:
            print(msg)
        print(f"Grid now shows Key_Df ({len(df)} rows, {len(df.columns)} cols)")

# ---------------------------
# Colab-native upload button ONLY (no path UI)
# ---------------------------
upload_btn = widgets.Button(description="Upload", button_style="primary", layout=widgets.Layout(width="120px"))
reset_btn = widgets.Button(description="Reset Key", button_style="warning", layout=widgets.Layout(width="120px"))
download_button = widgets.Button(description='Download', button_style='success', layout=widgets.Layout(width="120px"))

def on_colab_upload(_):
    with status_box:
        clear_output(wait=True)
        print("Opening Colab upload dialog...")
    uploaded = colab_files.upload()  # opens the native Colab picker
    if not uploaded:
        with status_box:
            print("No file selected.")
        return
    name, content = next(iter(uploaded.items()))
    key_df, msg = _read_key_from_upload(name, content)
    if key_df is None:
        build_or_rematch_key_df(None)
        rebuild_grid(f"Key status: {msg}")
    else:
        build_or_rematch_key_df(key_df, msg_hint=f"Key status: {msg}")
        rebuild_grid()

def _load_saved_key_from_disk():
    """Returns (key_df_or_None, message) from uploaded_key_path if present/valid."""
    key_path = globals().get('uploaded_key_path', None)
    if not (key_path and os.path.exists(key_path)):
        return None, "No saved key on disk to reload."
    try:
        ext = key_path.lower().rsplit('.', 1)[-1] if '.' in key_path else ''
        if ext == 'xlsx':
            xls = pd.ExcelFile(key_path, engine='openpyxl')
            frames = [pd.read_excel(xls, sheet_name=s) for s in xls.sheet_names]
            key_df = pd.concat(frames, ignore_index=True, sort=False)
        elif ext == 'csv':
            key_df = pd.read_csv(key_path, sep=None, engine='python')
        else:
            return None, f"Unsupported key type .{ext}"

        key_df = key_df.copy()
        key_df.columns = [str(c).strip() for c in key_df.columns]
        scan = _scan_key_columns(key_df)
        if not (scan['has_mouse'] or scan['has_filename']):
            return None, scan['msg']

        if scan['has_mouse']:
            key_df['Mouse_ID'] = key_df['Mouse_ID'].astype(str).str.strip()
        if scan['has_filename']:
            fcol = scan['filename_col']
            key_df[fcol] = key_df[fcol].astype(str).str.strip()
            key_df['_key_file_base_lower'] = key_df[fcol].map(_norm_base_lower)

        return key_df, f"Key reloaded from {os.path.basename(key_path)}. {scan['msg']}"
    except Exception as e:
        return None, f"Error reading saved uploaded key: {e}"

def on_reset(_):
    """
    HARD RESET: forget any saved key and show bare-bones Key_Df.
    Deletes _uploaded_key.(xlsx|csv) if present and clears uploaded_key_path.
    """
    # 1) Forget path in memory
    globals().pop('uploaded_key_path', None)

    # 2) Remove any persisted key files from disk
    import os
    for ext in ("xlsx", "csv"):
        try:
            os.remove(f"_uploaded_key.{ext}")
        except FileNotFoundError:
            pass

    # 3) Show bare-bones
    build_or_rematch_key_df(None)
    rebuild_grid("Reset: cleared saved key; showing bare-bones Key_Df.")

def download_df(_):
    global Key_Df
    with status_box:
        clear_output(wait=True)
        print("Saving latest Key_Df as XLSX ...")
    try:
        path = "/content/Key_Df.xlsx"
        Key_Df.to_excel(path, index=False, engine='openpyxl')
        colab_files.download(path)
        with status_box:
            clear_output(wait=True)
            print("Saved and downloading Key_Df.xlsx ...")
    except Exception as e:
        with status_box:
            clear_output(wait=True)
            print(f"Error while saving/downloading: {e}")

upload_btn.on_click(on_colab_upload)
reset_btn.on_click(on_reset)
download_button.on_click(download_df)

upload_row = widgets.HBox([
    widgets.HTML("<b>Optional key:</b>"),
    upload_btn,
    reset_btn,
    download_button
])

# ---------------------------
# Edit / rematch / download controls
# ---------------------------
new_col_name    = widgets.Text(placeholder='Enter new column name', description='New Col:')
add_col_button  = widgets.Button(description='Add Column', button_style='info')
apply_button    = widgets.Button(description='Apply Changes', button_style='primary', layout=widgets.Layout(width="120px"))

def add_column(_):
    global Key_Df
    col = new_col_name.value.strip()
    with status_box:
        clear_output(wait=True)
        if not col:
            print("Please enter a column name."); return
        if col in Key_Df.columns:
            print(f"Column '{col}' already exists."); return
        Key_Df[col] = ""
        print(f"Added column '{col}' to Key_Df.")
    rebuild_grid()

def apply_edits(_):
    global Key_Df
    try:
        Key_Df = grid.data.copy().reset_index(drop=True)
        with status_box:
            clear_output(wait=True)
            print(f"Applied grid edits to Key_Df ({len(Key_Df)} rows, {len(Key_Df.columns)} cols).")
    except Exception as e:
        with status_box:
            clear_output(wait=True)
            print(f"Error applying edits: {e}")

add_col_button.on_click(add_column)
apply_button.on_click(apply_edits)

controls = widgets.HBox([new_col_name, add_col_button, apply_button])

# ---------------------------
# Initialize UI
# ---------------------------
Key_Df = _make_base_key_df()
Key_Df["match_status"] = "No key"

grid = make_grid(Key_Df.copy().reset_index(drop=True))
ui = widgets.VBox([upload_row, grid, controls, status_box])
display(ui)


In [None]:
# @title Plot individual files
import os
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets

# optional Colab downloader
try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

files_list = feds
assert isinstance(files_list, list) and len(files_list) > 0, "No FED3 files loaded."
assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."

# metadata_df = copy of Key_Df
metadata_df = Key_Df.copy().reset_index(drop=True)
if 'filename' in metadata_df.columns:
    metadata_df['filename'] = metadata_df['filename'].astype(str).map(os.path.basename)

def _coerce_numeric_col(df, col, clip_upper=None, na_map=None):
    if col not in df.columns:
        return
    s = df[col]
    if na_map:
        s = s.replace(na_map)
    s = pd.to_numeric(s, errors='coerce')
    if clip_upper is not None:
        s.loc[s > clip_upper] = np.nan
    df[col] = s

def _plot_file_core(file_index):
    df = files_list[file_index].copy()
    full_name = getattr(df, 'name', f"File_{file_index}")
    file_basename = os.path.basename(str(full_name))

    # Preserve original index once
    if "Original_Timestamp" not in df.columns:
        df["Original_Timestamp"] = df.index

    # Attach metadata by filename (matching already done upstream)
    meta_row = None
    if 'filename' in metadata_df.columns:
        mr = metadata_df.loc[metadata_df['filename'] == file_basename]
        if not mr.empty:
            meta_row = mr.iloc[0]
    if meta_row is not None:
        for col in meta_row.index:
            if col == 'filename':
                continue
            if col not in df.columns:
                df[col] = meta_row[col]
            else:
                if pd.isna(df[col]).all() and pd.notna(meta_row[col]):
                    df[col] = meta_row[col]

    # Time + cleanup
    try:
        df['timestamp'] = pd.to_datetime(df.index)
    except Exception:
        df['timestamp'] = np.arange(len(df))

    _coerce_numeric_col(df, 'Poke_Time', clip_upper=2)
    _coerce_numeric_col(df, 'Retrieval_Time', na_map={"Timed_out": np.nan})

    if len(df) == 0:
        print(f"[!] Empty dataframe for {file_basename}. Skipping plot.")
        return None, None

    # Behavioral traces (needs fed3bandit as f3b)
    true_left = f3b.true_probs(df, offset=5)[0]
    mouse_left = f3b.binned_paction(df, window=10)

    # Plot
    fig, ax = plt.subplots(figsize=(10, 3))
    ax.plot(np.arange(len(true_left)), true_left, color="black", linewidth=2, alpha=0.5)

    color = "dodgerblue"
    if 'Sex' in df.columns and pd.notna(df['Sex']).any():
        try:
            color = "red" if str(df['Sex'].iloc[0]).strip().lower().startswith("f") else "dodgerblue"
        except Exception:
            pass
    ax.plot(np.arange(len(mouse_left)), mouse_left, color=color, linewidth=3, alpha=0.7)

    ax.set_ylabel("P(Left)")
    ax.set_xlabel("Trial")
    ax.set_yticks([1, 0.5, 0])

    # Title = Mouse_ID (fallback to filename)
    title_text = str(meta_row['Mouse_ID']) if (meta_row is not None and 'Mouse_ID' in meta_row and pd.notna(meta_row['Mouse_ID'])) else file_basename
    ax.set_title(title_text)

    sns.despine()
    plt.tight_layout()
    plt.show()

    # suggest a base filename for saving
    safe_title = "".join(c if c.isalnum() or c in ("-", "_") else "_" for c in title_text)
    suggested = f"{safe_title}_pLeft"
    return fig, suggested

# ----- UI -----
N = len(files_list)
idx_slider = widgets.IntSlider(min=0, max=max(0, N-1), step=1, value=0, description='File', continuous_update=True)
status_lbl = widgets.HTML()
out = widgets.Output()

# save controls
fmt_dd   = widgets.Dropdown(options=["pdf", "png", "svg"], value="pdf", description="Format:", layout=widgets.Layout(width="180px"))
save_btn = widgets.Button(description="Save current plot", button_style="success", layout=widgets.Layout(width="200px"))

_last_fig = None
_last_name = "plot"

def _status(idx):
    name = getattr(files_list[idx], 'name', f"File_{idx}")
    return f"Index: <b>{idx}</b> &nbsp;|&nbsp; File: <code>{os.path.basename(str(name))}</code>"

def _render(*_):
    global _last_fig, _last_name
    idx = int(idx_slider.value)
    status_lbl.value = _status(idx)
    out.clear_output()
    with out:
        fig, suggested = _plot_file_core(idx)
        _last_fig = fig
        _last_name = suggested or "plot"

def _save_current(_):
    if _last_fig is None:
        with out:
            print("No plot to save yet.");
        return
    fmt = fmt_dd.value
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    outdir = "figures"
    os.makedirs(outdir, exist_ok=True)
    path = os.path.join(outdir, f"{_last_name}_{ts}.{fmt}")
    _last_fig.savefig(path, dpi=300, bbox_inches="tight")
    with out:
        print(f"Saved: {path}")
    if colab_files is not None:
        colab_files.download(path)

save_btn.on_click(_save_current)
idx_slider.observe(_render, names='value')

controls = widgets.HBox([idx_slider, fmt_dd, save_btn], layout=widgets.Layout(gap="10px"))
display(widgets.VBox([controls, status_lbl, out]))

# Initial draw
_render()

In [None]:
# @title Analyze Bandit metrics
import os
import numpy as np
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from IPython.display import FileLink
import tqdm

assert 'metadata_df' in globals() and isinstance(metadata_df, pd.DataFrame), \
    "metadata_df not found. Build/rematch Key_Df and set metadata_df = Key_Df.copy() first."

# ---------- Small helpers ----------
def _basename(pathlike) -> str:
    s = str(pathlike).replace("\\", "/")
    return s.split("/")[-1]

def _file_base_lower(pathlike):
    return os.path.splitext(os.path.basename(str(pathlike)))[0].lower()

def _get_timestamp_series(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    """Return a pandas Series of timestamps for df (never an Index)."""
    if ts_col in df.columns:
        ts = pd.to_datetime(df[ts_col], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        return pd.Series(ts, index=df.index)
    for cand in ["DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            return pd.Series(ts, index=df.index)
    idx = df.index
    if isinstance(idx, pd.DatetimeIndex):
        return pd.Series(idx, index=df.index)
    return pd.to_datetime(pd.Series(idx, index=df.index), errors="coerce")

def _split_day_night(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    """Day: 07:00–19:00; Night: otherwise."""
    ts = _get_timestamp_series(df, ts_col=ts_col)
    valid = ts.notna()
    hrs = ts.dt.hour
    day_mask = valid & (hrs >= 7) & (hrs < 19)
    night_mask = valid & ~day_mask
    return df.loc[day_mask], df.loc[night_mask]

def compute_withinbout_lose_shift(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 2:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        total = shifted = 0
        for i in range(len(events) - 1):
            curr_evt, next_evt = events[i], events[i + 1]
            if curr_evt not in ("Left", "Right"):
                continue
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if np.isnan(dt_s) or dt_s > max_gap_s:
                continue
            if next_evt == "Pellet":
                continue
            if next_evt in ("Left", "Right"):
                total += 1
                if next_evt != curr_evt:
                    shifted += 1
        return (shifted / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_withinbout_win_stay(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 3:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        pellet_idx = [i for i in range(1, len(events) - 1) if events[i] == "Pellet"]
        total = same = 0
        for i in pellet_idx:
            prev_event, next_event = events[i - 1], events[i + 1]
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if not np.isnan(dt_s) and dt_s <= max_gap_s:
                if prev_event in ("Left", "Right") and next_event in ("Left", "Right"):
                    total += 1
                    if next_event == prev_event:
                        same += 1
        return (same / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_peak_accuracy(c_df):
    """Pre-/around-reversal accuracy using f3b.reversal_peh on the subset."""
    try:
        rev_avg = f3b.reversal_peh(c_df, (-10, 10), True)
        if len(rev_avg) == 0:
            return np.nan
        return float(np.mean(rev_avg[:10])) if len(rev_avg) >= 10 else float(np.mean(rev_avg))
    except Exception:
        return np.nan

def estimate_daily_pellets(c_df):
    ts = _get_timestamp_series(c_df)
    valid_ts = ts.dropna()
    if valid_ts.size < 2:
        return np.nan
    duration_hours = (valid_ts.max() - valid_ts.min()).total_seconds() / 3600.0
    if duration_hours <= 0:
        return np.nan

    pellet_events = np.nan
    if "Pellet_Count" in c_df.columns and c_df["Pellet_Count"].notna().any():
        pc = pd.to_numeric(c_df["Pellet_Count"], errors="coerce")
        if pc.notna().any():
            diffs = pc.diff().fillna(0).clip(lower=0)
            pellet_events = float(diffs.sum())
            if pellet_events == 0 and pc.iloc[-1] >= pc.iloc[0]:
                pellet_events = float(pc.iloc[-1] - pc.iloc[0])
    if (pd.isna(pellet_events)) and ("Event" in c_df.columns):
        pellet_events = float((c_df["Event"] == "Pellet").sum())

    if pd.isna(pellet_events):
        return np.nan
    return (pellet_events / duration_hours) * 24.0

# ---------- Sessions ----------
def _get_sessions():
    if 'feds' in globals() and isinstance(feds, (list, tuple)) and len(feds) > 0:
        return list(feds)
    raise RuntimeError("No FED3 sessions found. Expecting a non-empty 'feds' list.")

_sessions = _get_sessions()
all_indices = list(range(len(_sessions)))

# ---------- Normalize metadata_df ----------
md = metadata_df.copy()
md['filename'] = md['filename'].astype(str).map(_basename)
if 'Mouse_ID' in md.columns:
    md['Mouse_ID'] = md['Mouse_ID'].astype(str).str.strip()
else:
    md['Mouse_ID'] = np.nan  # ensure column exists

# ---------- Compute metrics for ALL files ----------
rows = []
for idx in tqdm.tqdm(all_indices):
    c_df = _sessions[idx]
    file_name = _basename(getattr(c_df, "name", f"File_{idx}"))

    try:
        pre_acc_all = compute_peak_accuracy(c_df)
        clean_retrieval_time = pd.to_numeric(c_df.get("Retrieval_Time", pd.Series(dtype=float)), errors="coerce")
        clean_poke_time = pd.to_numeric(c_df.get("Poke_Time", pd.Series(dtype=float)), errors="coerce")
        clean_poke_time = clean_poke_time[clean_poke_time > 0]

        row = {
            "filename": file_name,
            "PeakAccuracy": pre_acc_all,
            "Total_pellets": f3b.count_pellets(c_df),
            "Total_pokes": f3b.count_pokes(c_df),
            "PokesPerPellet": f3b.pokes_per_pellet(c_df),
            "RetrievalTime": clean_retrieval_time.median() if not clean_retrieval_time.empty else np.nan,
            "PokeTime": clean_poke_time.median() if not clean_poke_time.empty else np.nan,
            "Win-stay": compute_withinbout_win_stay(c_df),
            "Lose-shift": compute_withinbout_lose_shift(c_df),
            "daily pellets": estimate_daily_pellets(c_df),
        }
    except Exception as e:
        print(f"Failed to compute OVERALL metrics for {file_name} (idx {idx}): {e}")
        continue

    # Day/Night splits
    try:
        day_df, night_df = _split_day_night(c_df, ts_col="MM:DD:YYYY hh:mm:ss")
        row.update({
            "PeakAccuracy_Day": compute_peak_accuracy(day_df),
            "PeakAccuracy_Night": compute_peak_accuracy(night_df),
            "Win-stay_Day": compute_withinbout_win_stay(day_df),
            "Win-stay_Night": compute_withinbout_win_stay(night_df),
            "Lose-shift_Day": compute_withinbout_lose_shift(day_df),
            "Lose-shift_Night": compute_withinbout_lose_shift(night_df),
        })
    except Exception as e:
        print(f"Failed to compute DAY/NIGHT metrics for {file_name} (idx {idx}): {e}")

    rows.append(row)

Banditmetrics = pd.DataFrame(rows)
if Banditmetrics.empty:
    display(HTML("<b style='color:#b00'>No files to analyze.</b>"))
    raise SystemExit

# ---------- Attach Mouse_ID (primary) from metadata_df, then merge all metadata ----------
# Merging priority:
# 1) Map filename -> Mouse_ID from metadata_df
mouse_map = md.set_index('filename')['Mouse_ID']
Banditmetrics['Mouse_ID'] = Banditmetrics['filename'].map(mouse_map)

# 2) Fallback: if still missing Mouse_ID, try substring match on metadata Mouse_IDs
if Banditmetrics['Mouse_ID'].isna().any():
    known_ids = md['Mouse_ID'].dropna().unique().tolist()
    for i, row in Banditmetrics.loc[Banditmetrics['Mouse_ID'].isna()].iterrows():
        base = _file_base_lower(row['filename'])
        hits = [mid for mid in known_ids if str(mid).lower() in base]
        if len(hits) == 1:
            Banditmetrics.at[i, 'Mouse_ID'] = hits[0]
        elif len(hits) > 1:
            longest = max(len(str(h)) for h in hits)
            best = [h for h in hits if len(str(h)) == longest]
            if len(best) == 1:
                Banditmetrics.at[i, 'Mouse_ID'] = best[0]

# 3) Merge metadata: prefer Mouse_ID, fallback to filename if no Mouse_ID match
md_mouse_unique = md.drop_duplicates(subset=['Mouse_ID'], keep='first')
bm = Banditmetrics.merge(md_mouse_unique, on='Mouse_ID', how='left', suffixes=('', '_md'))

needs_fallback = bm['Mouse_ID'].isna() | bm.filter(regex='_md$').isna().all(axis=1)
if needs_fallback.any():
    md_file_unique = md.drop_duplicates(subset=['filename'], keep='first')
    bm_fb = Banditmetrics.loc[needs_fallback].merge(
        md_file_unique, on='filename', how='left', suffixes=('', '_md')
    )
    bm = pd.concat([bm.loc[~needs_fallback], bm_fb], ignore_index=True)

# ---------- Session-type suffixing ----------
metric_cols = [
    "Win-stay", "Lose-shift", "PeakAccuracy", "Total_pellets", "Total_pokes",
    "PokesPerPellet", "RetrievalTime", "PokeTime", "daily pellets",
    "Win-stay_Day", "Win-stay_Night", "Lose-shift_Day", "Lose-shift_Night",
    "PeakAccuracy_Day", "PeakAccuracy_Night",
]

# Prefer metadata_df's Session_type; fallback to df.attrs or "Unknown"
if 'Session_type' in bm.columns:
    session_series = bm['Session_type'].astype(str).str.strip()
else:
    sess_map = {
        _basename(getattr(_sessions[i], "name", f"File_{i}")):
        (_sessions[i].attrs.get("Session_type") or "Unknown")
        for i in range(len(_sessions))
    }
    session_series = bm["filename"].map(sess_map).fillna("Unknown").astype(str)

session_series = session_series.str.replace(r"\s+", "_", regex=True)
bm["_Session_type_for_csv"] = session_series

def with_session_suffix_for_csv(df, file_col="File", metrics=metric_cols, session_col="_Session_type_for_csv"):
    df = df.copy()
    for m in metrics:
        if m not in df.columns:
            continue
        for sess in df[session_col].dropna().unique():
            mask = df[session_col] == sess
            col_name = f"{m}_{sess}"
            if col_name not in df.columns:
                df[col_name] = np.nan
            df.loc[mask, col_name] = df.loc[mask, m]
        df.drop(columns=[m], inplace=True)
    return df.drop(columns=[session_col])

Banditmetrics_merged = bm.copy()  # optional to inspect in notebook
Banditmetrics_csv = with_session_suffix_for_csv(Banditmetrics_merged)

# ---------- Minimalist CSV: ID column + Session-type–suffixed metrics only ----------

def _norm_name(s: str) -> str:
    return re.sub(r'[^a-z0-9]+', '', str(s).lower())

# Decide whether the primary ID is Mouse_ID or filename
match_mode = globals().get('KEY_MATCH_MODE', None)

if match_mode == 'filename':
    id_col = 'filename'
elif match_mode == 'mouse_id':
    id_col = 'Mouse_ID'
else:
    # Fallback if KEY_MATCH_MODE wasn't set for some reason
    if 'Mouse_ID' in Banditmetrics_csv.columns and Banditmetrics_csv['Mouse_ID'].notna().any():
        id_col = 'Mouse_ID'
    elif 'filename' in Banditmetrics_csv.columns:
        id_col = 'filename'
    else:
        raise ValueError("Neither 'Mouse_ID' nor 'filename' found in Banditmetrics_csv.")

# Metric columns: any Session_type–suffixed metric (e.g., Win-stay_FR1, PeakAccuracy_Probe)
def _metric_match(col: str) -> bool:
    # Keep any column that is a suffixed metric from metric_cols
    return any(col.startswith(base + "_") for base in metric_cols)

metric_keep = [c for c in Banditmetrics_csv.columns if _metric_match(c)]

if not metric_keep:
    raise RuntimeError("No session-suffixed metric columns matched; check 'metric_cols' and your column names.")

# Final order: ID column first, then all metric columns
cols_out = [id_col] + [c for c in metric_keep if c != id_col]
Banditmetrics_csv = Banditmetrics_csv.loc[:, cols_out]

# ---------- Save CSV (new naming: strain_strainnum_task_L3.csv) ----------
# Use Gene / Gene_ID / Session_type from merged metadata
example = Banditmetrics_merged.iloc[0]

# Strain name (Gene or Strain)
strain_name = str(
    example.get("Gene", example.get("Strain", "NA"))
).replace(" ", "_")

# Strain number (Gene_ID or Strain_ID), zero-padded to 3 digits
strain_num_raw = example.get("Gene_ID", example.get("Strain_ID", "NA"))
try:
    strain_num = f"{int(strain_num_raw):03d}"
except Exception:
    strain_num = str(strain_num_raw).zfill(3)

# Task (Session_type)
task_name = str(example.get("Session_type", "Unknown")).replace(" ", "_")

fname = f"{strain_name}_{strain_num}_{task_name}_L3.csv"

Banditmetrics_csv.to_csv(fname, index=False)
display(HTML(f"<b>✓ Saved metrics CSV to:</b> <code>{fname}</code>"))
try:
    display(FileLink(fname))
except Exception:
    pass

# ---------- Download button ----------
btn = widgets.Button(
    description=f"Download {os.path.basename(fname)}",
    icon="download",
    tooltip="Click to download the metrics CSV",
    layout=widgets.Layout(width="auto"),
)
status = widgets.HTML()

def _on_click(b):
    clear_output(wait=True)
    display(btn, status)
    if not os.path.exists(fname):
        status.value = f"<b style='color:#b00'>File not found:</b> {fname}"
        return
    try:
        from google.colab import files as gfiles
        status.value = f"Starting download: <code>{os.path.basename(fname)}</code>…"
        gfiles.download(fname)
    except Exception:
        status.value = (
            "Not running in Colab. File saved locally at:<br>"
            f"<code>{fname}</code><br>"
            "Use the link above to open it."
        )

display(btn, status)
btn.on_click(_on_click)


In [None]:
# @title Group for plotting

import os
import numpy as np
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output

# --- sanity ---
if 'metadata_df' not in globals() or metadata_df is None or metadata_df.empty:
    raise RuntimeError("metadata_df is missing or empty. Build metadata_df (copy of Key_Df) first.")

EXCLUDE_LOWER = {"match_status"}   # everything else is allowed

def _build_file_column(df):
    if "filename" in df.columns:
        return df["filename"].apply(lambda p: os.path.basename(str(p)))
    if "FED3_from_file" in df.columns and "Date_from_file" in df.columns:
        return "FED" + df["FED3_from_file"].astype(str) + "_" + df["Date_from_file"].astype(str)
    if "FED3_from_file" in df.columns:
        return "FED" + df["FED3_from_file"].astype(str)
    return df.index.astype(str)

def _norm_val(x):
    s = str(x).strip()
    if s == "" or s.lower() in {"nan", "none"}:
        return "UNK"
    return s.upper()

def _build_group_row(row, ordered_cols):
    if not ordered_cols:
        return "ALL"
    return " | ".join(_norm_val(row[c]) for c in ordered_cols)

def build_mapping(ordered_cols):
    _meta = metadata_df.copy()
    _meta["filename"] = _build_file_column(_meta)
    _meta["Group"] = _meta.apply(lambda r: _build_group_row(r, ordered_cols), axis=1)
    mapping = (
        _meta[["filename", "Group"]]
        .dropna(subset=["filename"])
        .drop_duplicates()
        .sort_values(["Group", "filename"])
        .reset_index(drop=True)
    )
    return mapping

def _unique_keep_order(seq):
    seen = set(); out = []
    for x in seq:
        if x not in seen:
            seen.add(x); out.append(x)
    return out

# ---------- UI (fixed sizes + grid) ----------
PX_W = "260px"   # list box width
PX_H = "160px"   # list box height
BTN_W = "160px"  # button column width
HDR_H = "28px"   # header cell height (consistent across all headers)

title = widgets.HTML("<h3>Select columns to group by for X and Hue, then reorder X to set hierarchy</h3>")

all_cols = sorted((c for c in metadata_df.columns if str(c).lower() not in EXCLUDE_LOWER), key=str.lower)

def header(text):
    # Normalize header height/margins so they align perfectly in the grid row
    return widgets.HTML(
        f"<div style='height:{HDR_H};display:flex;align-items:flex-end;'>"
        f"<h4 style=\"margin:0;\">{text}</h4></div>"
    )

# Headers (row 1 of grid)
available_hdr = header("Available")
actions_hdr   = header("Actions")
x_hdr         = header("X grouping")
hue_hdr       = header("Hue grouping")

# Widgets (row 2 of grid)
available = widgets.SelectMultiple(
    options=all_cols, value=tuple(), rows=14,
    layout=widgets.Layout(
        width=PX_W, height=PX_H, min_width=PX_W, max_width=PX_W,
        min_height=PX_H, max_height=PX_H, flex="0 0 auto"
    )
)

right_x = widgets.Select(
    options=[], value=None, rows=8,
    layout=widgets.Layout(
        width=PX_W, height=PX_H, min_width=PX_W, max_width=PX_W,
        min_height=PX_H, max_height=PX_H, flex="0 0 auto"
    )
)

right_hue = widgets.Select(
    options=[], value=None, rows=8,
    layout=widgets.Layout(
        width=PX_W, height=PX_H, min_width=PX_W, max_width=PX_W,
        min_height=PX_H, max_height=PX_H, flex="0 0 auto"
    )
)

# Buttons
btn_add_x    = widgets.Button(description="Add to X ▶", button_style='primary', layout=widgets.Layout(width=BTN_W))
btn_add_hue  = widgets.Button(description="Add to Hue ▶",button_style='primary', layout=widgets.Layout(width=BTN_W))
btn_clear    = widgets.Button(description="Clear", button_style='danger', layout=widgets.Layout(width=BTN_W))
btn_up       = widgets.Button(description="↑ Up (X only)", layout=widgets.Layout(width=BTN_W))
btn_down     = widgets.Button(description="↓ Down (X only)", layout=widgets.Layout(width=BTN_W))
btn_build    = widgets.Button(description="Build Groups", button_style='success', layout=widgets.Layout(width="160px"))

controls_col = widgets.VBox(
    [btn_add_x, btn_add_hue, btn_clear, btn_up, btn_down],
    layout=widgets.Layout(
        align_items="center",
        width=BTN_W, min_width=BTN_W, max_width=BTN_W,
        height=PX_H, min_height=PX_H, max_height=PX_H,
        flex="0 0 auto"
    )
)

btn_build = widgets.Button(description="Build Groups", button_style='success', layout=widgets.Layout(width="160px"))
output = widgets.Output()

# --- Callbacks ---
def on_add_x(_):
    sel = list(available.value)
    if not sel: return
    new_opts = _unique_keep_order(list(right_x.options) + sel)
    right_x.value = None
    right_x.options = new_opts
    right_x.value = new_opts[-1] if new_opts else None

def on_add_hue(_):
    sel = list(available.value)
    if not sel: return
    new_opts = _unique_keep_order(list(right_hue.options) + sel)
    right_hue.value = None
    right_hue.options = new_opts
    right_hue.value = new_opts[-1] if new_opts else None

def on_clear(_):
    right_x.value = None; right_x.options = []
    right_hue.value = None; right_hue.options = []

def on_up(_):
    item = right_x.value
    if item is None: return
    opts = list(right_x.options)
    i = opts.index(item)
    if i > 0:
        opts[i-1], opts[i] = opts[i], opts[i-1]
        right_x.value = None; right_x.options = opts; right_x.value = item

def on_down(_):
    item = right_x.value
    if item is None: return
    opts = list(right_x.options)
    i = opts.index(item)
    if i < len(opts) - 1:
        opts[i+1], opts[i] = opts[i], opts[i+1]
        right_x.value = None; right_x.options = opts; right_x.value = item

def on_build(_):
    with output:
        clear_output()
        ordered_cols_x = list(right_x.options)
        ordered_cols_hue = list(right_hue.options)

        mapping_x = build_mapping(ordered_cols_x)
        mapping_hue = build_mapping(ordered_cols_hue)

        _meta = metadata_df.copy()
        _meta["filename"] = _build_file_column(_meta)
        _meta["XGroup"] = _meta.apply(lambda r: _build_group_row(r, ordered_cols_x), axis=1)
        _meta["HueGroup"] = _meta.apply(lambda r: _build_group_row(r, ordered_cols_hue), axis=1)
        mapping_both = (
            _meta[["filename", "XGroup", "HueGroup"]]
            .dropna(subset=["filename"])
            .drop_duplicates()
            .sort_values(["XGroup", "HueGroup", "filename"])
            .reset_index(drop=True)
        )

        globals()['files_to_group_x'] = mapping_x.copy()
        globals()['files_to_group_hue'] = mapping_hue.copy()
        globals()['files_to_group_both'] = mapping_both.copy()
        globals()['selected_group_cols_x'] = ordered_cols_x.copy()
        globals()['selected_group_cols_hue'] = ordered_cols_hue.copy()

        print("X-axis grouping (hierarchy):", ordered_cols_x if ordered_cols_x else ["ALL"])
        print(f"Total unique files (X map): {mapping_x['filename'].nunique()}")
        display(widgets.HTML("<b>X-group summary</b>"))
        display((mapping_x.groupby("Group", dropna=False)["filename"]
                 .nunique().sort_values(ascending=False)
                 .rename("UniqueFiles").to_frame()))

        print("\nHue grouping:", ordered_cols_hue if ordered_cols_hue else ["ALL"])
        print(f"Total unique files (Hue map): {mapping_hue['filename'].nunique()}")
        display(widgets.HTML("<b>Hue-group summary</b>"))
        display((mapping_hue.groupby("Group", dropna=False)["filename"]
                 .nunique().sort_values(ascending=False)
                 .rename("UniqueFiles").to_frame()))
        print("\nCombined mapping available as `files_to_group_both` (filename, XGroup, HueGroup)")

# Wire up
btn_add_x.on_click(on_add_x)
btn_add_hue.on_click(on_add_hue)
btn_clear.on_click(on_clear)
btn_up.on_click(on_up)
btn_down.on_click(on_down)
btn_build.on_click(on_build)

# ----- Grid layout -----
grid = widgets.GridBox(
    children=[
        available_hdr, actions_hdr, x_hdr, hue_hdr,     # row 1: headers
        available,     controls_col, right_x, right_hue # row 2: widgets
    ],
    layout=widgets.Layout(
        grid_template_columns=f"{PX_W} {BTN_W} {PX_W} {PX_W}",
        grid_template_rows="auto auto",
        grid_gap="6px 16px",
        align_items="flex-start",
        justify_items="flex-start",
        width="100%"
    )
)

ui = widgets.VBox([title, grid, widgets.HBox([btn_build]), output])
display(ui)

In [None]:
# @title Plot metrics!

import os, time, shutil, re, itertools
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Stats
import pingouin as pg
import statsmodels.api as sm
from statsmodels.formula.api import ols

# Optional Colab download
try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

ALPHA = 0.6  # apply to both bars and dots

# -----------------------
# 0) Preconditions & source
# -----------------------
if 'Banditmetrics_merged' in globals() and Banditmetrics_merged is not None and not Banditmetrics_merged.empty:
    bm = Banditmetrics_merged.copy()
elif 'Banditmetrics' in globals() and Banditmetrics is not None and not Banditmetrics.empty:
    bm = Banditmetrics.copy()
elif 'Banditmetrics_csv' in globals() and Banditmetrics_csv is not None and not Banditmetrics_csv.empty:
    bm = Banditmetrics_csv.copy()
else:
    raise RuntimeError("No metrics table found. Run the metrics cell first.")

if "filename" not in bm.columns:
    if "File" in bm.columns:
        bm["filename"] = bm["File"].astype(str)
    else:
        raise RuntimeError("Metrics table must include a 'filename' column.")

# -----------------------
# Merge in XGroup/HueGroup from grouping widget
# -----------------------
def _basename_col(s):
    return os.path.basename(str(s))

def _src_name(df):
    if "filename" in df.columns:
        return "filename"
    return None

if ("XGroup" not in bm.columns) or ("HueGroup" not in bm.columns):
    if 'files_to_group_both' in globals() and files_to_group_both is not None and not files_to_group_both.empty:
        m = files_to_group_both.copy()
        m_src = _src_name(m)
        if m_src is None:
            raise RuntimeError("Grouping table must include 'filename' (or legacy 'File').")
        m["file_base"]  = m[m_src].apply(_basename_col)
        bm["file_base"] = bm["filename"].apply(_basename_col)
        bm = bm.merge(m[["file_base","XGroup","HueGroup"]], on="file_base", how="left").drop(columns=["file_base"])
        bm["XGroup"]   = bm["XGroup"].fillna("UNASSIGNED")
        bm["HueGroup"] = bm["HueGroup"].fillna("UNASSIGNED")
    elif 'files_to_group' in globals() and files_to_group is not None and not files_to_group.empty:
        m = files_to_group.copy()
        m_src = _src_name(m)
        if m_src is None:
            raise RuntimeError("Grouping table must include 'filename' (or legacy 'File').")
        m["file_base"]  = m[m_src].apply(_basename_col)
        bm["file_base"] = bm["filename"].apply(_basename_col)
        bm = bm.merge(m[["file_base","Group"]], on="file_base", how="left").drop(columns=["file_base"])
        bm["Group"] = m["Group"].fillna("UNASSIGNED")
        bm["XGroup"] = bm["Group"]
        bm["HueGroup"] = "ALL"
    else:
        raise RuntimeError("Missing X/Hue mapping. Run the grouping widget (Build Groups) first.")

# -----------------------
# 1) Melt to long format
# -----------------------
base_metric_names = [
    "Win-stay","Lose-shift","PeakAccuracy",
    "Total_pellets","Total_pokes","PokesPerPellet",
    "RetrievalTime","PokeTime","daily pellets",
    "Win-stay_Day","Win-stay_Night","Lose-shift_Day","Lose-shift_Night",
    "PeakAccuracy_Day","PeakAccuracy_Night",
]

metric_cols = []
for c in bm.columns:
    if pd.api.types.is_numeric_dtype(bm[c]):
        for base in base_metric_names:
            if c == base or c.startswith(base + "_"):
                metric_cols.append(c); break
seen = set(); metric_cols = [c for c in metric_cols if not (c in seen or seen.add(c))]
if not metric_cols:
    raise RuntimeError("No numeric metric columns found among expected Bandit metrics.")

candidate_id_vars = ["Genotype","Sex","Strain","Start_Date","flename","Mouse_ID","Session_type","XGroup","HueGroup"]
id_vars = [c for c in candidate_id_vars if c in bm.columns]
for need in ["XGroup","HueGroup","filename"]:
    if need not in id_vars: id_vars.append(need)

long_df = pd.melt(
    bm,
    id_vars=id_vars,
    value_vars=metric_cols,
    var_name="variable",
    value_name="value"
)
bm.head()

# -----------------------
# 2) Ordering helpers
# -----------------------

# Define hue order priority
HUE_PRIORITY = ["Female", "Male", "F", "M", "Day", "Night", "Light", "Dark", "ALL", "UNASSIGNED"]

def _is_wt_group(g):
    u = str(g).strip().upper()
    tokens = [t for t in re.split(r'[^A-Z0-9]+', u) if t]
    WT_ALIASES = {"WT", "WILDTYPE", "CONTROL", "CTRL"}
    return any(t in WT_ALIASES for t in tokens)

def _hier_sort_key(g):
    lv = _x_levels(g)
    norm = []
    for tok in lv:
        is_blank = 1 if _is_unassigned_token(tok) else 0
        norm.append((is_blank, str(tok).upper()))
    wt_present = any(_is_wt_group(tok) for tok in lv) or _is_wt_group(g)
    wt_rank = 0 if wt_present else 1
    return (wt_rank,) + tuple(norm) + (str(g).upper(),)

def _is_unassigned_token(s):
    return (str(s).strip().upper() in {"", "UNASSIGNED", "NONE", "NA", "N/A"})

def _x_levels(xname):
    s = str(xname)
    parts = [p.strip() for p in s.split("|")]
    wanted = globals().get("selected_group_cols_x", None)
    if isinstance(wanted, (list, tuple)) and wanted:
        if len(parts) < len(wanted):
            parts += [""] * (len(wanted) - len(parts))
        else:
            parts = parts[:len(wanted)]
    return parts

def _order_x_groups(groups):
    return sorted(groups, key=_hier_sort_key)

def _choose_ref_group(order):
    for g in order:
        if _is_wt_group(g):
            return g
    return order[0] if order else None

def _order_hue_groups(hues):
    hp = globals().get("HUE_PRIORITY", ["Female", "Male", "F", "M", "ALL", "UNASSIGNED"])
    hp_lower = [p.lower() for p in hp]
    def _prio(h):
        u = str(h).strip()
        try:
            return (0, hp_lower.index(u.lower()), u.upper())
        except ValueError:
            return (1, u.upper())
    return sorted([h for h in hues if h is not None], key=_prio)

# -----------------------
# 3) Controls (left column: groups & colors)
# -----------------------
named_defaults = [
    "dodgerblue", "red", "green", "orange", "purple",
    "brown", "pink", "gray", "olive", "cyan"]

# Build the ordered list of XGroup levels
all_x_groups = long_df["XGroup"].dropna().unique().tolist()
if not all_x_groups:
    raise RuntimeError("No XGroup values found – check that the grouping step ran correctly.")

ordered_x = _order_x_groups(all_x_groups)

x_checks, x_colors = {}, {}
group_rows = []
for i, g in enumerate(ordered_x):
    chk = widgets.Checkbox(value=True, description=g, indent=False, layout=widgets.Layout(width="260px"))
    col = widgets.Text(value=named_defaults[i % len(named_defaults)],
                       layout=widgets.Layout(width="120px"))
    x_checks[g] = chk
    x_colors[g] = col
    # more compact row
    group_rows.append(widgets.HBox([chk, widgets.Label(""), col],
                                   layout=widgets.Layout(align_items="center", height="28px")))

picker = widgets.VBox(group_rows, layout=widgets.Layout(gap="2px"))

btn_all  = widgets.Button(description="Select all", layout=widgets.Layout(width="140px"))
btn_none = widgets.Button(description="Clear", layout=widgets.Layout(width="140px"))
def _set_all(val):
    for c in x_checks.values(): c.value = val
btn_all.on_click(lambda _: _set_all(True))
btn_none.on_click(lambda _: _set_all(False))

# scrollable container for the (possibly long) group list
picker_container = widgets.Box([picker],
    layout=widgets.Layout(overflow="auto", max_height="420px",
                          border="1px solid #ddd", padding="6px", width="360px"))

left_col = widgets.VBox([
    widgets.HTML("<b>Groups & Colors</b>"),
    widgets.HBox([btn_all, btn_none], layout=widgets.Layout(gap="8px")),
    picker_container
], layout=widgets.Layout(width="380px"))

# -----------------------
# 4) Comparison controls (right column)
# -----------------------
mode_radio = widgets.ToggleButtons(
    options=[("Reference group", "ref"), ("Select Pairs", "pairs")],
    value="ref", description="", style={"button_width":"150px"},
    layout=widgets.Layout(width="320px")
)

ref_dropdown = widgets.Dropdown(
    options=ordered_x, value=_choose_ref_group(ordered_x),
    description="Reference:", layout=widgets.Layout(width="320px")
)

def _pair_label(a,b): return f"{a} ⟷ {b}"
def _pair_value(a,b): return (a,b) if a <= b else (b,a)

pairs_select = widgets.SelectMultiple(
    options=[], value=[], description="Pairs",
    layout=widgets.Layout(width="360px", height="320px")
)

def _selected_x():
    return _order_x_groups([g for g, cb in x_checks.items() if cb.value])

def _pair_sort_key(a, b):
    """Rank pairs by level of first difference: later differences (within-group) rank first."""
    A = _x_levels(a); B = _x_levels(b)
    L = max(len(A), len(B))
    if len(A) < L: A += [""] * (L - len(A))
    if len(B) < L: B += [""] * (L - len(B))
    first_diff = next((i for i, (x, y) in enumerate(zip(A, B)) if x != y), L)
    prefix = tuple(A[:first_diff])
    return (-first_diff, prefix, tuple(A), tuple(B))

def _update_ref_and_pairs(*_):
    sel = _selected_x()
    ref_dropdown.options = sel or ["—"]
    if sel:
        if ref_dropdown.value not in sel:
            ref_dropdown.value = _choose_ref_group(sel)
    else:
        ref_dropdown.value = None

    opts = []
    for a, b in itertools.combinations(sel, 2):
        lbl = _pair_label(a, b)
        val = _pair_value(a, b)
        opts.append((lbl, val))
    opts.sort(key=lambda kv: _pair_sort_key(*kv[1]))
    pairs_select.options = opts

for cb in x_checks.values():
    cb.observe(_update_ref_and_pairs, names="value")
_update_ref_and_pairs()

# action buttons
plot_btn = widgets.Button(description="Plot", button_style="primary",
                          layout=widgets.Layout(width="160px"))
save_btn = widgets.Button(description="Save Plots", button_style="success",
                          layout=widgets.Layout(width="160px"))

# Right column layout
right_col = widgets.VBox([
    widgets.HTML("<b>Statistical comparisons</b>"),
    mode_radio,
    ref_dropdown,
    pairs_select,
    widgets.HBox([plot_btn, save_btn], layout=widgets.Layout(gap="8px"))
], layout=widgets.Layout(width="360px"))

# -----------------------
# 5) Labels for stats/legend
# -----------------------
def _grouping_label(which="X"):
    if which.lower().startswith("x"):
        cols = globals().get("selected_group_cols_x", [])
        default = "XGroup"
    else:
        cols = globals().get("selected_group_cols_hue", [])
        default = "HueGroup"
    cols = [str(c).strip() for c in (cols or []) if str(c).strip()]
    return " | ".join(cols) if cols else default

# -----------------------
# 6) Stats helpers (ANOVA with Hue)
# -----------------------
def _fmt_p(p):
    if not np.isfinite(p): return "n/a"
    return f"p = {p:.3f}" if p >= 0.001 else "p < 0.001"

def _anova_subset(df):
    """
    If >=2 Hue levels: two-way ANOVA (XGroup, HueGroup, interaction)
    Else: one-way ANOVA (XGroup).
    """
    out = {"p_x": np.nan, "p_h": np.nan, "p_int": np.nan, "n_h": 0, "ok": False, "err": None}
    d = df.dropna(subset=["value","XGroup"])
    if d.empty or d["XGroup"].nunique() < 2:
        out["err"] = "Too few groups"; return out
    n_h = d["HueGroup"].nunique(dropna=True); out["n_h"] = n_h
    try:
        if n_h >= 2:
            model = ols('value ~ C(XGroup) + C(HueGroup) + C(XGroup):C(HueGroup)', data=d).fit()
            an = sm.stats.anova_lm(model, typ=2)
            out["p_x"]   = float(an.loc['C(XGroup)','PR(>F)'])
            out["p_h"]   = float(an.loc['C(HueGroup)','PR(>F)'])
            out["p_int"] = float(an.loc['C(XGroup):C(HueGroup)','PR(>F)'])
            out["ok"] = True
        else:
            model = ols('value ~ C(XGroup)', data=d).fit()
            out["p_x"] = float(model.f_pvalue); out["ok"] = True
    except Exception as e:
        out["err"] = str(e)
    return out

def _stats_text(dfm, x_label, hue_label, *, mode="ref", ref_group=None, pair_list=None):
    df = dfm.dropna(subset=["value"]).copy()
    g_n = df["XGroup"].nunique(dropna=True)
    h_n = df["HueGroup"].nunique(dropna=True)

    if mode == "pairs" and pair_list:
        lines = ["Selected pairwise ANOVA tests:"]
        for a,b in pair_list:
            sub = df[df["XGroup"].isin([a,b])]
            res = _anova_subset(sub)
            if not res["ok"]:
                lines.append(f"{a} vs {b}: {res['err'] or 'failed'}"); continue
            if res["n_h"] >= 2:
                lines.append(
                    f"{a} vs {b} (Two-way: {x_label}, {hue_label})  "
                    f"{x_label}: {_fmt_p(res['p_x'])} | {hue_label}: {_fmt_p(res['p_h'])} | "
                    f"{x_label}×{hue_label}: {_fmt_p(res['p_int'])}"
                )
            else:
                lines.append(f"{a} vs {b} (One-way {x_label}): {_fmt_p(res['p_x'])}")
        return "\n".join(lines)

    def fmt(p): return _fmt_p(p)
    if g_n == 2 and h_n <= 1:
        g1, g2 = sorted(df["XGroup"].unique())
        v1 = df[df["XGroup"] == g1]["value"].dropna()
        v2 = df[df["XGroup"] == g2]["value"].dropna()
        if len(v1) > 1 and len(v2) > 1:
            p = pg.ttest(v1, v2, paired=False)["p-val"].values[0]
            return f"t-test ({x_label}): {fmt(p)}\n{g1} vs {g2}"
        return "t-test: not enough data"

    if g_n >= 2 and h_n >= 2:
        try:
            model = ols('value ~ C(XGroup) + C(HueGroup) + C(XGroup):C(HueGroup)', data=df).fit()
            an = sm.stats.anova_lm(model, typ=2)
            return (
                "Two-way ANOVA\n"
                f"{x_label}: {fmt(float(an.loc['C(XGroup)','PR(>F)']))}\n"
                f"{hue_label}: {fmt(float(an.loc['C(HueGroup)','PR(>F)']))}\n"
                f"{x_label}×{hue_label}: {fmt(float(an.loc['C(XGroup):C(HueGroup)','PR(>F)']))}"
            )
        except Exception as e:
            return f"ANOVA failed: {e}"

    if g_n >= 2:
        try:
            model = ols('value ~ C(XGroup)', data=df).fit()
            return f"One-way ANOVA ({x_label}): {fmt(float(model.f_pvalue))}"
        except Exception as e:
            return f"One-way ANOVA failed: {e}"
    return "Too few groups for stats"

# -----------------------
# 7) Plotting helpers
# -----------------------
def _p_to_stars(p):
    if not np.isfinite(p): return ""
    if p < 1e-4: return "****"
    if p < 1e-3: return "***"
    if p < 1e-2: return "**"
    if p < 5e-2: return "*"
    return ""

def _dot_palette(hues):
    hues = list(hues)
    if len(hues) == 0: return {}
    if len(hues) == 1: return {hues[0]: "black"}
    if len(hues) == 2: return {hues[0]: "white", hues[1]: "black"}
    defaults = plt.rcParams.get('axes.prop_cycle', None)
    colors = defaults.by_key()['color'] if defaults else ["C0","C1","C2","C3","C4","C5","C6","C7","C8","C9"]
    return {h: colors[i % len(colors)] for i, h in enumerate(hues)}

def _draw_bracket(ax, x1, x2, y, h, text):
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1, c="black", zorder=5)
    ax.text((x1+x2)/2, y+h, text, ha="center", va="bottom", fontsize=16, fontweight="bold")

def _plot_metric_clean(df_metric, variable, x_color_map, *, mode="ref", ref_group=None, pair_list=None, return_fig=False):
    dfm = df_metric.copy()
    order = _order_x_groups(dfm["XGroup"].dropna().unique().tolist())
    if not order: return None
    if (not ref_group) or (ref_group not in order):
        ref_group = _choose_ref_group(order)

    x_label_name   = _grouping_label("X")
    hue_label_name = _grouping_label("Hue")

    hue_levels = [h for h in dfm["HueGroup"].dropna().unique().tolist()]
    pal_dots = _dot_palette(hue_levels)

    width = max(2, 1 * len(order))
    height = 4.0
    fig, (ax_plot, ax_text) = plt.subplots(
        1, 2, figsize=(width/1.2, height), gridspec_kw={'width_ratios': [3, 1]}
    )

    # Bars
    bar_palette = [x_color_map.get(g, "tab:blue") for g in order]
    sns.barplot(data=dfm, x="XGroup", y="value", order=order, ci=None, alpha=ALPHA, ax=ax_plot, palette=bar_palette)

    # Determine hue levels in a controlled order
    raw_hues = dfm["HueGroup"].dropna().unique().tolist()
    hue_levels = _order_hue_groups(raw_hues)

    # Colors for dots — your helper already maps 2 hues as {hues[0]: "white", hues[1]: "black"}
    pal_dots = _dot_palette(hue_levels)

    # Points
    sns.stripplot(
        data=dfm, x="XGroup", y="value",
        order=order,
        hue="HueGroup",
        hue_order=hue_levels,        # <-- enforce hue order
        jitter=True, dodge=False, size=7,
        edgecolor="black", linewidth=1,
        palette=pal_dots,            # <-- colors aligned to hue_order
        ax=ax_plot, zorder=3, alpha=ALPHA
    )
    if ax_plot.legend_ is not None:
        ax_plot.legend_.remove()

    # Legend (right panel) in the same hue order
    if len(hue_levels) >= 2:
        handles = [plt.Line2D([0],[0], marker='o', linestyle='None',
                              markerfacecolor=pal_dots[h], markeredgecolor='black', label=str(h))
                  for h in hue_levels]
        ax_text.legend(handles=handles, title=hue_label_name, loc="upper left", bbox_to_anchor=(0, 0.6))

    # Annotations
    y_min, y_max = ax_plot.get_ylim()
    span = (y_max - y_min) if y_max > y_min else 1.0
    bump = 0.06 * span
    data_max = dfm["value"].max() if dfm["value"].notna().any() else y_max

    if mode == "ref" and (ref_group in order):
        ref_vals = dfm[dfm["XGroup"] == ref_group]["value"].dropna().to_numpy()
        for g in order:
            if g == ref_group: continue
            vals = dfm[dfm["XGroup"] == g]["value"].dropna().to_numpy()
            if len(vals) >= 2 and len(ref_vals) >= 2:
                try:
                    p = float(pg.ttest(vals, ref_vals, paired=False)["p-val"].values[0])
                except Exception:
                    p = np.nan
                if np.isfinite(p) and p < 0.05:
                    xloc = order.index(g)
                    gmax = dfm[dfm["XGroup"] == g]["value"].max()
                    y_star = (gmax if np.isfinite(gmax) else data_max) + bump
                    ax_plot.text(xloc, y_star, _p_to_stars(p),
                                 ha="center", va="bottom", fontsize=16, fontweight="bold")
                    y_max = max(y_max, y_star + bump)
        ax_plot.set_ylim(y_min, y_max)

    elif mode == "pairs" and pair_list:
        base = (dfm["value"].max() if dfm["value"].notna().any() else y_max) + bump
        step = 0.12 * span
        k = 0
        for a,b in pair_list:
            if (a not in order) or (b not in order):
                continue
            sub = dfm[dfm["XGroup"].isin([a,b])].dropna(subset=["value"])
            if sub["XGroup"].nunique() < 2:
                continue
            res = _anova_subset(sub)
            # draw bracket ONLY if XGroup effect significant
            if res["ok"] and np.isfinite(res["p_x"]) and (res["p_x"] < 0.05):
                x1 = order.index(a); x2 = order.index(b)
                if x1 > x2: x1, x2 = x2, x1
                y_here = base + k * step
                _draw_bracket(ax_plot, x1, x2, y_here, 0.04 * span, _p_to_stars(res["p_x"]))
                y_max = max(y_max, y_here + 0.08 * span)
                k += 1
        ax_plot.set_ylim(y_min, y_max)

    ax_plot.set_title("")
    ax_plot.set_xlabel("")
    ax_plot.set_ylabel(variable)
    plt.setp(ax_plot.get_xticklabels(), rotation=45, ha='right')
    sns.despine(ax=ax_plot)

    # Right panel: stats + Hue legend
    ax_text.axis("off")
    ax_text.text(
        0, 1,
        _stats_text(dfm, x_label_name, hue_label_name, mode=mode, ref_group=ref_group, pair_list=pair_list),
        va="top", ha="left", fontsize=12, transform=ax_text.transAxes
    )
    if len(hue_levels) >= 2:
        handles = [plt.Line2D([0],[0], marker='o', linestyle='None',
                              markerfacecolor=pal_dots[h], markeredgecolor='black', label=str(h))
                   for h in hue_levels]
        ax_text.legend(handles=handles, title=hue_label_name, loc="upper left", bbox_to_anchor=(0, 0.6))

    plt.tight_layout()
    return fig if return_fig else plt.show()

# -----------------------
# 8) Actions
# -----------------------
out = widgets.Output()

def _selected_x_and_colors():
    sel = _selected_x()
    color_map = {}
    for g in sel:
        val = x_colors[g].value.strip()
        color_map[g] = val if val else "tab:blue"
    return sel, color_map

def _current_pairs():
    return list(pairs_select.value)

def _run_plots(_=None):
    with out:
        clear_output()
        sel_x, color_map = _selected_x_and_colors()
        if len(sel_x) < 1:
            print("Select at least one X group."); return

        mode = mode_radio.value
        if mode == "ref":
            ref = ref_dropdown.value if (ref_dropdown.value in sel_x) else _choose_ref_group(sel_x)
            print(f"Showing X groups: {sel_x}  |  reference for stars: {ref}")
        else:
            pair_list = _current_pairs()
            if not pair_list:
                print(f"Showing X groups: {sel_x}  |  no pairs selected (select at least one)."); return
            print(f"Showing X groups: {sel_x}  |  pairs: {pair_list}")

        exclude = {"PeakAccuracy_Day","PeakAccuracy_Night",
                   "Win-stay_Day","Win-stay_Night",
                   "Lose-shift_Day","Lose-shift_Night"}
        metrics = [m for m in long_df["variable"].dropna().unique() if m not in exclude]

        for metric in metrics:
            subset = long_df[(long_df["variable"] == metric) & (long_df["XGroup"].isin(sel_x))]
            if subset["value"].dropna().empty:
                print(f"Skipping {metric} — no data for selected X groups."); continue
            if mode == "ref":
                _plot_metric_clean(
                    subset, metric,
                    x_color_map={g: color_map[g] for g in sel_x if g in subset['XGroup'].unique()},
                    mode="ref", ref_group=ref
                )
            else:
                _plot_metric_clean(
                    subset, metric,
                    x_color_map={g: color_map[g] for g in sel_x if g in subset['XGroup'].unique()},
                    mode="pairs", pair_list=_current_pairs()
                )

def _save_plots(_=None):
    with out:
        clear_output()
        sel_x, color_map = _selected_x_and_colors()
        if len(sel_x) < 1:
            print("Select at least one X group."); return

        mode = mode_radio.value
        ref = ref_dropdown.value if (mode == "ref") else None
        pair_list = _current_pairs() if (mode == "pairs") else None
        if mode == "pairs" and not pair_list:
            print("Select at least one pair before saving."); return

        os.makedirs("metric_comparisons", exist_ok=True)
        saved = 0

        exclude = {"PeakAccuracy_Day","PeakAccuracy_Night",
                   "Win-stay_Day","Win-stay_Night",
                   "Lose-shift_Day","Lose-shift_Night"}
        for metric in [m for m in long_df["variable"].dropna().unique() if m not in exclude]:
            subset = long_df[(long_df["variable"] == metric) & (long_df["XGroup"].isin(sel_x))]
            if subset["value"].dropna().empty: continue
            fig = _plot_metric_clean(
                subset, metric,
                x_color_map={g: color_map[g] for g in sel_x if g in subset['XGroup'].unique()},
                mode=mode, ref_group=ref, pair_list=pair_list, return_fig=True
            )
            safe = metric.replace(" ", "_").replace("/", "-")
            fig.savefig(f"metric_comparisons/{safe}.pdf", dpi=300, bbox_inches="tight")
            plt.close(fig); saved += 1

        if saved == 0:
            print("No figures to save."); return
        zipname = f"metric_comparisons_{int(time.time())}.zip"
        shutil.make_archive(zipname.replace(".zip",""), 'zip', "metric_comparisons")
        if colab_files is not None:
            colab_files.download(zipname)
        print(f"Saved {zipname}")

plot_btn.on_click(_run_plots)
save_btn.on_click(_save_plots)

# -----------------------
# 9) Assemble compact UI (two columns)
# -----------------------
# Toggle visibility of ref vs pairs widgets
def _toggle_controls(*_):
    if mode_radio.value == "ref":
        ref_dropdown.layout.display = ""
        pairs_select.layout.display = "none"
    else:
        ref_dropdown.layout.display = "none"
        pairs_select.layout.display = ""
_toggle_controls()
mode_radio.observe(lambda _: _toggle_controls(), names="value")

# Two-column layout container
row = widgets.HBox(
    [left_col, right_col],
    layout=widgets.Layout(
        justify_content="flex-start",   # keep columns together
        align_items="flex-start",
        gap="16px",
        width="auto"
    )
)

ui = widgets.VBox([
    widgets.HTML("<h3 style='margin-bottom:6px'></h3>"),
    row,
    out
], layout=widgets.Layout(width="auto"))

display(ui)

# Auto-run once
_run_plots()


In [None]:
# @title Plot Day vs Night metrics

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
from statsmodels.formula.api import ols

# ---- Config ----
BASES = ["Win-stay", "Lose-shift", "PeakAccuracy"]
DAY_TAG, NIGHT_TAG = "_Day", "_Night"
BAR_WIDTH = 0.36
NIGHT_ALPHA = 0.6
DOT_SIZE = 7
EDGE_LW_DAY = 2.0
EDGE_LW_NIGHT = 1.0

# ---- Helpers (reuse from prior cell when available) ----
def _safe_order(groups):
    groups = [g for g in groups if pd.notna(g)]
    if '_order_x_groups' in globals():
        try:
            return _order_x_groups(groups)
        except Exception:
            pass
    return sorted(groups, key=lambda s: str(s).upper())

def _selected_x_groups():
    if 'x_checks' in globals() and isinstance(x_checks, dict) and len(x_checks):
        sel = [g for g, cb in x_checks.items() if getattr(cb, "value", False)]
        return _safe_order(sel)
    # fallback: all groups present
    if 'long_df' not in globals():
        raise RuntimeError("long_df not found")
    return _safe_order(long_df["XGroup"].dropna().unique().tolist())

def _color_map(x_groups):
    if 'x_colors' in globals() and isinstance(x_colors, dict) and len(x_colors):
        out = {}
        for g in x_groups:
            w = x_colors.get(g, None)
            val = getattr(w, "value", None) if w is not None else None
            out[g] = (val.strip() if isinstance(val, str) and val.strip() else "blue")
        return out
    defaults = ["blue","orange","green","red","purple","brown","pink","gray","olive","cyan"]
    return {g: defaults[i % len(defaults)] for i, g in enumerate(x_groups)}

def _ensure_daynight_view(df, base):
    keep = {f"{base}{DAY_TAG}", f"{base}{NIGHT_TAG}"}
    d = df[df["variable"].isin(keep)].copy()
    if d.empty:
        return d
    d["DayNight"] = np.where(d["variable"].str.endswith(DAY_TAG), "Day", "Night")
    return d

def _dot_palette_local(hues):
    if '_dot_palette' in globals():
        return _dot_palette(hues)
    # fallback palette
    hues = list(hues)
    if len(hues) == 0: return {}
    if len(hues) == 1: return {hues[0]: "black"}
    if len(hues) == 2: return {hues[0]: "white", hues[1]: "black"}
    defaults = plt.rcParams.get('axes.prop_cycle', None)
    colors = defaults.by_key()['color'] if defaults else ["C0","C1","C2","C3","C4","C5","C6","C7","C8","C9"]
    return {h: colors[i % len(colors)] for i, h in enumerate(hues)}

def _label_for(which="X"):
    if '_grouping_label' in globals():
        return _grouping_label(which)
    return "XGroup" if which.lower().startswith("x") else "HueGroup"

# ---- Stats text (ANOVA with Day/Night included) ----
def _anova_daynight_text(df_dn, x_label, hue_label):
    d = df_dn.dropna(subset=["value","XGroup","DayNight"]).copy()
    if d["XGroup"].nunique() < 2 or d["DayNight"].nunique() < 2:
        return "Too few groups or missing Day/Night to run ANOVA."

    has_hue = d["HueGroup"].nunique(dropna=True) >= 2
    try:
        if has_hue:
            model = ols('value ~ C(XGroup) + C(DayNight) + C(HueGroup) + '
                        'C(XGroup):C(DayNight) + C(XGroup):C(HueGroup) + '
                        'C(DayNight):C(HueGroup) + C(XGroup):C(DayNight):C(HueGroup)',
                        data=d).fit()
            an = sm.stats.anova_lm(model, typ=2)
            def pget(term):
                return float(an.loc[term, 'PR(>F)']) if term in an.index else np.nan
            def fmt(p):
                if not np.isfinite(p): return "n/a"
                return f"p = {p:.3f}" if p >= 0.001 else "p < 0.001"
            return (
                "Three-way ANOVA (XGroup, Day/Night, Hue)\n"
                f"{x_label}: {fmt(pget('C(XGroup)'))}\n"
                f"Day/Night: {fmt(pget('C(DayNight)'))}\n"
                f"{hue_label}: {fmt(pget('C(HueGroup)'))}\n"
                f"{x_label}×Day/Night: {fmt(pget('C(XGroup):C(DayNight)'))}\n"
                f"{x_label}×{hue_label}: {fmt(pget('C(XGroup):C(HueGroup)'))}\n"
                f"Day/Night×{hue_label}: {fmt(pget('C(DayNight):C(HueGroup)'))}\n"
                f"{x_label}×Day/Night×{hue_label}: {fmt(pget('C(XGroup):C(DayNight):C(HueGroup)'))}"
            )
        else:
            model = ols('value ~ C(XGroup) + C(DayNight) + C(XGroup):C(DayNight)', data=d).fit()
            an = sm.stats.anova_lm(model, typ=2)
            def fmt(p):
                if not np.isfinite(p): return "n/a"
                return f"p = {p:.3f}" if p >= 0.001 else "p < 0.001"
            return (
                "Two-way ANOVA (XGroup, Day/Night)\n"
                f"{_label_for('X')}: {fmt(float(an.loc['C(XGroup)','PR(>F)']))}\n"
                f"Day/Night: {fmt(float(an.loc['C(DayNight)','PR(>F)']))}\n"
                f"{_label_for('X')}×Day/Night: {fmt(float(an.loc['C(XGroup):C(DayNight)','PR(>F)']))}"
            )
    except Exception as e:
        return f"ANOVA failed: {e}"

# ---- Main plotting for a single base ----
def _plot_day_night_for_base(base, x_groups, color_map):
    # subset
    df_dn = _ensure_daynight_view(long_df, base)
    df_dn = df_dn[df_dn["XGroup"].isin(x_groups)].copy()
    if df_dn.empty:
        print(f"Skipping {base}: no Day/Night data for selected groups.")
        return

    # stats labels
    x_label = _label_for("X")
    hue_label = _label_for("Hue")

    # palette for dots by HueGroup
    hue_levels = [h for h in df_dn["HueGroup"].dropna().unique().tolist()]
    pal_dots = _dot_palette_local(hue_levels)

    # compute means per group/daynight
    means = (df_dn.dropna(subset=["value"])
                  .groupby(["XGroup","DayNight"], as_index=False)["value"]
                  .mean().rename(columns={"value":"mean"}))
    # arrange grid for fast lookup
    grid = (means.set_index(["XGroup","DayNight"])["mean"]
                 .unstack("DayNight")
                 .reindex(index=x_groups, columns=["Day","Night"]))

    # figure layout: plot + stats panel
    width = max(4, 1.2 * len(x_groups))
    fig, (ax, ax_txt) = plt.subplots(1, 2, figsize=(width, 4.6), gridspec_kw={'width_ratios': [3, 1]})

    # bar positions
    x = np.arange(len(x_groups))
    off = BAR_WIDTH/2.0 + 0.02
    pos_day = x - off
    pos_night = x + off

    # N I G H T (filled)
    for i, g in enumerate(x_groups):
        val = grid.loc[g, "Night"] if ("Night" in grid.columns) else np.nan
        if pd.notna(val):
            ax.bar(pos_night[i], val, width=BAR_WIDTH, color=color_map[g],
                   alpha=NIGHT_ALPHA, edgecolor="black", linewidth=EDGE_LW_NIGHT, zorder=2)

    # D A Y (outline only)
    for i, g in enumerate(x_groups):
        val = grid.loc[g, "Day"] if ("Day" in grid.columns) else np.nan
        if pd.notna(val):
            ax.bar(pos_day[i], val, width=BAR_WIDTH, facecolor=(0,0,0,0),
                   edgecolor=color_map[g], linewidth=EDGE_LW_DAY, zorder=3)

    # overlay individual dots by HueGroup at each bar position
    rng = np.random.default_rng(42)
    jitter = lambda n: (rng.normal(0, 0.02, size=n))

    # Day dots (outline markers)
    sdf = df_dn[df_dn["DayNight"] == "Day"].dropna(subset=["value"])
    for g in x_groups:
        sub = sdf[sdf["XGroup"] == g]
        if sub.empty: continue
        px = pos_day[x_groups.index(g)]
        xs = px + jitter(len(sub))
        # color by HueGroup; outline only
        for h in sub["HueGroup"].unique():
            hh = sub[sub["HueGroup"] == h]
            if hh.empty: continue
            ax.scatter(np.full(len(hh), px) + jitter(len(hh)), hh["value"],
                       s=DOT_SIZE**2/2, facecolors=pal_dots.get(h, "black"),
                       edgecolors="black", linewidths=0.6, alpha=NIGHT_ALPHA, zorder=4)

    # Night dots (filled markers)
    sdf = df_dn[df_dn["DayNight"] == "Night"].dropna(subset=["value"])
    for g in x_groups:
        sub = sdf[sdf["XGroup"] == g]
        if sub.empty: continue
        px = pos_night[x_groups.index(g)]
        for h in sub["HueGroup"].unique():
            hh = sub[sub["HueGroup"] == h]
            if hh.empty: continue
            ax.scatter(np.full(len(hh), px) + jitter(len(hh)), hh["value"],
                       s=DOT_SIZE**2/2, facecolors=pal_dots.get(h, "black"),
                       edgecolors="black", linewidths=0.6, alpha=NIGHT_ALPHA, zorder=4)

    # axes cosmetics
    ax.set_xticks(x)
    ax.set_xticklabels(x_groups, rotation=45, ha="right")
    ax.set_ylabel(base)
    ax.set_title(f"{base}")
    sns.despine(ax=ax)



    # stats panel text
    ax_txt.axis("off")
    ax_txt.text(0, 1, _anova_daynight_text(df_dn, x_label, hue_label),
                va="top", ha="left", fontsize=12, transform=ax_txt.transAxes)

    if len(hue_levels) >= 2:
        handles = [plt.Line2D([0],[0], marker='o', linestyle='None',
                              markerfacecolor=pal_dots[h], markeredgecolor='black',
                              label=str(h)) for h in hue_levels]
        ax_txt.legend(handles=handles, title=hue_label,
                      loc="upper left", bbox_to_anchor=(0, 0.3),
                      frameon=False)
    plt.tight_layout()
    plt.show()

# ---- Run: plot for each base metric ----
if 'long_df' not in globals():
    raise RuntimeError("This cell expects long_df from the previous cell.")

Xsel = _selected_x_groups()
if not Xsel:
    print("No X groups selected or available.")
else:
    cmap = _color_map(Xsel)
    for base in BASES:
        _plot_day_night_for_base(base, Xsel, cmap)

In [None]:
# @title Peak Accuracy

import os, numpy as np, pandas as pd, seaborn as sns, matplotlib.pyplot as plt

# --- Config ---
TRIALS = 11   # trials before/after switch (i.e., window = [-11, +11])

# --- Preconditions ---
if 'feds' not in globals() or not isinstance(feds, (list, tuple)) or len(feds) == 0:
    raise RuntimeError("No FED3 sessions found in `feds`.")

if 'files_to_group_both' not in globals() or files_to_group_both is None or files_to_group_both.empty:
    raise RuntimeError("`files_to_group_both` is missing/empty. Build Groups first.")

# filename -> XGroup map (match on basename)
grp = files_to_group_both.copy()
src_col = "filename" if "filename" in grp.columns else ("File" if "File" in grp.columns else None)
if src_col is None:
    raise RuntimeError("`files_to_group_both` must include 'filename' or 'File'.")
basename = lambda s: os.path.basename(str(s))
grp["__base__"] = grp[src_col].astype(str).map(basename)
fname_to_xgroup = dict(grp[["__base__", "XGroup"]].dropna().values)

rows = []

def _is_empty(x):
    if x is None: return True
    try:
        return len(x) == 0
    except Exception:
        try:
            return np.size(x) == 0
        except Exception:
            return True

# --- Build rev_df ---
for i, sess in enumerate(feds):
    base = basename(getattr(sess, "name", f"session_{i}"))
    xg = fname_to_xgroup.get(base, "UNASSIGNED")

    # compute peri-switch trials using your f3b helper
    try:
        peh = f3b.reversal_peh(sess, (-TRIALS, TRIALS), return_avg=False)
    except Exception as e:
        print(f"[skip] {base}: reversal_peh failed: {e}")
        continue

    if _is_empty(peh):
        continue

    # peh is list/array of trial-length arrays; stack defensively
    for tr in list(peh):
        arr = np.asarray(tr).ravel()
        for t, v in enumerate(arr):
            rows.append({
                "Timepoint": t - TRIALS + 1,    # center around 0
                "Value": float(v) if np.isfinite(v) else np.nan,
                "Display_Group": xg
            })

rev_df = pd.DataFrame(rows)
rev_df = rev_df[np.isfinite(rev_df["Value"])]
rev_df = rev_df[rev_df["Timepoint"] != 0] # optional: drop center trial
if rev_df.empty:
    raise RuntimeError("No peri-switch data produced.")


# WT/Control detector (same as other cell)
def _is_wt_group_label(g):
    u = str(g).strip().upper()
    toks = [t for t in re.split(r'[^A-Z0-9]+', u) if t]
    return any(t in {"WT", "WILDTYPE", "CONTROL", "CTRL"} for t in toks)

def _x_levels_local(g):
    s = str(g)
    return [p.strip() for p in s.split("|")]

def _hier_sort_key_local(g):
    lv = _x_levels_local(g)
    wt_present = any(_is_wt_group_label(tok) for tok in lv) or _is_wt_group_label(g)
    wt_rank = 0 if wt_present else 1
    norm = []
    for tok in lv:
        is_blank = 1 if str(tok).strip().upper() in {"", "UNASSIGNED", "NONE", "NA", "N/A"} else 0
        norm.append((is_blank, str(tok).upper()))
    return (wt_rank,) + tuple(norm) + (str(g).upper(),)

def _order_x_groups_local(groups):
    return sorted(groups, key=_hier_sort_key_local)

_present = rev_df["Display_Group"].dropna().unique().tolist()

# ORDER: if the first cell was run, follow its ordered_x; otherwise reproduce the same rule.
if "ordered_x" in globals():
    group_order = [g for g in ordered_x if g in _present] + [g for g in _present if g not in ordered_x]
else:
    group_order = _order_x_groups_local(_present)

# COLORS: pull from the first cell's x_colors widget map; fall back if unavailable.
palette_map = None
if "x_colors" in globals() and isinstance(x_colors, dict) and len(x_colors) > 0:
    def _col(g):
        try:
            v = x_colors[g].value
            v = v.strip() if isinstance(v, str) else ""
            return v if v else "tab:blue"
        except Exception:
            return "tab:blue"
    palette_map = {g: _col(g) for g in group_order}
plt.figure(figsize=(10, 7))
# --- Plot (lines with SEM ribbons) ---
group_order = sorted(rev_df["Display_Group"].dropna().unique().tolist())

ax = sns.lineplot(
    data=rev_df.sort_values(["Display_Group", "Timepoint"]),
    x="Timepoint",
    y="Value",
    hue="Display_Group",
    hue_order=group_order,
    palette=palette_map,
    estimator="mean",
    errorbar="se",
    n_boot=0,
    lw=2
)

# Switch marker and label
ax.axvline(x=0, color="black", linestyle="--", linewidth=1.25)
ax.set_ylim(bottom=0)
ymin, ymax = ax.get_ylim()
ax.text(0.5, ymin + 0.95*(ymax - ymin), "Switch", color="red",
        fontsize=12, ha="left", va="top")

ax.set_xlabel("Trials around switch")
ax.set_ylabel("Peak Accuracy")
ax.set_title("")
ax.legend(
    title="",
    frameon=False,
    loc="center left",
    bbox_to_anchor=(1.12, 0.5),
    borderaxespad=0.0
)

sns.despine()
plt.tight_layout()

# Embed TrueType fonts so text stays editable in Illustrator/Inkscape
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype']  = 42

outdir = "figures"
os.makedirs(outdir, exist_ok=True)
fname = f"peak_accuracy_{datetime.now():%Y-%m-%d}.pdf"

fig = ax.get_figure()  # or plt.gcf() if you prefer
fig.savefig(
    os.path.join(outdir, fname),
    format="pdf",
    bbox_inches="tight",
    transparent=True,   # set True if you want a transparent background
)
files.download(os.path.join(outdir, fname))
plt.show()

In [None]:
# @title Cluster heatmap & PCA

import os, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# ---------- Pick Bandit source ----------
if 'Banditmetrics_merged' in globals() and Banditmetrics_merged is not None and not Banditmetrics_merged.empty:
    Bandit_src = Banditmetrics_merged.copy()
elif 'Banditmetrics_csv' in globals() and Banditmetrics_csv is not None and not Banditmetrics_csv.empty:
    Bandit_src = Banditmetrics_csv.copy()
elif 'Banditmetrics' in globals() and Banditmetrics is not None and not Banditmetrics.empty:
    Bandit_src = Banditmetrics.copy()
else:
    raise RuntimeError("No Bandit table found. Expect Banditmetrics_merged, Banditmetrics_csv, or Banditmetrics.")

# Ensure we have a File column (fallback from filename if needed)
if 'File' not in Bandit_src.columns:
    if 'filename' in Bandit_src.columns:
        Bandit_src = Bandit_src.copy()
        Bandit_src['File'] = Bandit_src['filename'].astype(str)
    else:
        raise RuntimeError("Bandit table must include 'File' or 'filename' column.")

# ---------- Groups from the Group UI ----------
def _basename(p): return os.path.basename(str(p))

if 'files_to_group_both' in globals() and files_to_group_both is not None and not files_to_group_both.empty:
    grp_map = files_to_group_both.copy()
    src_col = 'filename' if 'filename' in grp_map.columns else 'File'
    grp_map["File_base"] = grp_map[src_col].astype(str).apply(_basename)
    grp_map = grp_map[["File_base","XGroup","HueGroup"]]
else:
    raise RuntimeError("No groups found. Run the 'Group for plotting' widget (two-column) and click 'Build Groups'.")

# ---------- Merge groups into a working table (ldf) ----------
ldf = Bandit_src.copy()
ldf["File_base"] = ldf["File"].astype(str).apply(_basename)
ldf = ldf.merge(grp_map, on="File_base", how="left")
ldf["XGroup"] = ldf["XGroup"].fillna("UNASSIGNED")
ldf["HueGroup"] = ldf["HueGroup"].fillna("ALL")

# ---------- STRICT METRIC SELECTION (whitelist patterns) ----------
# Base metric names you actually have in Banditmetrics
_METRIC_BASES = [
    "Win-stay", "Lose-shift", "PeakAccuracy",
    "Total_pellets", "Total_pokes",
    "PokesPerPellet", "RetrievalTime", "PokeTime",
    "Daily_Pellets",
    # include any others you compute, e.g.:
    "Within_meal_pellet_rate",
]

# Build a regex that allows optional _Day/_Night and optional session-type suffixes
# e.g., PeakAccuracy, PeakAccuracy_Day, PeakAccuracy_FR1, PeakAccuracy_Day_FR1, etc.
def _metric_regex_from_bases(bases):
    safes = [re.escape(b) for b in bases]
    # ^(?:BASE1|BASE2)...(?:_(Day|Night))?(?:_.+)?$
    return re.compile(rf"^(?:{'|'.join(safes)})(?:_(?:Day|Night))?(?:_.+)?$")

_METRIC_RX = _metric_regex_from_bases(_METRIC_BASES)

ID_LIKE = {
    "File","filename","Mouse_ID","Strain","Sex","Genotype","Session_type",
    "File_base","XGroup","HueGroup"
}

# Select metric columns by name pattern only (not dtype)
metric_columns = [c for c in ldf.columns
                  if isinstance(c, str) and (c not in ID_LIKE) and _METRIC_RX.match(c)]
if not metric_columns:
    raise RuntimeError("No metric columns matched the whitelist. Check names or extend _METRIC_BASES.")

# Coerce metric columns to numeric
for c in metric_columns:
    ldf[c] = pd.to_numeric(ldf[c], errors="coerce")

# ---------- Build long & wide ----------
id_keep = [c for c in ["File","filename","Mouse_ID","Strain","Sex","Genotype","Session_type","XGroup","HueGroup"] if c in ldf.columns]

long_df = ldf.melt(
    id_vars=id_keep,
    value_vars=metric_columns,
    var_name="metric",
    value_name="value"
).copy()

# Wide (one row per file/animal), average duplicates
wide_index = [c for c in ["File","filename","Mouse_ID","Strain","Sex","Genotype","XGroup","HueGroup"] if c in long_df.columns]
wide_metrics = (
    long_df.pivot_table(
        index=wide_index,
        columns="metric",
        values="value",
        aggfunc="mean",
        observed=True
    )
    .reset_index()
)

# Metric columns are those not in the index
metric_columns = [c for c in wide_metrics.columns if c not in wide_index]

# ---------- Group means by (XGroup, HueGroup) for heatmap & PCA feature set ----------
if ("XGroup" not in wide_metrics.columns) or ("HueGroup" not in wide_metrics.columns):
    raise RuntimeError("XGroup and HueGroup are required to cluster on both.")

_x_order = sorted(wide_metrics["XGroup"].astype(str).unique())
_h_order = sorted(wide_metrics["HueGroup"].astype(str).unique())

group_means = (
    wide_metrics
    .assign(XGroup=wide_metrics["XGroup"].astype(str),
            HueGroup=wide_metrics["HueGroup"].astype(str))
    .groupby(["XGroup", "HueGroup"], dropna=False)[metric_columns]
    .mean()
    .reindex(pd.MultiIndex.from_product([_x_order, _h_order],
                                        names=["XGroup","HueGroup"]))
)

# Human-readable row labels "X | Hue"
group_means.index = [f"{x} | {h}" for x, h in group_means.index]

# ---------- Heatmap (min–max per column, with numeric annotations) ----------
def _fmt_cell(x):
    if pd.isna(x): return ""
    ax = abs(float(x))
    return f"{x:.0f}" if ax >= 100 else f"{x:.1f}" if ax >= 10 else f"{x:.2f}"

annot_data = group_means.applymap(_fmt_cell)

heatmap_scaled = group_means.copy()
for col in heatmap_scaled.columns:
    col_min, col_max = heatmap_scaled[col].min(), heatmap_scaled[col].max()
    if pd.isna(col_min) or pd.isna(col_max):
        heatmap_scaled[col] = 0.0
    elif col_max == col_min:
        heatmap_scaled[col] = 0.5
    else:
        heatmap_scaled[col] = (heatmap_scaled[col] - col_min) / (col_max - col_min)

# keep same row order for annotations
annot_data = annot_data.loc[heatmap_scaled.index]

# ---------- PCA (same metric set) ----------
pca_features = metric_columns[:]  # same set used in heatmap
mouse_data = wide_metrics.dropna(subset=pca_features).copy()

# Labels: prefer Mouse_ID, else filename (basename), else File
if "Mouse_ID" in mouse_data.columns and mouse_data["Mouse_ID"].notna().any():
    labels = mouse_data["Mouse_ID"].astype(str)
elif "filename" in mouse_data.columns and mouse_data["filename"].notna().any():
    labels = mouse_data["filename"].astype(str).apply(lambda p: os.path.basename(str(p)))
else:
    labels = mouse_data["File"].astype(str)
mouse_data["Label"] = labels

for col, default in [("XGroup", "UNASSIGNED"), ("HueGroup", "ALL")]:
    if col in mouse_data.columns:
        mouse_data[col] = mouse_data[col].astype(str)
    else:
        mouse_data[col] = default

X = StandardScaler().fit_transform(mouse_data[pca_features].values)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(X)

pca_df = pd.DataFrame(pca_result, columns=["PC1", "PC2"])
pca_df["Label"]    = mouse_data["Label"].values
pca_df["XGroup"]   = mouse_data["XGroup"].values
pca_df["HueGroup"] = mouse_data["HueGroup"].values

# Loadings (top 8 by |PC1|)
loadings = pd.DataFrame(pca.components_.T, index=pca_features, columns=["PC1", "PC2"])
top8_features = loadings.reindex(loadings["PC1"].abs().sort_values(ascending=False).head(8).index)
loadings_melted = top8_features[["PC1", "PC2"]].reset_index().melt(id_vars="index", var_name="PC", value_name="Loading")

# ---------- Plot (heatmap on top; PCA + loadings below) ----------
fig = plt.figure(figsize=(18, 12), constrained_layout=True)
gs = gridspec.GridSpec(
    2, 2, figure=fig,
    height_ratios=[0.8, 1.0],
    width_ratios=[1.0, 0.8],
    hspace=0.1, wspace=0.1
)

ax_heat    = fig.add_subplot(gs[0, :])   # top spans both
ax_scatter = fig.add_subplot(gs[1, 0])   # bottom-left
ax_bar     = fig.add_subplot(gs[1, 1])   # bottom-right

# 1) Heatmap
sns.heatmap(
    heatmap_scaled,
    ax=ax_heat,
    annot=annot_data.values,
    fmt="",
    cmap="Blues",
    linewidths=0.5,
    linecolor='gray',
    alpha=0.5,
    cbar=True
)
ax_heat.tick_params(axis='x', rotation=70)
ax_heat.tick_params(axis='y', rotation=0)
ax_heat.set_title("", fontsize=16, color="darkblue")
ax_heat.set_xlabel(""); ax_heat.set_ylabel("")

# 2) PCA scatter: color by XGroup; marker encodes HueGroup
unique_x = sorted(pca_df["XGroup"].unique())

x_colors = {}
if len(unique_x) >= 1: x_colors[unique_x[0]] = "dodgerblue"
if len(unique_x) >= 2: x_colors[unique_x[1]] = "red"
if len(unique_x) > 2:
    cmap = cm.get_cmap("tab20", len(unique_x) - 2)
    for i, grp in enumerate(unique_x[2:]):
        x_colors[grp] = mcolors.to_hex(cmap(i))

all_hues = sorted([h for h in pca_df["HueGroup"].unique()])
if len(all_hues) == 0:
    all_hues = ["ALL"]
if len(all_hues) == 1:
    hue_to_marker = {all_hues[0]: ("o", "filled")}
elif len(all_hues) == 2:
    hue_to_marker = {all_hues[0]: ("o", "hollow"), all_hues[1]: ("o", "filled")}
else:
    marker_cycle = ["o", "s", "^", "D", "P", "X", "*", "v", "<", ">"]
    hue_to_marker = {h: (marker_cycle[i % len(marker_cycle)], "filled") for i, h in enumerate(all_hues)}

for xg in unique_x:
    sub_x = pca_df[pca_df["XGroup"] == xg]
    for hg in all_hues:
        sub = sub_x[sub_x["HueGroup"] == hg]
        if sub.empty:
            continue
        marker, fill = hue_to_marker[hg]
        if fill == "hollow":
            ax_scatter.scatter(
                sub["PC1"], sub["PC2"],
                edgecolors=x_colors.get(xg, "black"), facecolors="none",
                s=110, linewidth=1.2, marker=marker, alpha=0.85,
                label=f"{xg} | {hg}"
            )
        else:
            ax_scatter.scatter(
                sub["PC1"], sub["PC2"],
                color=x_colors.get(xg, "black"),
                s=110, linewidth=0.8, marker=marker, alpha=0.85,
                label=f"{xg} | {hg}"
            )

# point labels
for _, row in pca_df.iterrows():
    ax_scatter.text(
        row["PC1"] + 0.05, row["PC2"] + 0.12, str(row["Label"]),
        fontsize=9, color="gray", alpha=0.6, ha="center", va="bottom"
    )

ax_scatter.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax_scatter.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")

handles, labels = ax_scatter.get_legend_handles_labels()
pairs = sorted({(lab, h) for lab, h in zip(labels, handles)}, key=lambda x: x[0])
if pairs:
    sorted_labels, sorted_handles = zip(*pairs)
    ax_scatter.legend(sorted_handles, sorted_labels, frameon=True, title="XGroup | HueGroup", fontsize=9)

# 3) Loadings barplot
sns.barplot(
    data=loadings_melted, y="index", x="Loading",
    hue="PC", hue_order=["PC1", "PC2"],
    ax=ax_bar, palette=["purple", "orange"], alpha=0.5
)
ax_bar.set_xlabel("Loading Weight")
ax_bar.set_title("Top 8 PC Loadings (metrics only)", fontsize=14)
ax_bar.axvline(0, color='gray', linestyle='--', linewidth=1)
ax_bar.set_ylabel("")
ax_bar.legend(title="", frameon=False, fontsize=9)

plt.show()
