<a href="https://colab.research.google.com/github/KravitzLab/Murrell2025/blob/main/Murrell_2026_Fig2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Murrell 2026 Figure 2
<br>
<img src="https://fed3bandit.readthedocs.io/en/latest/_static/fed3bandit_logo1.svg" width="200" />

Authors: Chantelle Murrell<br>
Updated: 12-30-25  

In [None]:
# @title Install libraries and import them {"run":"auto"}

import importlib.util
import subprocess
import sys

# Packages to ensure are installed (add others here if you like)
packages = {
    "fed3": "git+https://github.com/earnestt1234/fed3.git",
    "fed3bandit": "fed3bandit",
    "pingouin": "pingouin",
    "ipydatagrid": "ipydatagrid",
    "openpyxl": "openpyxl",
}

for name, source in packages.items():
    if importlib.util.find_spec(name) is None:
        print(f"Installing {name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", source])

# ----------------------------
# Imports
# ----------------------------
# Standard library
import copy
import io
import math
import os
import re
import shutil
import tempfile
import threading
import time
import warnings
import zipfile
import requests
import glob
from datetime import datetime, timedelta
from os.path import basename, splitext

# Third-party
from ipydatagrid import DataGrid, TextRenderer
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg
import fed3
import fed3.plot as fplot
import fed3bandit as f3b
from scipy.stats import f_oneway
import statsmodels.api as sm
from statsmodels.formula.api import ols
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.gridspec as gridspec
from matplotlib.ticker import PercentFormatter
from google.colab import files
try:
    from tqdm.auto import tqdm   # nice in notebooks; falls back to std tqdm on console
except Exception:
    # safe no-op fallback if tqdm isn't installed
    def tqdm(x):
        return x



# ----------------------------
# Configuration
# ----------------------------
warnings.filterwarnings("ignore")
plt.rcParams.update({'font.size': 12, 'figure.autolayout': True})
plt.rcParams['figure.figsize'] = [6, 4]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("Packages installed and imports ready.")


In [None]:
# @title Import FED3 FR1 Data

from urllib.request import urlretrieve
import os, zipfile, shutil
import pandas as pd
import numpy as np

try:
    from google.colab import output as colab_output
    colab_output.enable_custom_widget_manager()
except Exception:
    pass

zip_url = "https://github.com/KravitzLab/Murrell2025/raw/refs/heads/main/Data/FR1.zip"
key_url = "https://github.com/KravitzLab/Murrell2025/raw/refs/heads/main/Data/Murrell2026_Key.csv"

zip_dir = "/content/Murrell2025_zipdata"
zip_path = os.path.join(zip_dir, "Bandit100.zip")
extract_root = os.path.join(zip_dir, "Bandit100_extracted")
key_path = os.path.join(zip_dir, "Murrell2026_Key.csv")

os.makedirs(zip_dir, exist_ok=True)

# download + unzip (fresh each run)
if os.path.exists(zip_path):
    os.remove(zip_path)
if os.path.isdir(extract_root):
    shutil.rmtree(extract_root)

print("Importing github.com/KravitzLab/Murrell2025/Data/FR1.zip ...")
urlretrieve(zip_url, zip_path)

os.makedirs(extract_root, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
    zf.extractall(extract_root)

# if the zip contains a Bandit100 folder, use it; otherwise use extract_root
bandit_root = os.path.join(extract_root, "Bandit100")
local_parent_path = bandit_root if os.path.isdir(bandit_root) else extract_root

# load CSVs
feds, loaded_files, session_types = [], [], []

for dirpath, _, filenames in os.walk(local_parent_path):
    for file_name in sorted(filenames):
        if file_name.lower().endswith(".csv"):
            file_path = os.path.join(dirpath, file_name)
            strain_name = os.path.basename(os.path.dirname(file_path))

            df = fed3.load(file_path)
            df.name = file_name
            df["Strain"] = strain_name
            df["SourceFile"] = file_name

            feds.append(df)
            loaded_files.append(file_path)

            st = df["Session_Type"].dropna().astype(str)
            session_types.append(st.iloc[0] if len(st) else None)

print(f"Loaded {len(feds)} CSV files.")

# download + load key
urlretrieve(key_url, key_path)

key_df = pd.read_csv(key_path, encoding="utf-8-sig")
key_df["Mouse_ID"] = key_df["Mouse_ID"].astype(str).str.strip()

# match Mouse_ID by substring in filename base
def _base_lower(p):
    return os.path.splitext(os.path.basename(p))[0].lower()

files_df = pd.DataFrame({"filename": loaded_files, "Session_type": session_types})
files_df["_base"] = files_df["filename"].map(_base_lower)

mouse_ids = (
    key_df["Mouse_ID"]
    .dropna().astype(str).str.strip()
    .replace("", np.nan).dropna().unique().tolist()
)

rows = []
for fname, base in zip(files_df["filename"], files_df["_base"]):
    hits = [mid for mid in mouse_ids if mid.lower() in base]
    rows.append({"filename": fname, "Mouse_ID": hits[0] if len(hits) else None})

matched = pd.DataFrame(rows)

Key_Df = (
    files_df.drop(columns=["_base"])
    .merge(matched, on="filename", how="left")
    .merge(key_df.drop_duplicates("Mouse_ID"), on="Mouse_ID", how="left")
)

# display
grid = DataGrid(
    Key_Df.reset_index(drop=True),
    editable=True,
    selection_mode="cell",
    layout={"height": "420px"},
    base_row_size=28,
    base_column_size=120,
)
grid.default_renderer = TextRenderer(text_wrap=True)
display(grid)

In [None]:
# @title Plot individual files (males are blue, females are red)
# ----- Inputs -----
assert 'feds' in globals() and isinstance(feds, list) and len(feds) > 0, "No FED3 files loaded."
assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."

TIMESTAMP_COL_CANON = "MM:DD:YYYY hh:mm:ss"  # primary target column name

# ----- Helpers -----
def _find_time_col(df):
    # exact match first
    if TIMESTAMP_COL_CANON in df.columns:
        return TIMESTAMP_COL_CANON
    # tolerant search (case/space-insensitive)
    lc = {str(c).strip().lower(): c for c in df.columns}
    for key in lc:
        if key.replace(" ", "") in {"mm:dd:yyyyhh:mm:ss", "mm:dd:yyyy_hh:mm:ss", "mm/dd/yyyyhh:mm:ss"}:
            return lc[key]
    return None

def _parse_ts(series):
    # robust parsing; coerce errors to NaT
    return pd.to_datetime(series, errors="coerce", infer_datetime_format=True)


# ----- Plotting
files_list = feds

# metadata_df = copy of Key_Df
metadata_df = Key_Df.copy().reset_index(drop=True)
if 'filename' in metadata_df.columns:
    metadata_df['filename'] = metadata_df['filename'].astype(str).map(os.path.basename)

def _coerce_numeric_col(df, col, clip_upper=None, na_map=None):
    if col not in df.columns:
        return
    s = df[col]
    if na_map:
        s = s.replace(na_map)
    s = pd.to_numeric(s, errors='coerce')
    if clip_upper is not None:
        s.loc[s > clip_upper] = np.nan
    df[col] = s

def _plot_file_core(file_index):
    df = files_list[file_index].copy()
    full_name = getattr(df, 'name', f"File_{file_index}")
    file_basename = os.path.basename(str(full_name))

    # Preserve original index once
    if "Original_Timestamp" not in df.columns:
        df["Original_Timestamp"] = df.index

    # Attach metadata by filename (matching already done upstream)
    meta_row = None
    if 'filename' in metadata_df.columns:
        mr = metadata_df.loc[metadata_df['filename'] == file_basename]
        if not mr.empty:
            meta_row = mr.iloc[0]
    if meta_row is not None:
        for col in meta_row.index:
            if col == 'filename':
                continue
            if col not in df.columns:
                df[col] = meta_row[col]
            else:
                if pd.isna(df[col]).all() and pd.notna(meta_row[col]):
                    df[col] = meta_row[col]

    # Time + cleanup
    try:
        df['timestamp'] = pd.to_datetime(df.index)
    except Exception:
        df['timestamp'] = np.arange(len(df))

    _coerce_numeric_col(df, 'Poke_Time', clip_upper=2)
    _coerce_numeric_col(df, 'Retrieval_Time', na_map={"Timed_out": np.nan})

    if len(df) == 0:
        print(f"[!] Empty cropped dataframe for {file_basename}. Skipping plot.")
        return

    # Behavioral traces (needs f3b)
    true_left = 1
    mouse_left = f3b.binned_paction(df, window=10)

    # Plot
    fig, ax = plt.subplots(figsize=(3.3, 3))
    ax.plot(1, 1, color="black", linewidth=2, alpha=0.5)

    color = "dodgerblue"
    if 'Sex' in df.columns and pd.notna(df['Sex']).any():
        try:
            color = "red" if str(df['Sex'].iloc[0]).strip().lower().startswith("f") else "dodgerblue"
        except Exception:
            pass
    ax.plot(np.arange(len(mouse_left)), mouse_left, color=color, linewidth=3, alpha=0.7)

    # ---- Clean look: remove ticks, labels, spines ----
    ax.set_xlabel("")                 # no x label
    ax.set_ylabel("")                 # no y label
    ax.tick_params(axis="both", which="both",
                   bottom=False, top=False, left=False, right=False,
                   labelbottom=False, labelleft=False)
    for s in ax.spines.values():      # remove axis lines
        s.set_visible(False)

    # ---- Textual y "labels" at y=1 and y=0 (not ticks) ----
    ax.text(-0.01, 1.0, "Left", transform=ax.get_yaxis_transform(),
            ha="right", va="center", fontsize = 12)
    ax.text(-0.01, 0.0, "Right",  transform=ax.get_yaxis_transform(),
            ha="right", va="center", fontsize = 12)

    # ---- Title-area arrow above the axes ----
    # spans full width; adjust y (1.08–1.15) if you need more/less space
    ax.annotate("",
                xy=(0.95, 1.05), xytext=(0.05, 1.05),
                xycoords="axes fraction",
                arrowprops=dict(arrowstyle="->", lw=5, color="0.6"))
    # Optional: small caption above the arrow (example: duration text)
    ax.text(0.4, 1.15, "1 day", transform=ax.transAxes, ha="left", va="bottom", color="0.5", fontsize = 12)
    sns.despine(left=True, bottom=True, top=True, right=True)
    plt.tight_layout()
    plt.show()

# ----- Simple UI: slider + status + output -----
N = len(files_list)
assert N > 0, "No files available after cropping."
idx_slider = widgets.IntSlider(min=0, max=max(0, N-1), step=1, value=0, description='File', continuous_update=True)
status_lbl = widgets.HTML()
out = widgets.Output()

def _status(idx):
    name = getattr(files_list[idx], 'name', f"File_{idx}")
    return f"Index: <b>{idx}</b> &nbsp;|&nbsp; File: <code>{os.path.basename(str(name))}</code> &nbsp;|&nbsp; Rows: {len(files_list[idx])}"

def _render(*_):
    idx = int(idx_slider.value)
    status_lbl.value = _status(idx)
    out.clear_output()
    with out:
        _plot_file_core(idx)

idx_slider.observe(_render, names='value')
display(widgets.VBox([idx_slider, status_lbl, out]))
_render()


In [None]:
# @title Analyze FR1 metrics
from pathlib import Path
import os
def _find_time_col(df):
    for c in ["MM:DD:YYYY hh:mm:ss", "DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if c in df.columns:
            return c
    return None

def _get_timestamp_series(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    import pandas as pd
    if ts_col in df.columns:
        ts = pd.to_datetime(df[ts_col], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        return pd.Series(ts, index=df.index)
    for cand in ["DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            return pd.Series(ts, index=df.index)
    idx = df.index
    if isinstance(idx, pd.DatetimeIndex):
        return pd.Series(idx, index=df.index)
    return pd.to_datetime(pd.Series(idx, index=df.index), errors="coerce")

def _crop_last_24h(df):
    import pandas as pd
    dfc = df.copy()
    ts_col = _find_time_col(dfc)

    if ts_col is not None:
        ts_series = pd.to_datetime(dfc[ts_col], errors="coerce", infer_datetime_format=True)
        ts_series = pd.Series(ts_series, index=dfc.index)
    else:
        try:
            ts_idx = pd.to_datetime(dfc.index, errors="coerce", infer_datetime_format=True)
            ts_series = pd.Series(ts_idx, index=dfc.index)
        except Exception:
            return df  # no usable timestamps → return original

    if ts_series.isna().all():
        return df

    end = ts_series.max()
    if pd.isna(end):
        return df
    start = end - pd.Timedelta(hours=24)

    mask = ts_series.between(start, end, inclusive="both")
    cropped = dfc.loc[mask]

    if hasattr(df, "name"):
        cropped.name = df.name
    return cropped

def build_feds_cropped(sessions):
    """Return a list of sessions cropped to their last 24h."""
    return [_crop_last_24h(d) for d in sessions]

# Use it like this:
assert 'feds' in globals() and isinstance(feds, (list, tuple)) and len(feds) > 0, "No FED3 files available."
feds_cropped = build_feds_cropped(feds)

# Prefer cropped sessions downstream
_sessions = list(feds_cropped) if len(feds_cropped) > 0 else list(feds)
# ----- Build feds_cropped -----
feds_cropped = [_crop_last_24h(d) for d in feds]
# ---------- Inputs ----------
# Prefer cropped sessions
if 'feds_cropped' in globals() and isinstance(feds_cropped, (list, tuple)) and len(feds_cropped) > 0:
    _sessions = list(feds_cropped)
else:
    assert 'feds' in globals() and isinstance(feds, (list, tuple)) and len(feds) > 0, "No FED3 files available."
    _sessions = list(feds)

# metadata_df from Key_Df
if 'metadata_df' not in globals() or not isinstance(metadata_df, pd.DataFrame):
    assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."
    metadata_df = Key_Df.copy().reset_index(drop=True)

def _basename(pathlike) -> str:
    s = str(pathlike).replace("\\", "/")
    return s.split("/")[-1]

def _get_timestamp_series(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    if ts_col in df.columns:
        ts = pd.to_datetime(df[ts_col], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        return pd.Series(ts, index=df.index)
    for cand in ["DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            return pd.Series(ts, index=df.index)
    idx = df.index
    if isinstance(idx, pd.DatetimeIndex):
        return pd.Series(idx, index=df.index)
    return pd.to_datetime(pd.Series(idx, index=df.index), errors="coerce")

def _split_day_night(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    ts = _get_timestamp_series(df, ts_col=ts_col)
    valid = ts.notna()
    hrs = ts.dt.hour
    day_mask = valid & (hrs >= 7) & (hrs < 19)
    night_mask = valid & ~day_mask
    return df.loc[day_mask], df.loc[night_mask]

def compute_withinbout_lose_shift(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 2:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        total = shifted = 0
        for i in range(len(events) - 1):
            curr_evt, next_evt = events[i], events[i + 1]
            if curr_evt not in ("Left", "Right"):
                continue
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if np.isnan(dt_s) or dt_s > max_gap_s:
                continue
            if next_evt == "Pellet":
                continue
            if next_evt in ("Left", "Right"):
                total += 1
                if next_evt != curr_evt:
                    shifted += 1
        return (shifted / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_withinbout_win_stay(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 3:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        pellet_idx = [i for i in range(1, len(events) - 1) if events[i] == "Pellet"]
        total = same = 0
        for i in pellet_idx:
            prev_event, next_event = events[i - 1], events[i + 1]
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if not np.isnan(dt_s) and dt_s <= max_gap_s:
                if prev_event in ("Left", "Right") and next_event in ("Left", "Right"):
                    total += 1
                    if next_event == prev_event:
                        same += 1
        return (same / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_peak_accuracy(c_df):
    try:
        if "Event" not in c_df.columns or len(c_df) == 0:
            return np.nan

        events = c_df["Event"]
        left_count = (events == "Left").sum()
        right_count = (events == "Right").sum()
        total = left_count + right_count

        return (left_count / total) * 100 if total > 0 else np.nan
    except Exception:
        return np.nan


def estimate_daily_pellets(c_df):
    ts = _get_timestamp_series(c_df)
    valid_ts = ts.dropna()
    if valid_ts.size < 2:
        return np.nan
    duration_hours = (valid_ts.max() - valid_ts.min()).total_seconds() / 3600.0
    if duration_hours <= 0:
        return np.nan

    pellet_events = np.nan
    if "Pellet_Count" in c_df.columns and c_df["Pellet_Count"].notna().any():
        pc = pd.to_numeric(c_df["Pellet_Count"], errors="coerce")
        if pc.notna().any():
            diffs = pc.diff().fillna(0).clip(lower=0)
            pellet_events = float(diffs.sum())
            if pellet_events == 0 and pc.iloc[-1] >= pc.iloc[0]:
                pellet_events = float(pc.iloc[-1] - pc.iloc[0])
    if (pd.isna(pellet_events)) and ("Event" in c_df.columns):
        pellet_events = float((c_df["Event"] == "Pellet").sum())

    if pd.isna(pellet_events):
        return np.nan
    return (pellet_events / duration_hours) * 24.0

# ---------- Prepare metadata (merge once by filename) ----------
md = metadata_df.copy()
md['filename'] = md['filename'].astype(str).map(_basename)
if 'Mouse_ID' in md.columns:
    md['Mouse_ID'] = md['Mouse_ID'].astype(str).str.strip()
else:
    md['Mouse_ID'] = np.nan

# Keep only metadata columns we care about; rename to avoid accidental dupes
# (add/remove columns as needed)
meta_keep = [c for c in md.columns if c in {"filename", "Mouse_ID", "Session_type", "Cohort", "Strain", "Sex"}]
md_clean = md[meta_keep].drop_duplicates(subset=["filename"], keep="first")

# ---------- Compute metrics on the chosen sessions ----------
rows = []
for idx in tqdm(range(len(_sessions))):
    c_df = _sessions[idx]
    file_name = _basename(getattr(c_df, "name", f"File_{idx}"))

    try:
        clean_retrieval_time = pd.to_numeric(c_df.get("Retrieval_Time", pd.Series(dtype=float)), errors="coerce")
        clean_retrieval_time = clean_retrieval_time[clean_retrieval_time < 5]

        clean_poke_time = pd.to_numeric(c_df.get("Poke_Time", pd.Series(dtype=float)), errors="coerce")
        clean_poke_time = clean_poke_time[clean_poke_time > 0]

        day_df, night_df = _split_day_night(c_df, ts_col="MM:DD:YYYY hh:mm:ss")

        row = {
            "filename": file_name,
            "PeakAccuracy": compute_peak_accuracy(c_df),
            "Total_pellets": f3b.count_pellets(c_df),
            "Total_pokes": f3b.count_pokes(c_df),
            "PokesPerPellet": f3b.pokes_per_pellet(c_df),
            "RetrievalTime": clean_retrieval_time.median() if not clean_retrieval_time.empty else np.nan,
            "PokeTime": clean_poke_time.median() if not clean_poke_time.empty else np.nan,
            "Win-stay": compute_withinbout_win_stay(c_df),
            "Lose-shift": compute_withinbout_lose_shift(c_df),
            "daily pellets": estimate_daily_pellets(c_df),
            "PeakAccuracy_Day": compute_peak_accuracy(day_df),
            "PeakAccuracy_Night": compute_peak_accuracy(night_df),
            "Win-stay_Day": compute_withinbout_win_stay(day_df),
            "Win-stay_Night": compute_withinbout_win_stay(night_df),
            "Lose-shift_Day": compute_withinbout_lose_shift(day_df),
            "Lose-shift_Night": compute_withinbout_lose_shift(night_df),
        }
        rows.append(row)

    except Exception as e:
        print(f"Failed on {file_name} (idx {idx}): {e}")


FR1_metrics = pd.DataFrame(rows)
FR1_metrics = FR1_metrics.merge(md_clean, on="filename", how="left")
FR1_metrics = FR1_metrics.loc[:, ~FR1_metrics.columns.duplicated()]

csv_name = "FR1_metrics.csv"
FR1_metrics.to_csv(csv_name, index=False)

from google.colab import files
import ipywidgets as widgets
from IPython.display import display

def download_csv(b):
    files.download(csv_name)

download_button = widgets.Button(
    description="⬇️ Download summary stats (CSV)",
    button_style="primary",
)

download_button.on_click(download_csv)
display(download_button)


In [None]:
# @title Plot Female vs Male
import os, time, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output

try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

# -----------------------
# Config
# -----------------------
GROUP_COL = "Sex"
metrics = [
    "daily pellets", "Total_pokes", "PokesPerPellet", "PokeTime",
    "PeakAccuracy", "Win-stay", "Lose-shift", "RetrievalTime",
]



COLOR_MAP = {"F": "red", "M": "dodgerblue"}

# -----------------------
# Preconditions
# -----------------------
if "FR1_metrics" not in globals() or FR1_metrics is None or FR1_metrics.empty:
    raise RuntimeError("FR1_metrics is missing/empty. Run the metrics cell first.")

bm = FR1_metrics.copy()

# -----------------------
# Helper: normalize Sex to F/M/UNK
# -----------------------
def _norm_sex(x):
    s = str(x).strip().upper()
    if s in {"F", "FEMALE", "FEM"}: return "F"
    if s in {"M", "MALE"}: return "M"
    return "UNK"

# -----------------------
# Ensure we have Sex (merge from metadata_df if needed)
# -----------------------
if GROUP_COL not in bm.columns or bm[GROUP_COL].isna().all():

    if "metadata_df" not in globals() or metadata_df is None or metadata_df.empty:
        raise RuntimeError("Sex not found in Bandit100_metrics and metadata_df is missing/empty.")

    meta = metadata_df.copy()

    # Normalize column names for robust lookup
    def _find_col(df, name):
        lc = {str(c).strip().lower(): c for c in df.columns}
        return lc.get(name.lower(), None)

    meta_sex = _find_col(meta, "Sex")
    if meta_sex is None:
        raise RuntimeError("metadata_df does not contain a 'Sex' column (case-insensitive).")

    # Prefer merging by Mouse_ID if present in both
    meta_mouse = _find_col(meta, "Mouse_ID")
    bm_mouse   = _find_col(bm, "Mouse_ID")

    merged = None

    if meta_mouse is not None and bm_mouse is not None:
        meta_key = meta[[meta_mouse, meta_sex]].copy()
        meta_key.columns = ["Mouse_ID", "Sex"]
        meta_key["Mouse_ID"] = meta_key["Mouse_ID"].astype(str).str.strip()
        meta_key = meta_key.dropna(subset=["Mouse_ID"]).drop_duplicates("Mouse_ID")

        bm["Mouse_ID"] = bm[bm_mouse].astype(str).str.strip()
        merged = bm.merge(meta_key, on="Mouse_ID", how="left")

    # Fallback: merge by filename basename if Mouse_ID isn’t available
    if merged is None or merged["Sex"].isna().all():
        # build/ensure filename columns
        if "filename" not in bm.columns:
            if "File" in bm.columns:
                bm["filename"] = bm["File"].astype(str)
            else:
                raise RuntimeError("Need either Mouse_ID or filename/File in Bandit100_metrics to merge Sex.")

        if "filename" not in meta.columns:
            # if metadata_df already has filename, great; otherwise cannot fallback
            raise RuntimeError("Cannot fallback merge: metadata_df has no filename column.")

        bm["file_base"]   = bm["filename"].astype(str).apply(lambda p: os.path.basename(p))
        meta["file_base"] = meta["filename"].astype(str).apply(lambda p: os.path.basename(p))

        meta_key = meta[["file_base", meta_sex]].copy()
        meta_key.columns = ["file_base", "Sex"]
        meta_key = meta_key.dropna(subset=["file_base"]).drop_duplicates("file_base")

        merged = bm.merge(meta_key, on="file_base", how="left").drop(columns=["file_base"])

    bm = merged

# Normalize Sex values
bm[GROUP_COL] = bm[GROUP_COL].apply(_norm_sex)

# Keep only F/M for plotting (drop UNK)
bm = bm[bm[GROUP_COL].isin(["F", "M"])].copy()
if bm.empty:
    raise RuntimeError("After merging/normalizing Sex, no rows with Sex in {F, M} were found.")

# -----------------------
# Long format
# -----------------------
value_vars = [m for m in metrics if m in bm.columns]
if not value_vars:
    raise RuntimeError("None of the expected metric columns were found in Bandit100_metrics.")

id_vars = [c for c in ["filename", "Mouse_ID", "Strain", GROUP_COL] if c in bm.columns]
long_df = bm.melt(id_vars=id_vars, value_vars=value_vars, var_name="metric", value_name="value")

# consistent order
groups = [g for g in ["F", "M"] if g in long_df[GROUP_COL].unique()]
if len(groups) < 2:
    raise RuntimeError(f"Need both F and M present to compare; found: {groups}")

# -----------------------
# Stats
# -----------------------
def welch_p(a, b):
    a = pd.Series(a, dtype=float).dropna()
    b = pd.Series(b, dtype=float).dropna()
    if len(a) < 2 or len(b) < 2:
        return np.nan
    return float(pg.ttest(a, b, paired=False)["p-val"].iat[0])

# -----------------------
# Plot UI
# -----------------------
out = widgets.Output()
save_btn = widgets.Button(description="Save PDF", button_style="success")
_last_fig = None

def run_plots():
    global _last_fig
    with out:
        clear_output()

        fig, axes = plt.subplots(2, 4, figsize=(8, 6), constrained_layout=True)
        axes = axes.ravel()

        for i, metric in enumerate(metrics):
            ax = axes[i]
            if metric not in value_vars:
                ax.set_axis_off()
                continue

            dfm = long_df[long_df["metric"] == metric].dropna(subset=["value"])

            pal = {g: COLOR_MAP[g] for g in groups}

            sns.barplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups, ci=None, alpha=0.6,
                palette=pal, ax=ax
            )
            sns.stripplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups,
                color="white", edgecolor="black",
                linewidth=1, size=6,
                alpha=0.35, jitter=True, ax=ax
            )

            a = dfm.loc[dfm[GROUP_COL] == groups[0], "value"]
            b = dfm.loc[dfm[GROUP_COL] == groups[1], "value"]
            p = welch_p(a, b)
            label = (
                "p<0.001" if np.isfinite(p) and p < 0.001
                else (f"p={p:.3f}" if np.isfinite(p) else "p=NA")
            )

            ax.set_title("")
            ax.set_xlabel("")
            ax.set_ylabel(metric, fontsize=12)
            ax.text(0.5, 1.02, label, transform=ax.transAxes, ha="center", va="bottom")

            sns.despine(ax=ax)

        plt.show()
        _last_fig = fig

def save_plots(_=None):
    if _last_fig is None:
        with out:
            print("Nothing to save yet.")
        return
    fname = f"metrics_grid_{int(time.time())}.pdf"
    _last_fig.savefig(fname, dpi=300, bbox_inches="tight")
    with out:
        print(f"Saved {fname}")
    if colab_files is not None:
        colab_files.download(fname)

save_btn.on_click(save_plots)

display(save_btn)
display(out)

run_plots()


In [None]:
#TODO ADD AVERAGE LEARNING CURVES























In [None]:
# TODO - Update "Bandit80" in below called to be "FR1"

In [None]:
#TODO ADD HEATMAP OF BANDIT/FR1

#@title Correlation plots
import os
import subprocess
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr

Bandit100metrics  = pd.read_csv(r"https://raw.githubusercontent.com/KravitzLab/Murrell2025/refs/heads/main/Data/SummaryStats/Bandit100_metrics.csv")
Bandit80metrics = pd.read_csv(r"https://raw.githubusercontent.com/KravitzLab/Murrell2025/refs/heads/main/Data/SummaryStats/FR1_metrics.csv")

print("Loaded Bandit80metrics:", Bandit80metrics.shape)
print("Loaded Bandit100_metrics:", Bandit100metrics.shape)

# -------------------------------------------------------------------
# 2) Match metrics to Key_Df by Mouse_ID
#    (assumes Key_Df already exists in the environment)
# -------------------------------------------------------------------
if "Key_Df" not in globals():
    raise RuntimeError("Key_Df is not defined. Build/rematch Key_Df first.")

# Ensure Mouse_ID is string and stripped in all tables
def _clean_mouse_id(df, col="Mouse_ID"):
    if col in df.columns:
        df[col] = df[col].astype(str).str.strip()
    return df

Key_Df = _clean_mouse_id(Key_Df.copy())
Bandit80metrics  = _clean_mouse_id(Bandit80metrics.copy())
Bandit100_metrics = _clean_mouse_id(Bandit100metrics.copy())

# Add suffixes to metric columns so we can have both in the same table
id_cols = {"Mouse_ID", "filename", "Session_type", "Cohort", "Strain", "Sex"}

def _add_suffix_to_metrics(df, suffix):
    df = df.copy()
    rename_map = {
        c: f"{c}{suffix}"
        for c in df.columns
        if c not in id_cols
    }
    return df.rename(columns=rename_map)

b80_suffixed  = _add_suffix_to_metrics(Bandit80metrics,  "_Bandit80")
b100_suffixed = _add_suffix_to_metrics(Bandit100_metrics, "_Bandit100")

# Merge the two metrics tables on Mouse_ID
metrics_merged = pd.merge(
    b80_suffixed,
    b100_suffixed,
    on="Mouse_ID",
    how="outer",
    suffixes=("", "_dup")  # should not be needed because we already suffixed
)

print("Merged metrics table shape:", metrics_merged.shape)

# Merge metrics into Key_Df
Key_with_metrics = pd.merge(
    Key_Df,
    metrics_merged,
    on="Mouse_ID",
    how="left"
)

print("Key_with_metrics shape:", Key_with_metrics.shape)

# For convenience, use this DataFrame for correlations/plots:
df = Key_with_metrics
df = Key_with_metrics.copy()

# In case Sex wasn't carried through for some reason, re-merge it from Key_Df
if "Sex" not in df.columns and "Mouse_ID" in df.columns:
    sex_map = Key_Df[["Mouse_ID", "Sex"]].copy()
    sex_map["Mouse_ID"] = sex_map["Mouse_ID"].astype(str).str.strip()
    df["Mouse_ID"] = df["Mouse_ID"].astype(str).str.strip()
    df = df.merge(sex_map, on="Mouse_ID", how="left")

print("Columns in df:", df.columns.tolist())
# -------------------------------------------------------------------
# 3) Correlate Bandit80 vs Bandit100 metrics and plot
#    Columns expected to exist after suffixing:
#    PeakAccuracy_Bandit80, PeakAccuracy_Bandit100, etc.
# -------------------------------------------------------------------
pairs = [
    ("PeakAccuracy_Bandit80", "PeakAccuracy_Bandit100", "PeakAccuracy"),
    ("Win-stay_Bandit80",     "Win-stay_Bandit100",     "Win-stay"),
    ("Lose-shift_Bandit80",   "Lose-shift_Bandit100",   "Lose-shift"),
]

# Per-subplot axis limits (edit these to taste)
axis_limits = {
    "Peak Accuracy": {"x": (0.4, 1.0), "y": (0.4, 1.0)},
    "Win-stay":      {"x": (0.4, 1.0), "y": (0.4, 1.0)},
    "Lose-shift":    {"x": (0.0, 1.0), "y": (0.0, 1.0)},
}


fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=False, sharey=False)

for ax, (col80, col100, label) in zip(axes, pairs):
    needed_cols = [col80, col100, "Sex"]
    missing = [c for c in needed_cols if c not in df.columns]
    if missing:
        print(f"Skipping {label}: missing columns {missing}")
        continue

    sub = df[needed_cols].dropna()
    if len(sub) < 2:
        print(f"{col80} vs {col100}: not enough data (n={len(sub)})")
        continue

    # Normalize Sex values to 'Male' / 'Female'
    sub = sub.copy()
    sub["Sex_plot"] = (
        sub["Sex"]
        .astype(str)
        .str.strip()
        .str.lower()
        .map({"m": "Male", "f": "Female", "male": "Male", "female": "Female"})
    )
    sub = sub.dropna(subset=["Sex_plot"])
    if len(sub) < 2:
        print(f"{col80} vs {col100}: not enough data after Sex mapping (n={len(sub)})")
        continue

    # Pearson correlation
    r, p = pearsonr(sub[col80], sub[col100])

    # Regression line in grey on this axis
    sns.regplot(
        data=sub,
        x=col80,
        y=col100,
        ci=None,
        scatter=False,
        line_kws={"color": "grey", "linewidth": 2},
        ax=ax,
    )

    # Sex-colored scatter on this axis
    sns.scatterplot(
        data=sub,
        x=col80,
        y=col100,
        hue="Sex_plot",
        palette={"Male": "red", "Female": "dodgerblue"},
        edgecolor="black",
        ax=ax,
        legend=False if ax is not axes[-1] else True,  # only show legend on last
    )

    # Set per-subplot axis limits
    lims = axis_limits.get(label, None)
    if lims is not None:
        ax.set_xlim(*lims["x"])
        ax.set_ylim(*lims["y"])

    ax.set_xlabel(f"{label} (Bandit80)")
    ax.set_ylabel(f"{label} (Bandit100)")
    ax.set_title(f"{label}\n n={len(sub)}, r={r:.3f}, p={p:.3g}")
    sns.despine()
    ax.grid (False)

plt.tight_layout()
plt.show()

In [None]:
#@title Correlation heatmap
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# ------------------------------------------------------------------
# 1. Ensure merged dataset exists
# ------------------------------------------------------------------
if "Key_with_metrics" not in globals():
    raise RuntimeError("Key_with_metrics not found — run merge first.")

df = Key_with_metrics.copy()

# ------------------------------------------------------------------
# 1b. Drop any columns containing "Day" or "Night", and "Total_pellets"
# ------------------------------------------------------------------
cols_to_drop = [c for c in df.columns if ("Day" in c) or ("Night" in c)]
cols_to_drop.extend(["Total_pellets_Bandit100", "Total_pellets_Bandit80"])

df = df.drop(columns=cols_to_drop, errors="ignore")

# ------------------------------------------------------------------
# 2. Identify all Bandit80 and Bandit100 metric columns
# ------------------------------------------------------------------
b80_cols = sorted([c for c in df.columns if c.endswith("_Bandit80")])
b100_cols = sorted([c for c in df.columns if c.endswith("_Bandit100")])

# Keep only numeric columns
b80_cols = [c for c in b80_cols if pd.api.types.is_numeric_dtype(df[c])]
b100_cols = [c for c in b100_cols if pd.api.types.is_numeric_dtype(df[c])]

# ------------------------------------------------------------------
# 3. Build a cross-correlation matrix: rows=Bandit100, columns=Bandit80
# ------------------------------------------------------------------
cross_corr = pd.DataFrame(index=b100_cols, columns=b80_cols, dtype=float)

for col100 in b100_cols:
    for col80 in b80_cols:
        sub = df[[col100, col80]].dropna()
        if len(sub) < 2:
            cross_corr.loc[col100, col80] = np.nan
        else:
            cross_corr.loc[col100, col80] = sub[col100].corr(sub[col80])

# ------------------------------------------------------------------
# 3b. Order rows/columns by number of strong correlations (|r| >= 0.30)
# ------------------------------------------------------------------

threshold = 0.30

# Boolean mask of strong correlations
strong_mask = cross_corr.abs() >= threshold

# Convert row/column names to base metric names
row_base = cross_corr.index.str.replace("_Bandit100", "", regex=False)
col_base = cross_corr.columns.str.replace("_Bandit80", "", regex=False)

# Count strong correlations per base metric
row_counts = (
    strong_mask
    .groupby(row_base)
    .sum()
    .sum(axis=1)
)

col_counts = (
    strong_mask
    .groupby(col_base, axis=1)
    .sum()
    .sum(axis=0)
)

# Combine counts (rows + columns) → single ordering
total_counts = row_counts.add(col_counts, fill_value=0)

base_order = total_counts.sort_values(ascending=False).index.tolist()

# Rebuild ordered row/column labels
row_order = [f"{b}_Bandit100" for b in base_order if f"{b}_Bandit100" in cross_corr.index]
col_order = [f"{b}_Bandit80"  for b in base_order if f"{b}_Bandit80"  in cross_corr.columns]

# Reorder the matrix
cross_corr_sorted = cross_corr.loc[row_order, col_order]



plt.figure(figsize=(7, 5))

# Base layer: all correlations in light grey, NO annotations
ax = sns.heatmap(
    cross_corr_sorted,
    cmap="Greys_r",
    vmin=-1, vmax=1,
    center=0,
    annot=False,
    linewidths=0.5,
    cbar=False,
)

# Second layer: only strong correlations (|r| >= threshold), colored + annotated
strong_corr = cross_corr_sorted.where(cross_corr_sorted.abs() >= threshold)

sns.heatmap(
    strong_corr,
    cmap="coolwarm",
    vmin=-1, vmax=1,
    center=0,
    annot=True,
    fmt=".2f",
    linewidths=0.5,
    mask=strong_corr.isna(),
    cbar_kws={"shrink": 0.6},
    ax=ax,
)
# Remove "_Bandit80" and "_Bandit100" from axis labels
clean_x = [label.get_text().replace("_Bandit80", "").replace("_Bandit100", "")
           for label in ax.get_xticklabels()]
clean_y = [label.get_text().replace("_Bandit80", "").replace("_Bandit100", "")
           for label in ax.get_yticklabels()]

ax.set_xticklabels(clean_x, rotation=45, ha="right")
ax.set_yticklabels(clean_y)

plt.title("")
plt.xlabel("Bandit80 Metrics")
plt.ylabel("Bandit100 Metrics")
plt.tight_layout()
plt.show()