# Mouse Whisker-Stimulation Go/No-Go Behavior Analysis

Analysis of lick behavior relative to stimulus onset (whisker stimulation / tone). Visualizations at five temporal scales: per-trial, per-session, per-day, per-week, and longitudinal.

## Part 0: Setup and Data Loading

In [None]:
import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
from pathlib import Path
from datetime import datetime

# Publication-quality theme
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Helvetica', 'Arial', 'DejaVu Sans'],
    'font.size': 10,
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.linewidth': 1.2,
    'figure.dpi': 100,
})
# Color palette: Hit, Miss, FA, CR
OUTCOME_COLORS = {'Hit': '#2ecc71', 'Miss': '#95a5a6', 'False Alarm': '#e74c3c', 'Correct Reject': '#3498db'}
STIM_ONSET_COLOR = '#e74c3c'
REWARD_WINDOW_COLOR = '#27ae60'

In [None]:
def load_all_data(base_dir):
    """Load all trials and lick data from mouse 01 and 02. Returns trials DataFrame and list of (trial_id, lick_times_rel_stim)."""
    base = Path(base_dir)
    rows = []
    lick_data = []  # list of (mouse, date, session, TrNum, lick_times_rel_stim_ms)

    for mouse in ['01', '02']:
        mouse_path = base / mouse
        if not mouse_path.exists():
            continue
        for date_dir in sorted(mouse_path.iterdir()):
            if not date_dir.is_dir() or not re.match(r'^\d{6}$', date_dir.name):
                continue
            yy, mm, dd = date_dir.name[:2], date_dir.name[2:4], date_dir.name[4:6]
            try:
                date = datetime(2000 + int(yy), int(mm), int(dd)).date()
            except ValueError:
                continue
            week = date.isocalendar()[1]

            for f in sorted(date_dir.glob('*_trials.csv')):
                match = re.match(r'(\d{6})_S(\d+)_trials\.csv', f.name)
                if not match:
                    continue
                session = int(match.group(2))
                trials_df = pd.read_csv(f, sep='\t')
                trials_df['mouse'] = mouse
                trials_df['date'] = date
                trials_df['session'] = session
                trials_df['week'] = week
                trials_df['session_id'] = f"{date.isoformat()}_S{session}"
                rows.append(trials_df)

                # Load corresponding lick file
                lick_path = date_dir / f"{match.group(1)}_S{session}_lick.csv"
                if not lick_path.exists():
                    continue
                lick_df = pd.read_csv(lick_path, sep='\t', header=None)
                # First column is TrNum, rest are lick timestamps
                tr_start = trials_df.set_index('TrNum')['TrStartTime'].to_dict()
                tr_stim = trials_df.set_index('TrNum')['StimOnsetTime'].to_dict()
                for _, row in lick_df.iterrows():
                    trnum = int(row.iloc[0])
                    t_start = tr_start.get(trnum, 0)
                    stim_onset = tr_stim.get(trnum, 500)
                    licks = []
                    for c in range(1, len(row)):
                        t = row.iloc[c]
                        try:
                            t = float(t)
                        except (TypeError, ValueError):
                            continue
                        if np.isnan(t) or t <= 0:
                            continue
                        # Filter out artifact timestamps (e.g. 4.3e+09)
                        if t > 1e7:
                            continue
                        rel_trial = t - t_start
                        rel_stim = rel_trial - stim_onset
                        licks.append(rel_stim)
                    lick_data.append({'mouse': mouse, 'date': date, 'session': session, 'TrNum': trnum, 'week': week,
                                     'session_id': f"{date.isoformat()}_S{session}", 'lick_times_rel_stim_ms': licks})

    trials = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
    lick_df = pd.DataFrame(lick_data) if lick_data else pd.DataFrame()
    return trials, lick_df

In [None]:
# Load data (run from notebook directory or set path)
BASE_DIR = Path('/Users/jeremy/Downloads/TSC2_BehaviorData/Jeremy')
trials, lick_df = load_all_data(BASE_DIR)

# Map TrOutcome codes to labels: 1=Hit, 2=Correct Reject, 3=False Alarm, 4=Miss
OUTCOME_MAP = {1: 'Hit', 2: 'Correct Reject', 3: 'False Alarm', 4: 'Miss'}
trials['outcome_label'] = trials['TrOutcome'].map(OUTCOME_MAP)

# Reward window relative to stimulus onset (ms) for plotting
trials['RWStart_rel_stim'] = trials['RWStartTime'] - trials['TrStartTime'] - trials['StimOnsetTime']
trials['RWEnd_rel_stim'] = trials['RWEndTime'] - trials['TrStartTime'] - trials['StimOnsetTime']

print(f"Loaded {len(trials)} trials, {len(lick_df)} trial-lick rows")
print("Mice:", trials['mouse'].unique().tolist())
print("Date range:", trials['date'].min(), "to", trials['date'].max())

In [None]:
def compute_sdt_metrics(go_hits, go_total, nogo_fas, nogo_total, correction=0.5):
    """Compute d', hit rate, false alarm rate with 1/(2N) correction to avoid 0/1."""
    if go_total == 0 or nogo_total == 0:
        return np.nan, np.nan, np.nan
    hr = (go_hits + correction) / (go_total + 2 * correction)
    far = (nogo_fas + correction) / (nogo_total + 2 * correction)
    hr = np.clip(hr, 1e-6, 1 - 1e-6)
    far = np.clip(far, 1e-6, 1 - 1e-6)
    d_prime = norm.ppf(hr) - norm.ppf(far)
    return d_prime, hr, far

def compute_criterion(go_hits, go_total, nogo_fas, nogo_total, correction=0.5):
    """Signal detection criterion c = -0.5 * (z(HR) + z(FAR))."""
    if go_total == 0 or nogo_total == 0:
        return np.nan
    hr = (go_hits + correction) / (go_total + 2 * correction)
    far = (nogo_fas + correction) / (nogo_total + 2 * correction)
    hr = np.clip(hr, 1e-6, 1 - 1e-6)
    far = np.clip(far, 1e-6, 1 - 1e-6)
    c = -0.5 * (norm.ppf(hr) + norm.ppf(far))
    return c

def sdt_for_trials(trial_subset):
    """Compute d', HR, FAR for a DataFrame of trials (must have TrType and TrOutcome)."""
    go = trial_subset[trial_subset['TrType'] == 1]
    nogo = trial_subset[trial_subset['TrType'] == 0]
    go_hits = (go['TrOutcome'] == 1).sum()
    nogo_fas = (nogo['TrOutcome'] == 3).sum()
    d_prime, hr, far = compute_sdt_metrics(go_hits, len(go), nogo_fas, len(nogo))
    return d_prime, hr, far

## Part 1: Per-Trial Visualizations

Lick raster aligned to stimulus onset, trial outcome sequence, and rolling performance within a session.

In [None]:
def plot_lick_raster_one_session(trials_sub, lick_sub, session_id, time_win=(-500, 8000), max_trials=80):
    """Raster: each row = trial, x = time rel stimulus onset. Vertical lines: stimulus onset (red), reward window (green dashed)."""
    trials_sub = trials_sub.sort_values('TrNum').head(max_trials)
    lick_merge = lick_sub[lick_sub['session_id'] == session_id] if 'session_id' in lick_sub.columns else lick_sub

    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
    for ax, (tr_type, label) in zip(axes, [(1, 'Go'), (0, 'No-Go')]):
        sub = trials_sub[trials_sub['TrType'] == tr_type]
        if sub.empty:
            ax.set_visible(False)
            continue
        row = 0
        for _, tr in sub.iterrows():
            trnum = tr['TrNum']
            licks = []
            if not lick_merge.empty:
                match = lick_merge[(lick_merge['TrNum'] == trnum)]
                if not match.empty:
                    licks = match.iloc[0]['lick_times_rel_stim_ms']
                    if isinstance(licks, (list, np.ndarray)) and len(licks):
                        licks = [t for t in licks if time_win[0] <= t <= time_win[1]]
            color = OUTCOME_COLORS.get(tr['outcome_label'], '#333')
            for t in licks:
                ax.plot([t, t], [row - 0.4, row + 0.4], color=color, solid_capstyle='butt', linewidth=1.5)
            if not licks:
                ax.axhspan(row - 0.5, row + 0.5, alpha=0.15, color=color)
            row += 1
        ax.axvline(0, color=STIM_ONSET_COLOR, linewidth=2, label='Stimulus onset')
        rw_start = sub['RWStart_rel_stim'].iloc[0] if 'RWStart_rel_stim' in sub.columns else 5000
        rw_end = sub['RWEnd_rel_stim'].iloc[0] if 'RWEnd_rel_stim' in sub.columns else 8000
        ax.axvline(rw_start, color=REWARD_WINDOW_COLOR, linestyle='--', linewidth=1.5, label='Reward window')
        ax.axvline(rw_end, color=REWARD_WINDOW_COLOR, linestyle='--', linewidth=1.5)
        ax.set_ylabel(f'{label} trials')
        ax.set_ylim(-0.5, row)
        ax.legend(loc='upper right', frameon=True)
    axes[1].set_xlabel('Time relative to stimulus onset (ms)')
    axes[0].set_title(f'Lick raster — {session_id}')
    plt.tight_layout()
    plt.show()

In [None]:
# Example: pick one session (first with enough trials)
sample_session = trials['session_id'].drop_duplicates().iloc[0]
trials_one = trials[trials['session_id'] == sample_session].copy()
plot_lick_raster_one_session(trials_one, lick_df, sample_session, time_win=(-500, 9000))

In [None]:
def plot_trial_outcome_sequence(trials_sub, session_id, max_trials=100):
    """Horizontal strip: each column = trial, color = outcome; secondary row = Go vs No-Go."""
    trials_sub = trials_sub[trials_sub['session_id'] == session_id].sort_values('TrNum').head(max_trials)
    if trials_sub.empty:
        return
    fig, axes = plt.subplots(2, 1, figsize=(12, 2), sharex=True)
    ax = axes[0]
    colors = [OUTCOME_COLORS.get(l, '#333') for l in trials_sub['outcome_label']]
    ax.bar(range(len(trials_sub)), 1, color=colors, width=1, edgecolor='none')
    ax.set_ylabel('Outcome')
    ax.set_yticks([])
    ax.set_ylim(0, 1.2)
    ax.set_title(f'Trial outcome sequence — {session_id}')
    ax = axes[1]
    tr_type_colors = ['#3498db' if t == 0 else '#e67e22' for t in trials_sub['TrType']]
    ax.bar(range(len(trials_sub)), 1, color=tr_type_colors, width=1, edgecolor='none')
    ax.set_ylabel('Stimulus')
    ax.set_yticks([0.5])
    ax.set_yticklabels(['No-Go (blue) / Go (orange)'])
    ax.set_ylim(0, 1.2)
    ax.set_xlabel('Trial number')
    plt.tight_layout()
    plt.show()

plot_trial_outcome_sequence(trials, sample_session, max_trials=80)

In [None]:
def plot_rolling_performance(trials_sub, session_id, window=20):
    """Rolling d', hit rate, false alarm rate over trial index within session."""
    sub = trials_sub[trials_sub['session_id'] == session_id].sort_values('TrNum').reset_index(drop=True)
    if len(sub) < window:
        return
    d_vals, hr_vals, far_vals = [], [], []
    for i in range(len(sub)):
        start = max(0, i - window + 1)
        block = sub.iloc[start:i + 1]
        d, hr, far = sdt_for_trials(block)
        d_vals.append(d)
        hr_vals.append(hr)
        far_vals.append(far)
    fig, ax = plt.subplots(figsize=(10, 4))
    x = np.arange(len(sub))
    ax.plot(x, d_vals, color='#2c3e50', label="d'", linewidth=2)
    ax.plot(x, hr_vals, color='#2ecc71', label='Hit rate', linewidth=1.5)
    ax.plot(x, far_vals, color='#e74c3c', label='False alarm rate', linewidth=1.5)
    ax.axhline(0, color='gray', linestyle='-', linewidth=0.5)
    ax.set_xlabel('Trial number')
    ax.set_ylabel('Metric')
    ax.set_title(f'Rolling performance (window={window} trials) — {session_id}')
    ax.legend(loc='upper right')
    ax.set_ylim(-0.05, 1.05)
    plt.tight_layout()
    plt.show()

plot_rolling_performance(trials, sample_session, window=20)

## Part 2: Per-Session Visualizations

Session-level PSTH, summary bars (d', HR, FAR), and cumulative reward curve.

In [None]:
def plot_session_psth(trials_sub, lick_sub, session_id, bin_ms=50, time_win=(-500, 8000)):
    """PSTH: lick count per time bin relative to stimulus onset, Go vs No-Go."""
    trials_s = trials_sub[trials_sub['session_id'] == session_id]
    lick_s = lick_sub[lick_sub['session_id'] == session_id]
    if trials_s.empty or lick_s.empty:
        return
    bins = np.arange(time_win[0], time_win[1] + bin_ms, bin_ms)
    fig, ax = plt.subplots(figsize=(9, 4))
    for tr_type, label, color in [(1, 'Go', '#e67e22'), (0, 'No-Go', '#3498db')]:
        tr_nums = trials_s[trials_s['TrType'] == tr_type]['TrNum'].tolist()
        all_licks = []
        for trnum in tr_nums:
            row = lick_s[lick_s['TrNum'] == trnum]
            if not row.empty:
                licks = row.iloc[0]['lick_times_rel_stim_ms']
                if isinstance(licks, (list, np.ndarray)):
                    all_licks.extend([t for t in licks if time_win[0] <= t <= time_win[1]])
        counts, _ = np.histogram(all_licks, bins=bins)
        rate = counts / (len(tr_nums) * (bin_ms / 1000)) if tr_nums else counts
        ax.plot(bins[:-1] + bin_ms / 2, rate, color=color, label=label, linewidth=2)
    ax.axvline(0, color=STIM_ONSET_COLOR, linewidth=2, linestyle='-', label='Stimulus onset')
    rw = trials_s['RWStart_rel_stim'].iloc[0]
    ax.axvline(rw, color=REWARD_WINDOW_COLOR, linewidth=1.5, linestyle='--')
    ax.set_xlabel('Time relative to stimulus onset (ms)')
    ax.set_ylabel('Lick rate (Hz)')
    ax.set_title(f'Session PSTH — {session_id}')
    ax.legend()
    plt.tight_layout()
    plt.show()

plot_session_psth(trials, lick_df, sample_session, bin_ms=50)

In [None]:
# Session summary: d', hit rate, FA rate per session (one mouse or all)
session_metrics = []
for sid in trials['session_id'].unique():
    sub = trials[trials['session_id'] == sid]
    d, hr, far = sdt_for_trials(sub)
    session_metrics.append({'session_id': sid, 'date': sub['date'].iloc[0], 'mouse': sub['mouse'].iloc[0],
                            'week': sub['week'].iloc[0], "d'": d, 'Hit rate': hr, 'False alarm rate': far})
session_df = pd.DataFrame(session_metrics)

fig, ax = plt.subplots(figsize=(12, 5))
x = np.arange(len(session_df))
w = 0.25
ax.bar(x - w, session_df["d'"], width=w, label="d'", color='#2c3e50')
ax.bar(x, session_df['Hit rate'], width=w, label='Hit rate', color='#2ecc71')
ax.bar(x + w, session_df['False alarm rate'], width=w, label='False alarm rate', color='#e74c3c')
ax.set_xticks(x)
ax.set_xticklabels([s.replace('_', '\n') for s in session_df['session_id']], rotation=45, ha='right')
ax.set_ylabel('Metric')
ax.set_title('Performance per session')
ax.legend(loc='upper right')
plt.tight_layout()
plt.show()

In [None]:
# Cumulative reward curve per session
fig, ax = plt.subplots(figsize=(10, 4))
for sid in trials['session_id'].unique()[:12]:  # limit for readability
    sub = trials[trials['session_id'] == sid].sort_values('TrNum')
    if sub.empty or 'CumNRewards' not in sub.columns:
        continue
    ax.plot(sub['TrNum'], sub['CumNRewards'], label=sid, alpha=0.8)
ax.set_xlabel('Trial number')
ax.set_ylabel('Cumulative rewards')
ax.set_title('Cumulative reward curve per session')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=7)
plt.tight_layout()
plt.show()

## Part 3: Per-Day Visualizations

Daily performance trends, lick heatmap by day, and outcome proportions.

In [None]:
# Daily performance: d', HR, FAR with error bands (across sessions per day)
daily = []
for (mouse, date), grp in trials.groupby(['mouse', 'date']):
    d, hr, far = sdt_for_trials(grp)
    daily.append({'mouse': mouse, 'date': date, "d'": d, 'Hit rate': hr, 'False alarm rate': far})
daily_agg = pd.DataFrame(daily)

# Per-day with session-level variability for error bars
session_df['date'] = pd.to_datetime(session_df['date'])
daily_means = session_df.groupby(['mouse', 'date']).agg({"d'": 'mean', 'Hit rate': 'mean', 'False alarm rate': 'mean'}).reset_index()
daily_sem = session_df.groupby(['mouse', 'date']).agg({"d'": 'sem', 'Hit rate': 'sem', 'False alarm rate': 'sem'}).reset_index()

fig, ax = plt.subplots(figsize=(11, 5))
for mouse in session_df['mouse'].unique():
    sub = daily_means[daily_means['mouse'] == mouse]
    sem_sub = daily_sem[daily_sem['mouse'] == mouse]
    x = pd.to_datetime(sub['date'])
    ax.plot(x, sub["d'"], 'o-', color='#2c3e50', label=f"Mouse {mouse} d'", linewidth=2, markersize=8)
    ax.plot(x, sub['Hit rate'], 's-', color='#2ecc71', label=f"Mouse {mouse} Hit rate", linewidth=1.5)
    ax.plot(x, sub['False alarm rate'], '^-', color='#e74c3c', label=f"Mouse {mouse} FA rate", linewidth=1.5)
ax.set_xlabel('Date')
ax.set_ylabel('Metric')
ax.set_title('Daily performance over time (mean per day across sessions)')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
def plot_daily_lick_heatmap(trials_sub, lick_sub, date_val, bin_ms=100, time_win=(-200, 7000)):
    """2D heatmap: rows = trials (concatenated across sessions that day), cols = time rel stimulus."""
    day_trials = trials_sub[trials_sub['date'] == date_val].sort_values(['session', 'TrNum'])
    if day_trials.empty:
        return
    # Build trial order and get licks per trial
    lick_day = lick_sub[lick_sub['date'] == date_val]
    bins = np.arange(time_win[0], time_win[1] + bin_ms, bin_ms)
    n_trials = len(day_trials)
    mat = np.zeros((n_trials, len(bins) - 1))
    for i, (_, tr) in enumerate(day_trials.iterrows()):
        sid = tr['session_id']
        trnum = tr['TrNum']
        row = lick_day[(lick_day['session_id'] == sid) & (lick_day['TrNum'] == trnum)]
        if not row.empty:
            licks = row.iloc[0]['lick_times_rel_stim_ms']
            if isinstance(licks, (list, np.ndarray)):
                counts, _ = np.histogram([t for t in licks if time_win[0] <= t <= time_win[1]], bins=bins)
                mat[i, :] = counts
    fig, ax = plt.subplots(figsize=(10, 6))
    im = ax.imshow(mat, aspect='auto', cmap='YlOrRd', interpolation='nearest',
                   extent=[time_win[0], time_win[1], n_trials, 0])
    ax.axvline(0, color=STIM_ONSET_COLOR, linewidth=2, label='Stimulus onset')
    ax.set_xlabel('Time relative to stimulus onset (ms)')
    ax.set_ylabel('Trial (concatenated across sessions)')
    ax.set_title(f'Daily lick density — {date_val}')
    plt.colorbar(im, ax=ax, label='Lick count')
    plt.tight_layout()
    plt.show()

# Example: first date in dataset
example_date = trials['date'].iloc[0]
plot_daily_lick_heatmap(trials, lick_df, example_date)

In [None]:
# Daily outcome proportions (stacked bar)
outcome_order = ['Hit', 'Correct Reject', 'False Alarm', 'Miss']
daily_outcome = trials.groupby(['date', 'outcome_label']).size().unstack(fill_value=0)
daily_outcome = daily_outcome.reindex(columns=[c for c in outcome_order if c in daily_outcome.columns], fill_value=0)
daily_outcome_pct = daily_outcome.div(daily_outcome.sum(axis=1), axis=0) * 100

fig, ax = plt.subplots(figsize=(11, 5))
bottom = np.zeros(len(daily_outcome_pct))
colors = [OUTCOME_COLORS.get(c, '#333') for c in daily_outcome_pct.columns]
for col in daily_outcome_pct.columns:
    ax.bar(range(len(daily_outcome_pct)), daily_outcome_pct[col], bottom=bottom, label=col, color=OUTCOME_COLORS.get(col, '#333'))
    bottom += daily_outcome_pct[col].values
ax.set_xticks(range(len(daily_outcome_pct)))
ax.set_xticklabels([str(d) for d in daily_outcome_pct.index], rotation=45, ha='right')
ax.set_ylabel('Proportion (%)')
ax.set_xlabel('Date')
ax.set_title('Daily outcome proportions (Hit, CR, FA, Miss)')
ax.legend(loc='upper right')
plt.tight_layout()
plt.show()

## Part 4: Per-Week Visualizations

Weekly performance summary (box/violin), learning curve by week, and median first-lick latency by week.

In [None]:
# Weekly performance: box plot of d' by week (session-level)
session_df['week_label'] = 'Week ' + session_df['week'].astype(str)
fig, ax = plt.subplots(figsize=(8, 5))
sns.boxplot(data=session_df, x='week_label', y="d'", hue='mouse', ax=ax, palette=['#3498db', '#e67e22'])
ax.axhline(0, color='gray', linestyle='--', linewidth=0.5)
ax.set_xlabel('Week')
ax.set_ylabel("d'")
ax.set_title('Weekly performance (session-level d\')')
ax.legend(title='Mouse')
plt.xticks(rotation=15)
plt.tight_layout()
plt.show()

In [None]:
# Weekly learning curve: mean d', HR, FAR per week with error bars
week_means = session_df.groupby(['mouse', 'week']).agg({"d'": 'mean', 'Hit rate': 'mean', 'False alarm rate': 'mean'}).reset_index()
week_sem = session_df.groupby(['mouse', 'week']).agg({"d'": 'sem', 'Hit rate': 'sem', 'False alarm rate': 'sem'}).reset_index()

fig, ax = plt.subplots(figsize=(8, 5))
for mouse in session_df['mouse'].unique():
    wm = week_means[week_means['mouse'] == mouse]
    ws = week_sem[week_sem['mouse'] == mouse]
    ax.errorbar(wm['week'], wm["d'"], yerr=ws["d'"], marker='o', capsize=4, label=f"Mouse {mouse} d'", linewidth=2)
    ax.errorbar(wm['week'], wm['Hit rate'], yerr=ws['Hit rate'], marker='s', capsize=4, label=f"Mouse {mouse} HR", linestyle='--')
    ax.errorbar(wm['week'], wm['False alarm rate'], yerr=ws['False alarm rate'], marker='^', capsize=4, label=f"Mouse {mouse} FAR", linestyle=':')
ax.set_xlabel('Week')
ax.set_ylabel('Metric')
ax.set_title('Weekly learning curve (mean ± SEM across sessions)')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
ax.set_xticks(sorted(session_df['week'].unique()))
plt.tight_layout()
plt.show()

In [None]:
# Weekly lick timing: median first-lick latency (Go trials, hits) relative to stimulus onset
first_lick_latency = []
for _, tr in trials.iterrows():
    if tr['TrType'] != 1 or tr['TrOutcome'] != 1:  # Go and Hit only
        continue
    sid = tr['session_id']
    trnum = tr['TrNum']
    row = lick_df[(lick_df['session_id'] == sid) & (lick_df['TrNum'] == trnum)]
    if row.empty:
        continue
    licks = row.iloc[0]['lick_times_rel_stim_ms']
    if isinstance(licks, (list, np.ndarray)) and len(licks):
        # First lick in reward window (after stimulus) - use min positive latency
        in_window = [t for t in licks if t >= 0]
        if in_window:
            first_lick_latency.append({'week': tr['week'], 'mouse': tr['mouse'], 'latency_ms': min(in_window)})
fl_df = pd.DataFrame(first_lick_latency)
if not fl_df.empty:
    week_latency = fl_df.groupby(['mouse', 'week'])['latency_ms'].median().reset_index()
    fig, ax = plt.subplots(figsize=(8, 4))
    for mouse in fl_df['mouse'].unique():
        sub = week_latency[week_latency['mouse'] == mouse]
        ax.plot(sub['week'], sub['latency_ms'], 'o-', label=f'Mouse {mouse}', linewidth=2, markersize=8)
    ax.set_xlabel('Week')
    ax.set_ylabel('Median first-lick latency (ms, rel. stimulus onset)')
    ax.set_title('Weekly lick timing shift (Go-Hit trials)')
    ax.legend()
    ax.set_xticks(sorted(fl_df['week'].unique()))
    plt.tight_layout()
    plt.show()
else:
    print('No Go-Hit trials with licks found for first-lick latency.')

## Part 5: Longitudinal and Recommended Visualizations

Full learning curve with day/week labels, criterion (c) over time, lick bout analysis, reaction time distributions, and ROC curves.

In [None]:
# Full learning curve: d' over all sessions in chronological order, with day and week boundaries
session_df_sorted = session_df.sort_values(['mouse', 'date', 'session_id']).reset_index(drop=True)
session_df_sorted['session_ord'] = session_df_sorted.groupby('mouse').cumcount()
session_df_sorted['global_ord'] = np.arange(len(session_df_sorted))  # for early/late split elsewhere

fig, ax = plt.subplots(figsize=(14, 5))
for mouse in session_df_sorted['mouse'].unique():
    sub = session_df_sorted[session_df_sorted['mouse'] == mouse].reset_index(drop=True)
    ax.plot(sub['session_ord'], sub["d'"], 'o-', label=f'Mouse {mouse}', linewidth=2, markersize=6)
    # Day boundaries for this mouse: first session of each date
    for date_val in sub['date'].unique():
        idx = sub[sub['date'] == date_val].index[0]
        x = sub.loc[idx, 'session_ord'] - 0.5
        ax.axvline(x, color='gray', linestyle=':', linewidth=0.8)
        ylim = ax.get_ylim()
        ax.text(x, ylim[1] * 0.98, str(date_val), fontsize=7, rotation=90, va='top', color='gray')
ax.set_xlabel('Session (chronological, per mouse)')
ax.set_ylabel("d'")
ax.set_title('Full learning curve (vertical lines: day boundaries, labeled by date)')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Criterion (c) over time: add to session metrics and plot with d'
def criterion_for_trials(trial_subset):
    go = trial_subset[trial_subset['TrType'] == 1]
    nogo = trial_subset[trial_subset['TrType'] == 0]
    go_hits = (go['TrOutcome'] == 1).sum()
    nogo_fas = (nogo['TrOutcome'] == 3).sum()
    return compute_criterion(go_hits, len(go), nogo_fas, len(nogo))

session_df['criterion'] = session_df.apply(
    lambda r: criterion_for_trials(trials[trials['session_id'] == r['session_id']]), axis=1)

fig, ax1 = plt.subplots(figsize=(12, 5))
x = np.arange(len(session_df))
ax1.plot(x, session_df["d'"], 'o-', color='#2c3e50', label="d'", linewidth=2)
ax1.set_ylabel("d'", color='#2c3e50')
ax1.tick_params(axis='y', labelcolor='#2c3e50')
ax2 = ax1.twinx()
ax2.plot(x, session_df['criterion'], 's-', color='#9b59b6', label='Criterion c', linewidth=1.5)
ax2.set_ylabel('Criterion (c)', color='#9b59b6')
ax2.tick_params(axis='y', labelcolor='#9b59b6')
ax2.axhline(0, color='gray', linestyle='--', linewidth=0.5)
ax1.set_xticks(x)
ax1.set_xticklabels([s.replace('_', '\n') for s in session_df['session_id']], rotation=45, ha='right')
ax1.set_xlabel('Session')
ax1.set_title("d' and response bias (criterion) over time")
fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=2)
plt.tight_layout()
plt.show()

In [None]:
# Lick bout analysis: inter-lick interval distribution (early vs late training)
def get_inter_lick_intervals(lick_sub, trials_sub):
    ilis = []
    for _, row in lick_sub.iterrows():
        licks = row['lick_times_rel_stim_ms']
        if not isinstance(licks, (list, np.ndarray)) or len(licks) < 2:
            continue
        licks = sorted(licks)
        for i in range(1, len(licks)):
            ili = licks[i] - licks[i - 1]
            if 0 < ili < 2000:  # plausible ms
                ilis.append(ili)
    return ilis

# Split by early vs late sessions (by global chronological order)
session_order = session_df_sorted.set_index('session_id')['global_ord'].to_dict()
trials['global_ord'] = trials['session_id'].map(session_order)
mid = trials['global_ord'].max() / 2 if trials['global_ord'].notna().any() else 0
early_trials = trials[trials['global_ord'] <= mid]
late_trials = trials[trials['global_ord'] > mid]
early_lick = lick_df[lick_df['session_id'].isin(early_trials['session_id'].unique())]
late_lick = lick_df[lick_df['session_id'].isin(late_trials['session_id'].unique())]
ilis_early = get_inter_lick_intervals(early_lick, early_trials)
ilis_late = get_inter_lick_intervals(late_lick, late_trials)

fig, ax = plt.subplots(figsize=(8, 4))
if ilis_early:
    ax.hist(ilis_early, bins=50, range=(0, 500), alpha=0.6, label='Early training', color='#3498db', density=True)
if ilis_late:
    ax.hist(ilis_late, bins=50, range=(0, 500), alpha=0.6, label='Late training', color='#e67e22', density=True)
ax.set_xlabel('Inter-lick interval (ms)')
ax.set_ylabel('Density')
ax.set_title('Lick bout structure: inter-lick interval distribution')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Reaction time distribution: first-lick latency (Go-Hit) by week — violin plot
if not fl_df.empty:
    fl_df['week_label'] = 'Week ' + fl_df['week'].astype(str)
    fig, ax = plt.subplots(figsize=(9, 5))
    sns.violinplot(data=fl_df, x='week_label', y='latency_ms', hue='mouse', ax=ax, palette=['#3498db', '#e67e22'])
    ax.axhline(5000, color=REWARD_WINDOW_COLOR, linestyle='--', linewidth=1, label='Reward window ~5s')
    ax.set_xlabel('Week')
    ax.set_ylabel('First-lick latency (ms, rel. stimulus onset)')
    ax.set_title('Reaction time distribution (Go-Hit trials) by week')
    ax.legend(title='Mouse')
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.show()

In [None]:
# ROC curve: Hit rate vs False alarm rate, one point per week (or per session)
# Operating point should move toward upper-left with learning
fig, ax = plt.subplots(figsize=(6, 5))
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Chance')
for mouse in session_df['mouse'].unique():
    sub = session_df[session_df['mouse'] == mouse]
    ax.scatter(sub['False alarm rate'], sub['Hit rate'], label=f'Mouse {mouse} (sessions)', alpha=0.7, s=50)
# Weekly means as larger markers
for mouse in week_means['mouse'].unique():
    wm = week_means[week_means['mouse'] == mouse]
    ax.scatter(wm['False alarm rate'], wm['Hit rate'], marker='*', s=200, edgecolors='black', linewidths=1,
               label=f'Mouse {mouse} (weekly mean)', zorder=5)
ax.set_xlabel('False alarm rate')
ax.set_ylabel('Hit rate')
ax.set_title('ROC space: operating points (sessions and weekly means)')
ax.legend(loc='lower right')
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.set_aspect('equal')
plt.tight_layout()
plt.show()

In [None]:
# Unified extractor UI (auto-loaded)
from pathlib import Path
from behavior_data_extractor import show_extraction_widget

# Works from repo root or from inside the Jeremy folder
default_folder = 'Jeremy' if Path('Jeremy').exists() else '.'
show_extraction_widget(default_folder)
