# Two-Armed Bandit Task: Descriptive Analyses

## 1. Setup & Data Loading

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob

# Paths
code_dir = Path('.')
data_dir = Path('../data')

# Check for subject subset file
sublist_file = code_dir / 'sublist.txt'
if sublist_file.exists():
    with open(sublist_file, 'r') as f:
        subject_subset = [line.strip() for line in f if line.strip()]
    print(f'Running subset of {len(subject_subset)} subjects from sublist.txt')
else:
    subject_subset = None
    print('Running all subjects in data directory')

In [None]:
# Load all CSV files
all_files = glob.glob(str(data_dir / '**/*_task-bandit_*.csv'), recursive=True)

dfs = []
for f in all_files:
    df_tmp = pd.read_csv(f)
    # Filter by subset if specified
    if subject_subset is None or df_tmp['subject_id'].iloc[0] in subject_subset:
        dfs.append(df_tmp)

df = pd.concat(dfs, ignore_index=True)
print(f'Loaded {len(df)} trials from {len(dfs)} files')

## 2. Participant Summary

In [None]:
n_subjects = df['subject_id'].nunique()
n_sessions = df.groupby('subject_id')['session'].nunique().mean()
n_runs = df.groupby(['subject_id', 'session'])['run'].nunique().mean()
trials_per_run = df.groupby(['subject_id', 'session', 'run']).size().mean()
missed_trials = df['choice'].isna().sum()
missed_pct = 100 * missed_trials / len(df)

print(f'Subjects: {n_subjects}')
print(f'Sessions per subject: {n_sessions:.1f}')
print(f'Runs per session: {n_runs:.1f}')
print(f'Trials per run: {trials_per_run:.1f}')
print(f'Missed trials: {missed_trials} ({missed_pct:.1f}%)')

## 3. Behavioral Performance

In [None]:
# Exclude missed trials for performance metrics
df_valid = df[df['choice'].notna()].copy()

# Overall accuracy and win rate
overall_acc = df_valid['correct'].mean()
overall_win = df_valid['reward'].mean()

print(f'Overall accuracy (chose high-prob option): {overall_acc:.1%}')
print(f'Overall win rate: {overall_win:.1%}')

In [None]:
# Accuracy by stimulation condition
acc_by_cond = df_valid.groupby('stim_condition')['correct'].mean()
win_by_cond = df_valid.groupby('stim_condition')['reward'].mean()

print('Accuracy by condition:')
for cond in acc_by_cond.index:
    print(f'  {cond}: {acc_by_cond[cond]:.1%}')

print('\nWin rate by condition:')
for cond in win_by_cond.index:
    print(f'  {cond}: {win_by_cond[cond]:.1%}')

## 4. Response Times

In [None]:
print(f'Mean RT: {df_valid["rt"].mean():.0f} ms')
print(f'Median RT: {df_valid["rt"].median():.0f} ms')
print(f'SD RT: {df_valid["rt"].std():.0f} ms')

In [None]:
# RT by stimulation condition
rt_by_cond = df_valid.groupby('stim_condition')['rt'].agg(['mean', 'median', 'std'])

print('RT by condition (ms):')
print(rt_by_cond.round(0).to_string())

## 5. Learning Dynamics

In [None]:
# Accuracy by trial within contingency block
acc_by_trial_in_cont = df_valid.groupby('trial_in_contingency')['correct'].mean()

# Limit to first 20 trials for cleaner visualization
acc_by_trial_in_cont = acc_by_trial_in_cont[acc_by_trial_in_cont.index < 20]

print('Accuracy by trial within contingency block (first 20 trials):')
print(acc_by_trial_in_cont.to_string())

## 6. Visualizations

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Accuracy by condition
ax = axes[0]
acc_by_cond.plot(kind='bar', ax=ax, color='steelblue', edgecolor='black')
ax.set_ylabel('Accuracy')
ax.set_xlabel('Condition')
ax.set_title('Accuracy by Stimulation Condition')
ax.set_ylim(0, 1)
ax.axhline(0.5, color='gray', linestyle='--', linewidth=1)
ax.tick_params(axis='x', rotation=0)

# RT by condition
ax = axes[1]
rt_by_cond['mean'].plot(kind='bar', ax=ax, color='coral', edgecolor='black')
ax.set_ylabel('RT (ms)')
ax.set_xlabel('Condition')
ax.set_title('Mean RT by Stimulation Condition')
ax.tick_params(axis='x', rotation=0)

# Learning curve
ax = axes[2]
ax.plot(acc_by_trial_in_cont.index, acc_by_trial_in_cont.values, 
        marker='o', color='seagreen', linewidth=2, markersize=5)
ax.set_xlabel('Trial in Contingency Block')
ax.set_ylabel('Accuracy')
ax.set_title('Learning Curve (Post-Reversal)')
ax.set_ylim(0, 1)
ax.axhline(0.5, color='gray', linestyle='--', linewidth=1)

plt.tight_layout()
plt.show()