<a href="https://colab.research.google.com/github/KravitzLab/Murrell2025/blob/main/Murrell_2026_Fig1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Murrell 2026 Figure 1
<br>
<img src="https://fed3bandit.readthedocs.io/en/latest/_static/fed3bandit_logo1.svg" width="200" />

Authors: Chantelle Murrell<br>
Updated: 12-30-25  

In [None]:
# @title Install libraries and import them {"run":"auto"}

import importlib.util
import subprocess
import sys

# Packages to ensure are installed (add others here if you like)
packages = {
    "fed3": "git+https://github.com/earnestt1234/fed3.git",
    "fed3bandit": "fed3bandit",
    "pingouin": "pingouin",
    "ipydatagrid": "ipydatagrid",
    "openpyxl": "openpyxl",
}

for name, source in packages.items():
    if importlib.util.find_spec(name) is None:
        print(f"Installing {name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", source])

# ----------------------------
# Imports
# ----------------------------
# Standard library
import copy
import io
import math
import os
import re
import shutil
import tempfile
import threading
import time
import warnings
import zipfile
import requests
import glob
from datetime import datetime, timedelta
from os.path import basename, splitext

# Third-party
from ipydatagrid import DataGrid, TextRenderer
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pingouin as pg
import fed3
import fed3.plot as fplot
import fed3bandit as f3b
from scipy.stats import f_oneway
import statsmodels.api as sm
from statsmodels.formula.api import ols
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.gridspec as gridspec
from matplotlib.ticker import PercentFormatter
from google.colab import files
try:
    from tqdm.auto import tqdm   # nice in notebooks; falls back to std tqdm on console
except Exception:
    # safe no-op fallback if tqdm isn't installed
    def tqdm(x):
        return x



# ----------------------------
# Configuration
# ----------------------------
warnings.filterwarnings("ignore")
plt.rcParams.update({'font.size': 12, 'figure.autolayout': True})
plt.rcParams['figure.figsize'] = [6, 4]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("Packages installed and imports ready.")


In [None]:
# @title Import FED3 Bandit 100-0 Data

from urllib.request import urlretrieve
import os, zipfile, shutil
import pandas as pd
import numpy as np

try:
    from google.colab import output as colab_output
    colab_output.enable_custom_widget_manager()
except Exception:
    pass

zip_url = "https://github.com/KravitzLab/Murrell2025/raw/refs/heads/main/Data/Bandit100.zip"
key_url = "https://github.com/KravitzLab/Murrell2025/raw/refs/heads/main/Data/Murrell2026_Key.csv"

zip_dir = "/content/Murrell2025_zipdata"
zip_path = os.path.join(zip_dir, "Bandit100.zip")
extract_root = os.path.join(zip_dir, "Bandit100_extracted")
key_path = os.path.join(zip_dir, "Murrell2026_Key.csv")

os.makedirs(zip_dir, exist_ok=True)

# download + unzip (fresh each run)
if os.path.exists(zip_path):
    os.remove(zip_path)
if os.path.isdir(extract_root):
    shutil.rmtree(extract_root)

print("Importing github.com/KravitzLab/Murrell2025/Data/Bandit100.zip ...")
urlretrieve(zip_url, zip_path)

os.makedirs(extract_root, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
    zf.extractall(extract_root)

# if the zip contains a Bandit100 folder, use it; otherwise use extract_root
bandit_root = os.path.join(extract_root, "Bandit100")
local_parent_path = bandit_root if os.path.isdir(bandit_root) else extract_root

# load CSVs
feds, loaded_files, session_types = [], [], []

for dirpath, _, filenames in os.walk(local_parent_path):
    for file_name in sorted(filenames):
        if file_name.lower().endswith(".csv"):
            file_path = os.path.join(dirpath, file_name)
            strain_name = os.path.basename(os.path.dirname(file_path))

            df = fed3.load(file_path)
            df.name = file_name
            df["Strain"] = strain_name
            df["SourceFile"] = file_name

            feds.append(df)
            loaded_files.append(file_path)

            st = df["Session_Type"].dropna().astype(str)
            session_types.append(st.iloc[0] if len(st) else None)

print(f"Loaded {len(feds)} CSV files.")

# download + load key
urlretrieve(key_url, key_path)

key_df = pd.read_csv(key_path, encoding="utf-8-sig")
key_df["Mouse_ID"] = key_df["Mouse_ID"].astype(str).str.strip()

# match Mouse_ID by substring in filename base
def _base_lower(p):
    return os.path.splitext(os.path.basename(p))[0].lower()

files_df = pd.DataFrame({"filename": loaded_files, "Session_type": session_types})
files_df["_base"] = files_df["filename"].map(_base_lower)

mouse_ids = (
    key_df["Mouse_ID"]
    .dropna().astype(str).str.strip()
    .replace("", np.nan).dropna().unique().tolist()
)

rows = []
for fname, base in zip(files_df["filename"], files_df["_base"]):
    hits = [mid for mid in mouse_ids if mid.lower() in base]
    rows.append({"filename": fname, "Mouse_ID": hits[0] if len(hits) else None})

matched = pd.DataFrame(rows)

Key_Df = (
    files_df.drop(columns=["_base"])
    .merge(matched, on="filename", how="left")
    .merge(key_df.drop_duplicates("Mouse_ID"), on="Mouse_ID", how="left")
)

# display
grid = DataGrid(
    Key_Df.reset_index(drop=True),
    editable=True,
    selection_mode="cell",
    layout={"height": "420px"},
    base_row_size=28,
    base_column_size=120,
)
grid.default_renderer = TextRenderer(text_wrap=True)
display(grid)

In [None]:
# @title Plot individual files (males are blue, females are red)
# ----- Inputs -----
assert 'feds' in globals() and isinstance(feds, list) and len(feds) > 0, "No FED3 files loaded."
assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."

TIMESTAMP_COL_CANON = "MM:DD:YYYY hh:mm:ss"  # primary target column name

# ----- Helpers -----
def _find_time_col(df):
    # exact match first
    if TIMESTAMP_COL_CANON in df.columns:
        return TIMESTAMP_COL_CANON
    # tolerant search (case/space-insensitive)
    lc = {str(c).strip().lower(): c for c in df.columns}
    for key in lc:
        if key.replace(" ", "") in {"mm:dd:yyyyhh:mm:ss", "mm:dd:yyyy_hh:mm:ss", "mm/dd/yyyyhh:mm:ss"}:
            return lc[key]
    return None

def _parse_ts(series):
    # robust parsing; coerce errors to NaT
    return pd.to_datetime(series, errors="coerce", infer_datetime_format=True)


# ----- Plotting
files_list = feds

# metadata_df = copy of Key_Df
metadata_df = Key_Df.copy().reset_index(drop=True)
if 'filename' in metadata_df.columns:
    metadata_df['filename'] = metadata_df['filename'].astype(str).map(os.path.basename)

def _coerce_numeric_col(df, col, clip_upper=None, na_map=None):
    if col not in df.columns:
        return
    s = df[col]
    if na_map:
        s = s.replace(na_map)
    s = pd.to_numeric(s, errors='coerce')
    if clip_upper is not None:
        s.loc[s > clip_upper] = np.nan
    df[col] = s

def _plot_file_core(file_index):
    df = files_list[file_index].copy()
    full_name = getattr(df, 'name', f"File_{file_index}")
    file_basename = os.path.basename(str(full_name))

    # Preserve original index once
    if "Original_Timestamp" not in df.columns:
        df["Original_Timestamp"] = df.index

    # Attach metadata by filename (matching already done upstream)
    meta_row = None
    if 'filename' in metadata_df.columns:
        mr = metadata_df.loc[metadata_df['filename'] == file_basename]
        if not mr.empty:
            meta_row = mr.iloc[0]
    if meta_row is not None:
        for col in meta_row.index:
            if col == 'filename':
                continue
            if col not in df.columns:
                df[col] = meta_row[col]
            else:
                if pd.isna(df[col]).all() and pd.notna(meta_row[col]):
                    df[col] = meta_row[col]

    # Time + cleanup
    try:
        df['timestamp'] = pd.to_datetime(df.index)
    except Exception:
        df['timestamp'] = np.arange(len(df))

    _coerce_numeric_col(df, 'Poke_Time', clip_upper=2)
    _coerce_numeric_col(df, 'Retrieval_Time', na_map={"Timed_out": np.nan})

    if len(df) == 0:
        print(f"[!] Empty cropped dataframe for {file_basename}. Skipping plot.")
        return

    # Behavioral traces (needs f3b)
    true_left = f3b.true_probs(df, offset=5)[0]
    mouse_left = f3b.binned_paction(df, window=10)

    # Plot
    fig, ax = plt.subplots(figsize=(10, 3))
    ax.plot(np.arange(len(true_left)), true_left, color="black", linewidth=2, alpha=0.5)

    color = "dodgerblue"
    if 'Sex' in df.columns and pd.notna(df['Sex']).any():
        try:
            color = "red" if str(df['Sex'].iloc[0]).strip().lower().startswith("f") else "dodgerblue"
        except Exception:
            pass
    ax.plot(np.arange(len(mouse_left)), mouse_left, color=color, linewidth=3, alpha=0.7)

    # ---- Clean look: remove ticks, labels, spines ----
    ax.set_xlabel("")                 # no x label
    ax.set_ylabel("")                 # no y label
    ax.tick_params(axis="both", which="both",
                   bottom=False, top=False, left=False, right=False,
                   labelbottom=False, labelleft=False)
    for s in ax.spines.values():      # remove axis lines
        s.set_visible(False)

    # ---- Textual y "labels" at y=1 and y=0 (not ticks) ----
    ax.text(-0.01, 1.0, "Right", transform=ax.get_yaxis_transform(),
            ha="right", va="center")
    ax.text(-0.01, 0.0, "Left",  transform=ax.get_yaxis_transform(),
            ha="right", va="center")

    # ---- Title-area arrow above the axes ----
    # spans full width; adjust y (1.08–1.15) if you need more/less space
    ax.annotate("",
                xy=(0.95, 1.05), xytext=(0.05, 1.05),
                xycoords="axes fraction",
                arrowprops=dict(arrowstyle="->", lw=5, color="0.6"))
    # Optional: small caption above the arrow (example: duration text)
    ax.text(0.4, 1.15, "3 days", transform=ax.transAxes, ha="left", va="bottom", color="0.5")
    sns.despine(left=True, bottom=True, top=True, right=True)
    plt.tight_layout()
    plt.show()

# ----- Simple UI: slider + status + output -----
N = len(files_list)
assert N > 0, "No files available after cropping."
idx_slider = widgets.IntSlider(min=0, max=max(0, N-1), step=1, value=0, description='File', continuous_update=True)
status_lbl = widgets.HTML()
out = widgets.Output()

def _status(idx):
    name = getattr(files_list[idx], 'name', f"File_{idx}")
    return f"Index: <b>{idx}</b> &nbsp;|&nbsp; File: <code>{os.path.basename(str(name))}</code> &nbsp;|&nbsp; Rows: {len(files_list[idx])}"

def _render(*_):
    idx = int(idx_slider.value)
    status_lbl.value = _status(idx)
    out.clear_output()
    with out:
        _plot_file_core(idx)

idx_slider.observe(_render, names='value')
display(widgets.VBox([idx_slider, status_lbl, out]))
_render()


In [None]:
# @title Analyze Bandit metrics
from pathlib import Path
import os
def _find_time_col(df):
    for c in ["MM:DD:YYYY hh:mm:ss", "DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if c in df.columns:
            return c
    return None

def _get_timestamp_series(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    import pandas as pd
    if ts_col in df.columns:
        ts = pd.to_datetime(df[ts_col], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        return pd.Series(ts, index=df.index)
    for cand in ["DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            return pd.Series(ts, index=df.index)
    idx = df.index
    if isinstance(idx, pd.DatetimeIndex):
        return pd.Series(idx, index=df.index)
    return pd.to_datetime(pd.Series(idx, index=df.index), errors="coerce")

def _crop_last_24h(df):
    import pandas as pd
    dfc = df.copy()
    ts_col = _find_time_col(dfc)

    if ts_col is not None:
        ts_series = pd.to_datetime(dfc[ts_col], errors="coerce", infer_datetime_format=True)
        ts_series = pd.Series(ts_series, index=dfc.index)
    else:
        try:
            ts_idx = pd.to_datetime(dfc.index, errors="coerce", infer_datetime_format=True)
            ts_series = pd.Series(ts_idx, index=dfc.index)
        except Exception:
            return df  # no usable timestamps → return original

    if ts_series.isna().all():
        return df

    end = ts_series.max()
    if pd.isna(end):
        return df
    start = end - pd.Timedelta(hours=24)

    mask = ts_series.between(start, end, inclusive="both")
    cropped = dfc.loc[mask]

    if hasattr(df, "name"):
        cropped.name = df.name
    return cropped

def build_feds_cropped(sessions):
    """Return a list of sessions cropped to their last 24h."""
    return [_crop_last_24h(d) for d in sessions]

# Use it like this:
assert 'feds' in globals() and isinstance(feds, (list, tuple)) and len(feds) > 0, "No FED3 files available."
feds_cropped = build_feds_cropped(feds)

# Prefer cropped sessions downstream
_sessions = list(feds_cropped) if len(feds_cropped) > 0 else list(feds)
# ----- Build feds_cropped -----
feds_cropped = [_crop_last_24h(d) for d in feds]
# ---------- Inputs ----------
# Prefer cropped sessions
if 'feds_cropped' in globals() and isinstance(feds_cropped, (list, tuple)) and len(feds_cropped) > 0:
    _sessions = list(feds_cropped)
else:
    assert 'feds' in globals() and isinstance(feds, (list, tuple)) and len(feds) > 0, "No FED3 files available."
    _sessions = list(feds)

# metadata_df from Key_Df
if 'metadata_df' not in globals() or not isinstance(metadata_df, pd.DataFrame):
    assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."
    metadata_df = Key_Df.copy().reset_index(drop=True)

def _basename(pathlike) -> str:
    s = str(pathlike).replace("\\", "/")
    return s.split("/")[-1]

def _get_timestamp_series(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    if ts_col in df.columns:
        ts = pd.to_datetime(df[ts_col], format="%m:%d:%Y %H:%M:%S", errors="coerce")
        return pd.Series(ts, index=df.index)
    for cand in ["DateTime", "Datetime", "Timestamp", "timestamp", "datetime"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            return pd.Series(ts, index=df.index)
    idx = df.index
    if isinstance(idx, pd.DatetimeIndex):
        return pd.Series(idx, index=df.index)
    return pd.to_datetime(pd.Series(idx, index=df.index), errors="coerce")

def _split_day_night(df, ts_col="MM:DD:YYYY hh:mm:ss"):
    ts = _get_timestamp_series(df, ts_col=ts_col)
    valid = ts.notna()
    hrs = ts.dt.hour
    day_mask = valid & (hrs >= 7) & (hrs < 19)
    night_mask = valid & ~day_mask
    return df.loc[day_mask], df.loc[night_mask]

def compute_withinbout_lose_shift(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 2:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        total = shifted = 0
        for i in range(len(events) - 1):
            curr_evt, next_evt = events[i], events[i + 1]
            if curr_evt not in ("Left", "Right"):
                continue
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if np.isnan(dt_s) or dt_s > max_gap_s:
                continue
            if next_evt == "Pellet":
                continue
            if next_evt in ("Left", "Right"):
                total += 1
                if next_evt != curr_evt:
                    shifted += 1
        return (shifted / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_withinbout_win_stay(c_df, max_gap_s=120):
    try:
        if "Event" not in c_df.columns or len(c_df) < 3:
            return np.nan
        events = c_df["Event"].to_numpy()
        times = _get_timestamp_series(c_df).to_numpy()
        pellet_idx = [i for i in range(1, len(events) - 1) if events[i] == "Pellet"]
        total = same = 0
        for i in pellet_idx:
            prev_event, next_event = events[i - 1], events[i + 1]
            dt_s = (times[i + 1] - times[i]) / np.timedelta64(1, "s")
            if not np.isnan(dt_s) and dt_s <= max_gap_s:
                if prev_event in ("Left", "Right") and next_event in ("Left", "Right"):
                    total += 1
                    if next_event == prev_event:
                        same += 1
        return (same / total) if total > 0 else np.nan
    except Exception:
        return np.nan

def compute_peak_accuracy(c_df):
    try:
        rev_avg = f3b.reversal_peh(c_df, (-10, 10), True)
        if len(rev_avg) == 0:
            return np.nan
        return float(np.mean(rev_avg[:10])) if len(rev_avg) >= 10 else float(np.mean(rev_avg))
    except Exception:
        return np.nan

def estimate_daily_pellets(c_df):
    ts = _get_timestamp_series(c_df)
    valid_ts = ts.dropna()
    if valid_ts.size < 2:
        return np.nan
    duration_hours = (valid_ts.max() - valid_ts.min()).total_seconds() / 3600.0
    if duration_hours <= 0:
        return np.nan

    pellet_events = np.nan
    if "Pellet_Count" in c_df.columns and c_df["Pellet_Count"].notna().any():
        pc = pd.to_numeric(c_df["Pellet_Count"], errors="coerce")
        if pc.notna().any():
            diffs = pc.diff().fillna(0).clip(lower=0)
            pellet_events = float(diffs.sum())
            if pellet_events == 0 and pc.iloc[-1] >= pc.iloc[0]:
                pellet_events = float(pc.iloc[-1] - pc.iloc[0])
    if (pd.isna(pellet_events)) and ("Event" in c_df.columns):
        pellet_events = float((c_df["Event"] == "Pellet").sum())

    if pd.isna(pellet_events):
        return np.nan
    return (pellet_events / duration_hours) * 24.0

# ---------- Prepare metadata (merge once by filename) ----------
md = metadata_df.copy()
md['filename'] = md['filename'].astype(str).map(_basename)
if 'Mouse_ID' in md.columns:
    md['Mouse_ID'] = md['Mouse_ID'].astype(str).str.strip()
else:
    md['Mouse_ID'] = np.nan

# Keep only metadata columns we care about; rename to avoid accidental dupes
# (add/remove columns as needed)
meta_keep = [c for c in md.columns if c in {"filename", "Mouse_ID", "Session_type", "Cohort", "Strain", "Sex"}]
md_clean = md[meta_keep].drop_duplicates(subset=["filename"], keep="first")

# ---------- Compute metrics on the chosen sessions ----------
rows = []
for idx in tqdm(range(len(_sessions))):
    c_df = _sessions[idx]
    file_name = _basename(getattr(c_df, "name", f"File_{idx}"))

    try:
        clean_retrieval_time = pd.to_numeric(c_df.get("Retrieval_Time", pd.Series(dtype=float)), errors="coerce")
        clean_retrieval_time = clean_retrieval_time[clean_retrieval_time < 5]

        clean_poke_time = pd.to_numeric(c_df.get("Poke_Time", pd.Series(dtype=float)), errors="coerce")
        clean_poke_time = clean_poke_time[clean_poke_time > 0]

        day_df, night_df = _split_day_night(c_df, ts_col="MM:DD:YYYY hh:mm:ss")

        row = {
            "filename": file_name,
            "PeakAccuracy": compute_peak_accuracy(c_df),
            "Total_pellets": f3b.count_pellets(c_df),
            "Total_pokes": f3b.count_pokes(c_df),
            "PokesPerPellet": f3b.pokes_per_pellet(c_df),
            "RetrievalTime": clean_retrieval_time.median() if not clean_retrieval_time.empty else np.nan,
            "PokeTime": clean_poke_time.median() if not clean_poke_time.empty else np.nan,
            "Win-stay": compute_withinbout_win_stay(c_df),
            "Lose-shift": compute_withinbout_lose_shift(c_df),
            "daily pellets": estimate_daily_pellets(c_df),
            "PeakAccuracy_Day": compute_peak_accuracy(day_df),
            "PeakAccuracy_Night": compute_peak_accuracy(night_df),
            "Win-stay_Day": compute_withinbout_win_stay(day_df),
            "Win-stay_Night": compute_withinbout_win_stay(night_df),
            "Lose-shift_Day": compute_withinbout_lose_shift(day_df),
            "Lose-shift_Night": compute_withinbout_lose_shift(night_df),
        }
        rows.append(row)

    except Exception as e:
        print(f"Failed on {file_name} (idx {idx}): {e}")


Bandit100_metrics = pd.DataFrame(rows)
Bandit100_metrics = Bandit100_metrics.merge(md_clean, on="filename", how="left")
Bandit100_metrics = Bandit100_metrics.loc[:, ~Bandit100_metrics.columns.duplicated()]

csv_name = "Bandit100_metrics.csv"
Bandit100_metrics.to_csv(csv_name, index=False)

from google.colab import files
import ipywidgets as widgets
from IPython.display import display

def download_csv(b):
    files.download(csv_name)

download_button = widgets.Button(
    description="⬇️ Download summary stats (CSV)",
    button_style="primary",
)

download_button.on_click(download_csv)
display(download_button)


In [None]:
# @title Plot Female vs Male (with FDR correction across plots)
import os, time, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output

try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

from statsmodels.stats.multitest import multipletests  # <-- FDR

# -----------------------
# Config
# -----------------------
GROUP_COL = "Sex"
metrics = [
    "daily pellets", "Total_pokes",
    "PeakAccuracy", "Win-stay",
]

COLOR_MAP = {"F": "red", "M": "dodgerblue"}

# -----------------------
# Preconditions
# -----------------------
if "Bandit100_metrics" not in globals() or Bandit100_metrics is None or Bandit100_metrics.empty:
    raise RuntimeError("Bandit100_metrics is missing/empty. Run the metrics cell first.")

bm = Bandit100_metrics.copy()

# -----------------------
# Helper: normalize Sex to F/M/UNK
# -----------------------
def _norm_sex(x):
    s = str(x).strip().upper()
    if s in {"F", "FEMALE", "FEM"}: return "F"
    if s in {"M", "MALE"}: return "M"
    return "UNK"

# -----------------------
# Ensure we have Sex (merge from metadata_df if needed)
# -----------------------
if GROUP_COL not in bm.columns or bm[GROUP_COL].isna().all():

    if "metadata_df" not in globals() or metadata_df is None or metadata_df.empty:
        raise RuntimeError("Sex not found in Bandit100_metrics and metadata_df is missing/empty.")

    meta = metadata_df.copy()

    # Normalize column names for robust lookup
    def _find_col(df, name):
        lc = {str(c).strip().lower(): c for c in df.columns}
        return lc.get(name.lower(), None)

    meta_sex = _find_col(meta, "Sex")
    if meta_sex is None:
        raise RuntimeError("metadata_df does not contain a 'Sex' column (case-insensitive).")

    # Prefer merging by Mouse_ID if present in both
    meta_mouse = _find_col(meta, "Mouse_ID")
    bm_mouse   = _find_col(bm, "Mouse_ID")

    merged = None

    if meta_mouse is not None and bm_mouse is not None:
        meta_key = meta[[meta_mouse, meta_sex]].copy()
        meta_key.columns = ["Mouse_ID", "Sex"]
        meta_key["Mouse_ID"] = meta_key["Mouse_ID"].astype(str).str.strip()
        meta_key = meta_key.dropna(subset=["Mouse_ID"]).drop_duplicates("Mouse_ID")

        bm["Mouse_ID"] = bm[bm_mouse].astype(str).str.strip()
        merged = bm.merge(meta_key, on="Mouse_ID", how="left")

    # Fallback: merge by filename basename if Mouse_ID isn’t available
    if merged is None or merged["Sex"].isna().all():
        # build/ensure filename columns
        if "filename" not in bm.columns:
            if "File" in bm.columns:
                bm["filename"] = bm["File"].astype(str)
            else:
                raise RuntimeError("Need either Mouse_ID or filename/File in Bandit100_metrics to merge Sex.")

        if "filename" not in meta.columns:
            raise RuntimeError("Cannot fallback merge: metadata_df has no filename column.")

        bm["file_base"]   = bm["filename"].astype(str).apply(lambda p: os.path.basename(p))
        meta["file_base"] = meta["filename"].astype(str).apply(lambda p: os.path.basename(p))

        meta_key = meta[["file_base", meta_sex]].copy()
        meta_key.columns = ["file_base", "Sex"]
        meta_key = meta_key.dropna(subset=["file_base"]).drop_duplicates("file_base")

        merged = bm.merge(meta_key, on="file_base", how="left").drop(columns=["file_base"])

    bm = merged

# Normalize Sex values
bm[GROUP_COL] = bm[GROUP_COL].apply(_norm_sex)

# Keep only F/M for plotting (drop UNK)
bm = bm[bm[GROUP_COL].isin(["F", "M"])].copy()
if bm.empty:
    raise RuntimeError("After merging/normalizing Sex, no rows with Sex in {F, M} were found.")

# -----------------------
# Long format
# -----------------------
value_vars = [m for m in metrics if m in bm.columns]
if not value_vars:
    raise RuntimeError("None of the expected metric columns were found in Bandit100_metrics.")

id_vars = [c for c in ["filename", "Mouse_ID", "Strain", GROUP_COL] if c in bm.columns]
long_df = bm.melt(id_vars=id_vars, value_vars=value_vars, var_name="metric", value_name="value")

# consistent order
groups = [g for g in ["F", "M"] if g in long_df[GROUP_COL].unique()]
if len(groups) < 2:
    raise RuntimeError(f"Need both F and M present to compare; found: {groups}")

# -----------------------
# Stats
# -----------------------
def welch_p(a, b):
    a = pd.Series(a, dtype=float).dropna()
    b = pd.Series(b, dtype=float).dropna()
    if len(a) < 2 or len(b) < 2:
        return np.nan
    return float(pg.ttest(a, b, paired=False)["p-val"].iat[0])

# -----------------------
# Precompute p-values + FDR correction across plotted metrics
# -----------------------
raw_pvals = {}
for metric in metrics:
    if metric not in value_vars:
        raw_pvals[metric] = np.nan
        continue
    dfm = long_df[long_df["metric"] == metric].dropna(subset=["value"])
    a = dfm.loc[dfm[GROUP_COL] == groups[0], "value"]
    b = dfm.loc[dfm[GROUP_COL] == groups[1], "value"]
    raw_pvals[metric] = welch_p(a, b)

valid_metrics = [m for m, p in raw_pvals.items() if np.isfinite(p)]
valid_pvals = [raw_pvals[m] for m in valid_metrics]

pvals_fdr_map = {}
if len(valid_pvals) > 0:
    _, pvals_fdr, _, _ = multipletests(valid_pvals, alpha=0.05, method="fdr_bh")
    pvals_fdr_map = dict(zip(valid_metrics, pvals_fdr))

# -----------------------
# Plot UI
# -----------------------
out = widgets.Output()
save_btn = widgets.Button(description="Save PDF", button_style="success")
_last_fig = None

def run_plots():
    global _last_fig
    with out:
        clear_output()

        fig, axes = plt.subplots(2, 4, figsize=(8, 6), constrained_layout=True)
        axes = axes.ravel()

        for i, metric in enumerate(metrics):
            ax = axes[i]
            if metric not in value_vars:
                ax.set_axis_off()
                continue

            dfm = long_df[long_df["metric"] == metric].dropna(subset=["value"])

            pal = {g: COLOR_MAP[g] for g in groups}

            sns.barplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups, ci=None, alpha=0.6,
                palette=pal, ax=ax
            )
            sns.stripplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups,
                color="white", edgecolor="black",
                linewidth=1, size=6,
                alpha=0.35, jitter=True, ax=ax
            )

            p_fdr = pvals_fdr_map.get(metric, np.nan)
            label = (
                "FDR p<0.001" if np.isfinite(p_fdr) and p_fdr < 0.001
                else (f"FDR p={p_fdr:.3f}" if np.isfinite(p_fdr) else "FDR p=NA")
            )

            ax.set_title("")
            ax.set_xlabel("")
            ax.set_ylabel(metric, fontsize=12)
            ax.text(0.5, 1.02, label, transform=ax.transAxes, ha="center", va="bottom")

            sns.despine(ax=ax)

        plt.show()
        _last_fig = fig

def save_plots(_=None):
    if _last_fig is None:
        with out:
            print("Nothing to save yet.")
        return
    fname = f"metrics_grid_{int(time.time())}.pdf"
    _last_fig.savefig(fname, dpi=300, bbox_inches="tight")
    with out:
        print(f"Saved {fname}")
    if colab_files is not None:
        colab_files.download(fname)

save_btn.on_click(save_plots)

display(save_btn)
display(out)

run_plots()


In [None]:
# @title Accuracy around switches

import os, re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# --- Config ---
TRIALS = 11
COLOR_MAP = {"F": "red", "M": "dodgerblue"}

# --- Preconditions ---
if 'feds' not in globals() or not isinstance(feds, (list, tuple)) or len(feds) == 0:
    raise RuntimeError("No FED3 sessions found in `feds`.")
if 'metadata_df' not in globals() or metadata_df is None or metadata_df.empty:
    raise RuntimeError("metadata_df is missing/empty. Build it from the Key first.")

# --- Helpers ---
def _basename(x): return os.path.basename(str(x))

def _norm_sex(x):
    s = str(x).strip().upper()
    if s in {"F", "FEMALE", "FEM"}: return "F"
    if s in {"M", "MALE"}: return "M"
    return "UNK"

def _find_col(df, name):
    lc = {str(c).strip().lower(): c for c in df.columns}
    return lc.get(name.lower(), None)

def _extract_mouse_id_from_sess_name(name):
    """
    Expected patterns like:
      CDKL5_005_180_69_Bandit80_20251013.csv  -> CDKL5_005_180_69
      ChowHFD_670_Bandit80_20250728.csv       -> ChowHFD_670
    Returns the ID part (no extension).
    """
    base = _basename(name)
    base = re.sub(r"\.csv$", "", base, flags=re.IGNORECASE)

    # Strip trailing task/date: _Bandit80_YYYYMMDD, _FR1_YYYYMMDD, etc.
    base = re.sub(r"_(Bandit80|Bandit100|FR1|PR1)_\d{8}$", "", base, flags=re.IGNORECASE)
    return base

def _is_empty(x):
    if x is None: return True
    try:
        return len(x) == 0
    except Exception:
        try:
            return np.size(x) == 0
        except Exception:
            return True

# --- Build Sex lookup from metadata_df ---
meta = metadata_df.copy()
meta_sex = _find_col(meta, "Sex")
if meta_sex is None:
    raise RuntimeError("metadata_df does not contain a 'Sex' column (case-insensitive).")
meta_mouse = _find_col(meta, "Mouse_ID")

sex_lookup = {}

# Prefer Mouse_ID mapping if present
if meta_mouse is not None:
    tmp = meta[[meta_mouse, meta_sex]].copy()
    tmp.columns = ["Mouse_ID", "Sex"]
    tmp["Mouse_ID"] = tmp["Mouse_ID"].astype(str).str.strip()
    tmp["Sex"] = tmp["Sex"].apply(_norm_sex)
    tmp = tmp.dropna(subset=["Mouse_ID"]).drop_duplicates("Mouse_ID")
    sex_lookup = dict(zip(tmp["Mouse_ID"], tmp["Sex"]))

# Fallback: if metadata_df has filename, build filename->sex map too
file_sex_lookup = {}
if "filename" in meta.columns:
    tmp2 = meta[["filename", meta_sex]].copy()
    tmp2["file_base"] = tmp2["filename"].astype(str).apply(_basename)
    tmp2["Sex"] = tmp2[meta_sex].apply(_norm_sex)
    tmp2 = tmp2.dropna(subset=["file_base"]).drop_duplicates("file_base")
    file_sex_lookup = dict(zip(tmp2["file_base"], tmp2["Sex"]))

# --- Build rev_df ---
rows = []
for i, sess in enumerate(feds):
    sess_name = getattr(sess, "name", f"session_{i}")
    base = _basename(sess_name)

    # Get an ID from session filename and look up Sex
    mouse_id = _extract_mouse_id_from_sess_name(base)
    sex = sex_lookup.get(mouse_id, "UNK")

    # fallback: try direct filename basename lookup (only if metadata has filename)
    if sex == "UNK" and file_sex_lookup:
        sex = file_sex_lookup.get(base, "UNK")

    # compute peri-switch trials using your helper
    try:
        peh = f3b.reversal_peh(sess, (-TRIALS, TRIALS), return_avg=False)
    except Exception as e:
        print(f"[skip] {base}: reversal_peh failed: {e}")
        continue

    if _is_empty(peh) or sex not in {"F", "M"}:
        continue

    for tr in list(peh):
        arr = np.asarray(tr).ravel()
        for t, v in enumerate(arr):
            rows.append({
                "Timepoint": t - TRIALS + 1,
                "Value": float(v) if np.isfinite(v) else np.nan,
                "Sex": sex
            })

rev_df = pd.DataFrame(rows)
rev_df = rev_df[np.isfinite(rev_df["Value"])]
rev_df = rev_df[rev_df["Timepoint"] != 0]  # optional
if rev_df.empty:
    raise RuntimeError("No peri-switch data produced (after filtering to F/M).")

# --- Plot ---
group_order = [g for g in ["F", "M"] if g in rev_df["Sex"].unique()]
palette = {g: COLOR_MAP[g] for g in group_order}

plt.figure(figsize=(5, 4))
ax = sns.lineplot(
    data=rev_df.sort_values(["Sex", "Timepoint"]),
    x="Timepoint",
    y="Value",
    hue="Sex",
    hue_order=group_order,
    palette=palette,
    estimator="mean",
    errorbar="se",
    n_boot=0,
    lw=2
)

ax.axvline(x=0, color="darkgrey", linestyle="--", linewidth=1.25)
ymin, ymax = ax.get_ylim()
ax.text(0.5, ymin + 0.95*(ymax - ymin), "Switch", color="darkgrey",
        fontsize=12, ha="left", va="top")

ax.set_xlabel("Trials from switch")
ax.set_ylabel("Accuracy (%)")
ax.set_title("")
ax.legend(title="", frameon=False)
sns.despine()
plt.tight_layout()
plt.show()


In [None]:
# @title Plot Day vs. Night metrics
import os, time, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output

try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

# -----------------------
# Config
# -----------------------
GROUP_COL = "Sex"
metrics = [
    "PeakAccuracy_Night", "Win-stay_Night", "Lose-shift_Night",
    "PeakAccuracy_Day", "Win-stay_Day", "Lose-shift_Day",
]
COLOR_MAP = {"F": "red", "M": "dodgerblue"}

# -----------------------
# Preconditions
# -----------------------
if "Bandit100_metrics" not in globals() or Bandit100_metrics is None or Bandit100_metrics.empty:
    raise RuntimeError("Bandit100_metrics is missing/empty. Run the metrics cell first.")

bm = Bandit100_metrics.copy()

# -----------------------
# Helper: normalize Sex to F/M/UNK
# -----------------------
def _norm_sex(x):
    s = str(x).strip().upper()
    if s in {"F", "FEMALE", "FEM"}: return "F"
    if s in {"M", "MALE"}: return "M"
    return "UNK"

# -----------------------
# Ensure we have Sex (merge from metadata_df if needed)
# -----------------------
if GROUP_COL not in bm.columns or bm[GROUP_COL].isna().all():

    if "metadata_df" not in globals() or metadata_df is None or metadata_df.empty:
        raise RuntimeError("Sex not found in Bandit100_metrics and metadata_df is missing/empty.")

    meta = metadata_df.copy()

    # Normalize column names for robust lookup
    def _find_col(df, name):
        lc = {str(c).strip().lower(): c for c in df.columns}
        return lc.get(name.lower(), None)

    meta_sex = _find_col(meta, "Sex")
    if meta_sex is None:
        raise RuntimeError("metadata_df does not contain a 'Sex' column (case-insensitive).")

    # Prefer merging by Mouse_ID if present in both
    meta_mouse = _find_col(meta, "Mouse_ID")
    bm_mouse   = _find_col(bm, "Mouse_ID")

    merged = None

    if meta_mouse is not None and bm_mouse is not None:
        meta_key = meta[[meta_mouse, meta_sex]].copy()
        meta_key.columns = ["Mouse_ID", "Sex"]
        meta_key["Mouse_ID"] = meta_key["Mouse_ID"].astype(str).str.strip()
        meta_key = meta_key.dropna(subset=["Mouse_ID"]).drop_duplicates("Mouse_ID")

        bm["Mouse_ID"] = bm[bm_mouse].astype(str).str.strip()
        merged = bm.merge(meta_key, on="Mouse_ID", how="left")

    # Fallback: merge by filename basename if Mouse_ID isn’t available
    if merged is None or merged["Sex"].isna().all():
        # build/ensure filename columns
        if "filename" not in bm.columns:
            if "File" in bm.columns:
                bm["filename"] = bm["File"].astype(str)
            else:
                raise RuntimeError("Need either Mouse_ID or filename/File in Bandit100_metrics to merge Sex.")

        if "filename" not in meta.columns:
            # if metadata_df already has filename, great; otherwise cannot fallback
            raise RuntimeError("Cannot fallback merge: metadata_df has no filename column.")

        bm["file_base"]   = bm["filename"].astype(str).apply(lambda p: os.path.basename(p))
        meta["file_base"] = meta["filename"].astype(str).apply(lambda p: os.path.basename(p))

        meta_key = meta[["file_base", meta_sex]].copy()
        meta_key.columns = ["file_base", "Sex"]
        meta_key = meta_key.dropna(subset=["file_base"]).drop_duplicates("file_base")

        merged = bm.merge(meta_key, on="file_base", how="left").drop(columns=["file_base"])

    bm = merged

# Normalize Sex values
bm[GROUP_COL] = bm[GROUP_COL].apply(_norm_sex)

# Keep only F/M for plotting (drop UNK)
bm = bm[bm[GROUP_COL].isin(["F", "M"])].copy()
if bm.empty:
    raise RuntimeError("After merging/normalizing Sex, no rows with Sex in {F, M} were found.")

# -----------------------
# Long format
# -----------------------
value_vars = [m for m in metrics if m in bm.columns]
if not value_vars:
    raise RuntimeError("None of the expected metric columns were found in Bandit100_metrics.")

id_vars = [c for c in ["filename", "Mouse_ID", "Strain", GROUP_COL] if c in bm.columns]
long_df = bm.melt(id_vars=id_vars, value_vars=value_vars, var_name="metric", value_name="value")

# consistent order
groups = [g for g in ["F", "M"] if g in long_df[GROUP_COL].unique()]
if len(groups) < 2:
    raise RuntimeError(f"Need both F and M present to compare; found: {groups}")

# -----------------------
# Stats
# -----------------------
def welch_p(a, b):
    a = pd.Series(a, dtype=float).dropna()
    b = pd.Series(b, dtype=float).dropna()
    if len(a) < 2 or len(b) < 2:
        return np.nan
    return float(pg.ttest(a, b, paired=False)["p-val"].iat[0])

# -----------------------
# Plot UI
# -----------------------
out = widgets.Output()
save_btn = widgets.Button(description="Save PDF", button_style="success")
_last_fig = None

def run_plots():
    global _last_fig
    with out:
        clear_output()

        fig, axes = plt.subplots(2, 3, figsize=(6, 6), constrained_layout=True)
        axes = axes.ravel()

        for i, metric in enumerate(metrics):
            ax = axes[i]
            if metric not in value_vars:
                ax.set_axis_off()
                continue

            dfm = long_df[long_df["metric"] == metric].dropna(subset=["value"])

            pal = {g: COLOR_MAP[g] for g in groups}

            sns.barplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups, ci=None, alpha=0.6,
                palette=pal, ax=ax
            )
            sns.stripplot(
                data=dfm, x=GROUP_COL, y="value",
                order=groups,
                color="white", edgecolor="black",
                linewidth=1, size=6,
                alpha=0.35, jitter=True, ax=ax
            )

            a = dfm.loc[dfm[GROUP_COL] == groups[0], "value"]
            b = dfm.loc[dfm[GROUP_COL] == groups[1], "value"]
            p = welch_p(a, b)
            label = (
                "p<0.001" if np.isfinite(p) and p < 0.001
                else (f"p={p:.3f}" if np.isfinite(p) else "p=NA")
            )

            ax.set_title("")
            ax.set_xlabel("")
            ax.set_ylabel(metric, fontsize=12)
            ax.text(0.5, 1.02, label, transform=ax.transAxes, ha="center", va="bottom")

            sns.despine(ax=ax)

        plt.show()
        _last_fig = fig

def save_plots(_=None):
    if _last_fig is None:
        with out:
            print("Nothing to save yet.")
        return
    fname = f"metrics_grid_{int(time.time())}.pdf"
    _last_fig.savefig(fname, dpi=300, bbox_inches="tight")
    with out:
        print(f"Saved {fname}")
    if colab_files is not None:
        colab_files.download(fname)

save_btn.on_click(save_plots)

display(save_btn)
display(out)

run_plots()


In [None]:
# @title RL fitting functions

def extract_context_action_reward(df):
	num_events = len(df)

	num_trials = 0 # the number of valid trials
	for event in df['Event']:
		if event == 'Left' or event == 'Right':
			num_trials += 1

	simple_data = {'context': np.zeros((num_trials)), 'action': np.zeros((num_trials)), 'reward': np.zeros((num_trials))}
	trial_idx = 0
	for eidx, event in enumerate(df['Event']):
		if event == 'Left' or event == 'Right':
			# action
			if event == 'Right':
				simple_data['action'][trial_idx] = 1
			else:
				simple_data['action'][trial_idx] = 0
			# reward
			if eidx + 1 < len(df) and df['Event'].iloc[eidx + 1] == 'Pellet':
				simple_data['reward'][trial_idx] = 1
			else:
				simple_data['reward'][trial_idx] = 0
			# context
			if df['High_prob_poke'].iloc[eidx] == 'Right':
				simple_data['context'][trial_idx] = 1
			else:
				simple_data['context'][trial_idx] = 0
			#
			trial_idx += 1

	return simple_data


# get basic stats from [action, reward, context] sequence
def get_basic_stats(session_data):
	actions = session_data['action']
	rewards = session_data['reward']
	contexts = session_data['context']

	s_stats = {}
	session_length = len(actions)
	if session_length == 0:
		return dict(session_length=0, total_rewards=0, reward_rate=np.nan,
				correct_port_rate=np.nan, action_bias=np.nan)
	else:
		s_stats['session_length'] = session_length
		s_stats['total_rewards'] = np.sum(rewards)
		s_stats['reward_rate'] = s_stats['total_rewards']/session_length
		s_stats['correct_port_rate'] = np.sum( actions==contexts )/session_length
		s_stats['action_bias'] = 2*( np.sum(actions)/session_length - 0.5 )

		return s_stats


  # Model fitting
# -------------------------
# Model & loss (JAX-native)
# -------------------------
def _one_step(carry, inp):
    qL, qR, alpha, beta, bias, lapse, c_preva, prev_action, eps = carry
    action, reward = inp
    signed_prev_action = 2 * (prev_action - 1/2)

    logit = beta * (qR - qL + c_preva * signed_prev_action + bias)
    p_right = 0.5 * lapse + (1.0 - lapse) * jax.nn.sigmoid(logit)
    p_right = jnp.clip(p_right, eps, 1.0 - eps)

    loss_t = -(action * jnp.log(p_right) + (1 - action) * jnp.log1p(-p_right))

    qR_new = qR + action * alpha * (reward - qR)
    qL_new = qL + (1 - action) * alpha * (reward - qL)
    return (qL_new, qR_new, alpha, beta, bias, lapse, c_preva, action, eps), (p_right, loss_t)


def calc_loss(alpha, beta, bias, lapse, c_preva, actions, rewards):
    eps = 1e-8
    prev_action = 0.5
    carry0 = (jnp.array(0.0), jnp.array(0.0), alpha, beta, bias, lapse, c_preva, prev_action, eps)
    _, outputs = lax.scan(_one_step, carry0, (actions, rewards))
    losses = outputs[1]
    return jnp.sum(losses)


def loss_fn(params, actions, rewards):
    return calc_loss(params['alpha'], params['beta'], params['bias'], params['lapse'], params['c_preva'],
                     actions, rewards)

# -----------------------------
# Training with Adam optimizer
# -----------------------------
def fit_with_adam(data,
                  lr=1e-2,
                  num_iterations=1000,
                  param_ranges=None,
                  seed=0,
                  RL_model='preva'): # Add RL_model parameter
    """
    Trains alpha, beta, bias, lapse using Adam on session NLL.
    """
    if param_ranges is None:
        param_ranges = {
            'alpha': (0.0, 1.0),
            'beta':  (0.0, 20.0),
            'bias':  (-1.0, 1.0),
            'lapse': (0.0, 1.0),
            'c_preva': (0.0, 1.0),
        }

    # Adjust param_ranges for c_preva if RL_model is 'vanilla'
    if RL_model == 'vanilla':
        param_ranges['c_preva'] = (0.0, 0.0) # Fix c_preva to 0

    actions = jnp.asarray(data['action'])
    rewards = jnp.asarray(data['reward'])
    # contexts = jnp.asarray(data['context'])  # unused in current model

    # init params
    rng = np.random.default_rng(seed)
    params = {
        k: jnp.array(rng.uniform(low=lo, high=hi))
        for k, (lo, hi) in param_ranges.items()
    }

    # Ensure c_preva is 0 if RL_model is 'vanilla', even if random init was non-zero
    if RL_model == 'vanilla':
        params['c_preva'] = jnp.array(0.0)

    # optimizer
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    # projection to bounds
    def project(p):
        return {k: jnp.clip(v, *param_ranges[k]) for k, v in p.items()}

    @jax.jit
    def step(params, opt_state, actions, rewards):
        loss, grads = jax.value_and_grad(loss_fn)(params, actions, rewards)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = project(params)  # keep within bounds
        return params, opt_state, loss

    losses = []
    p = params
    s = opt_state
    for _ in range(num_iterations): # iterative update of parameters
        p, s, loss = step(p, s, actions, rewards)
        losses.append(float(loss))

    # return Python floats for convenience
    learned = {k: float(v) for k, v in p.items()}
    return learned, np.asarray(losses)

## meta data incorporation

files_list = feds_cropped if 'feds_cropped' in globals() else feds
assert isinstance(files_list, list) and len(files_list) > 0, "No FED3 files loaded."
assert 'Key_Df' in globals() and isinstance(Key_Df, pd.DataFrame), "Build/rematch Key_Df first."

# metadata_df = copy of Key_Df
metadata_df = Key_Df.copy().reset_index(drop=True)
if 'filename' in metadata_df.columns:
    metadata_df['filename'] = metadata_df['filename'].astype(str).map(os.path.basename)

def _coerce_numeric_col(df, col, clip_upper=None, na_map=None):
    if col not in df.columns:
        return
    s = df[col]
    if na_map:
        s = s.replace(na_map)
    s = pd.to_numeric(s, errors='coerce')
    if clip_upper is not None:
        s.loc[s > clip_upper] = np.nan
    df[col] = s

def attach_meta_data(file_index):
    df = files_list[file_index].copy()
    full_name = getattr(df, 'name', f"File_{file_index}")
    file_basename = os.path.basename(str(full_name))

    # Preserve original index once
    if "Original_Timestamp" not in df.columns:
        df["Original_Timestamp"] = df.index

    # Attach metadata by filename (matching already done upstream)
    meta_row = None
    if 'filename' in metadata_df.columns:
        mr = metadata_df.loc[metadata_df['filename'] == file_basename]
        if not mr.empty:
            meta_row = mr.iloc[0]
    if meta_row is not None:
        for col in meta_row.index:
            if col == 'filename':
                continue
            if col not in df.columns:
                df[col] = meta_row[col]
            else:
                if pd.isna(df[col]).all() and pd.notna(meta_row[col]):
                    df[col] = meta_row[col]

    # Time + cleanup
    try:
        df['timestamp'] = pd.to_datetime(df.index)
    except Exception:
        df['timestamp'] = np.arange(len(df))

    _coerce_numeric_col(df, 'Poke_Time', clip_upper=2)
    _coerce_numeric_col(df, 'Retrieval_Time', na_map={"Timed_out": np.nan})

    if len(df) == 0:
        print(f"[!] Empty dataframe for {file_basename}. Skipping plot.")
        return

    return df

df = attach_meta_data(1)


In [None]:
#@title RL fitting (this may take ~30-60 minutes)
rng = np.random.default_rng()

# inferring the RL parameters (alpha, beta, bias, lapse, c_preva)
# alpha : learning rate parameter
# beta : inverse temperature
# bias : bias in the action
# lapse:
# c_preva : the coefficient of action history term
#
# vanilla: vanilla Rescorla-Wagner model
# preva : Rescorla-Wagner model with previous-action variable for capturing hysteresis factor
#
def infer_params( session_data, hy_params, if_plot=False ):
	lr = hy_params['lr']
	num_iterations = hy_params['num_iterations']
	fitting_reps = hy_params['fitting_reps']
	RL_model = hy_params['RL_model'] # Get RL_model from hy_params

	for i in range(fitting_reps):
		fitted_params, losses = fit_with_adam(session_data, lr, num_iterations, seed=rng.choice(range(1000000)), RL_model=RL_model ) # Pass RL_model
		if i == 0:
			best_params, best_loss = fitted_params, losses[-1]
		elif losses[-1] < best_loss:
			best_params = fitted_params
			best_loss = losses[-1]
		print(i, losses[-1]) #fitted_params_emsemble.append( fitted_params )
		if if_plot:
			plt.plot(losses)
	if if_plot:
		plt.xlabel('fitting iterations')
		plt.ylabel('fitting error')
		plt.show()

	return best_params, best_loss

def run_fitting(fitting_repetition = 10):
  hy_params = {
      'fitting_reps': fitting_repetition, # the number of fitting repetitions (from different initial values). More is better, but 10 seems to be enough
      'lr': 5e-3, # learning rate (of Adam optimization)
      'num_iterations': 1000, # the total number of iteration per fitting
      'RL_model': 'preva', # {'vanilla', 'preva'};  Model choice
      }

  session_df = attach_meta_data(0)
  hy_params['task'] = session_df['Session_Type'].iloc[0]

  gene_name_id = str( session_df['Gene'].iloc[0] ) + '_'
  festr = 'pdata/RW_fitting_gene_' + gene_name_id + '_task_' + hy_params['task'] + '_model_' + hy_params['RL_model']\
          + '_freps' + str(hy_params['fitting_reps']) + '_lr' + str(hy_params['lr']) + '_niters' + str(hy_params['num_iterations']) + '.csv'
  os.makedirs(os.path.dirname(festr), exist_ok=True)
  few = open(festr,'w')

  RL_df = pd.DataFrame() # Initialize RL_df here
  for sidx in tqdm(range( len(files_list) )): # do fitting for all files in the files_list
    session_df = attach_meta_data(sidx)
    session_data = extract_context_action_reward(session_df)
    s_stats = get_basic_stats(session_data)
    best_params, best_loss = infer_params( session_data, hy_params, if_plot=False )

    new_row = (s_stats | best_params) if hasattr(dict, '__or__') else {**s_stats, **best_params}
    new_row["best_loss"] = best_loss / s_stats["session_length"]
    for key in ['Mouse_ID', 'Gene', 'Sex']: # add the basic meta data
      new_row[key] = session_df[key].iloc[0]

    if sidx == 0:
      RL_df = pd.DataFrame(columns=new_row.keys())
    RL_df.loc[len(RL_df)] = new_row

    # write down the best parameter into a file
    ltmps = str(s_stats['session_length']) + ',' + str(s_stats['reward_rate']) + ','+ str(s_stats['correct_port_rate'])\
            + ',' + str(s_stats['action_bias']) + ','  + str(best_params['alpha']) + ',' + str(best_params['beta'])\
            + ',' + str(best_params['bias']) + ',' + str(best_params['lapse']) + ',' + str(best_params['c_preva']) + ','\
            + str(best_loss/s_stats['session_length'])
    few.write(ltmps + '\n')
  few.flush()
  return RL_df

RL_df = run_fitting(fitting_repetition=10)


In [None]:
#TODO - make this plot look like the others :)

# @title Plot the RL metrics
plt.figure(figsize=(12, 3))

for kyidx, key in enumerate( ['alpha', 'beta', 'bias', 'lapse', 'c_preva', 'best_loss'] ):
  plt.subplot(1,6,kyidx+1)
  sns.boxplot(data=RL_df, x='Sex', y=key)
#  sns.swarmplot(data=RL_df, x='Genotype', y=key)
plt.show()

In [None]:
# @title Download the RL metrics


rl_df = RL_df.copy()

#rl_df["Session_type"] = rl_df["Session_type"] + "_RL"

output_file = f"RL_modeling.csv"


RL_df.to_csv(output_file, index=False)

try:
    from google.colab import files as gfiles
except Exception:
    gfiles = None

btn = widgets.Button(description=f"Download {os.path.basename(output_file)}", icon="download")
status = widgets.HTML()
def _dl(_):
    if gfiles is not None:
        status.value = f"Starting download: <code>{os.path.basename(output_file)}</code>…"
        gfiles.download(output_file)
    else:
        status.value = f"Saved locally at <code>{output_file}</code>."
display(btn, status)
btn.on_click(_dl)