# Model-Based Model-Free (MBMF) Regression Analysis

This notebook analyzes participants' use of model-free and model-based learning strategies using mixed-effects logistic regression. We examine how previous trial outcomes (reward and transition type) and participant age influence repeated first-stage choices in a two-step decision task.

## Research Questions:
1. **Model-Free Learning**: Do participants repeat choices more after rewarded vs unrewarded trials?
2. **Model-Based Learning**: Do participants show reward × transition interactions (repeating after rewarded-common or unrewarded-rare more than rewarded-rare or unrewarded-common)?
3. **Age Effects**: How do these learning strategies change with age?

## Expected Effects:
- **Main effect of reward**: Model-free learning (reward → repeat)
- **Reward × transition interaction**: Model-based learning 
- **Reward × transition × age interaction**: Age-related increase in model-based learning

## 1. Import Required Libraries and Setup

In [1]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import zscore
import warnings
warnings.filterwarnings('ignore')

# Statistical modeling libraries
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

# For mixed-effects models (if available)
try:
    from statsmodels.formula.api import mixedlm
    print("Mixed-effects models available via statsmodels")
except ImportError:
    print("Mixed-effects models not available - will use alternative approaches")

print("Libraries imported successfully!")

Mixed-effects models available via statsmodels
Libraries imported successfully!


In [2]:
# =============================================================================
# APA-7 VISUAL CONFIGURATION (consistent with descriptives notebook)
# =============================================================================

# APA-7 Font Configuration
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 12

# APA-7 Figure Settings
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['savefig.facecolor'] = 'white'
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.dpi'] = 300

# Remove top and right spines (APA-7 style)
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.left'] = True
plt.rcParams['axes.spines.bottom'] = True

# APA-7 Font Sizes
APA_TITLE_SIZE = 12
APA_LABEL_SIZE = 12
APA_TICK_SIZE = 10
APA_LEGEND_SIZE = 10

# APA-7 Colors (consistent with descriptives)
APA_BLUE = '#2E86AB'      # Professional blue
APA_ORANGE = '#FF8C00'    # Professional orange  
APA_GRAY = '#4A4A4A'      # Dark gray for text
APA_LIGHT_GRAY = '#E5E5E5'  # Light gray for backgrounds

# Colors for MBMF analysis (common/rare transitions)
COMMON_COLOR = APA_BLUE     # Blue for common transitions
RARE_COLOR = APA_ORANGE     # Orange for rare transitions

# Transparency level
APA_ALPHA = 0.7

# Set seaborn style
sns.set_style("white")
sns.set_context("paper", font_scale=1.0)

def apply_apa_style(ax, title=None, xlabel=None, ylabel=None):
    """Apply consistent APA-7 styling to matplotlib axes"""
    if title:
        ax.set_title(title, fontsize=APA_TITLE_SIZE, fontweight='normal', pad=20)
    if xlabel:
        ax.set_xlabel(xlabel, fontsize=APA_LABEL_SIZE)
    if ylabel:
        ax.set_ylabel(ylabel, fontsize=APA_LABEL_SIZE)
    
    ax.tick_params(labelsize=APA_TICK_SIZE)
    sns.despine(ax=ax)
    plt.tight_layout()
    return ax

print("APA-7 visual configuration loaded successfully!")
print(f"Colors: Common transitions = {COMMON_COLOR}, Rare transitions = {RARE_COLOR}")

APA-7 visual configuration loaded successfully!
Colors: Common transitions = #2E86AB, Rare transitions = #FF8C00


## 2. Data Preparation and Variable Creation

In [None]:
# Load the dataset
df = pd.read_csv('final_dataset.csv')

print("Dataset loaded successfully!")
print(f"Shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"Subjects: {df['subject_id'].nunique()}")
print(f"Trials per subject: {df.groupby('subject_id').size().describe()}")

# Display first few rows to understand structure
print("\nFirst 5 rows:")
df.head()

Dataset loaded successfully!
Shape: (30200, 12)
Columns: ['subject_id', 'practice_trial', 'trial', 'transition', 'reward', 'choice_1', 'choice_2', 'rt_1', 'rt_2', 'state', 'gender', 'age']
Subjects: 151
Trials per subject: count    151.0
mean     200.0
std        0.0
min      200.0
25%      200.0
50%      200.0
75%      200.0
max      200.0
dtype: float64

First 5 rows:


Unnamed: 0,subject_id,practice_trial,trial,transition,reward,choice_1,choice_2,rt_1,rt_2,state,gender,age
0,sub1,real,1,common,1,1,1,1131.48,1457.62,2,Female,17.055556
1,sub1,real,2,common,0,1,1,639.0,483.59,2,Female,17.055556
2,sub1,real,3,rare,0,2,1,264.19,718.195,2,Female,17.055556
3,sub1,real,4,common,1,2,2,302.265,1507.225,3,Female,17.055556
4,sub1,real,5,rare,1,2,2,321.685,1751.195,2,Female,17.055556


In [8]:
# =============================================================================
# CREATE VARIABLES FOR MBMF ANALYSIS
# =============================================================================

def create_mbmf_variables(df):
    """
    Create variables needed for model-based model-free analysis:
    - previous_reward: Reward outcome from previous trial
    - previous_transition: Transition type from previous trial (common/rare)
    - stay: Whether participant repeated their first-stage choice
    - age_z: Z-scored age for regression analysis
    """
    
    # Sort by subject and trial to ensure proper lagging
    df_sorted = df.sort_values(['subject_id', 'trial']).copy()
    
    # Create lagged variables (previous trial outcomes)
    df_sorted['previous_reward'] = df_sorted.groupby('subject_id')['reward'].shift(1)
    df_sorted['previous_transition'] = df_sorted.groupby('subject_id')['transition'].shift(1)
    df_sorted['previous_choice1'] = df_sorted.groupby('subject_id')['choice_1'].shift(1)
    
    # Create stay variable (1 if repeated first-stage choice, 0 if switched)
    df_sorted['stay'] = (df_sorted['choice_1'] == df_sorted['previous_choice1']).astype(int)
    
    # Add age group categorization
    def categorize_age(age):
        age_floor = int(age)
        if 8 <= age_floor <= 12:
            return 'Children'
        elif 13 <= age_floor <= 17:
            return 'Adolescents'
        elif 18 <= age_floor <= 25:
            return 'Adults'
        else:
            return 'Other'
    
    df_sorted['age_group'] = df_sorted['age'].apply(categorize_age)
    
    # Z-score age for regression analysis
    df_sorted['age_z'] = zscore(df_sorted['age'])
    
    # Convert categorical variables to appropriate formats
    df_sorted['previous_reward'] = df_sorted['previous_reward'].astype('Int64')  # Allow NaN
    df_sorted['previous_transition'] = df_sorted['previous_transition'].astype('category')
    df_sorted['age_group'] = df_sorted['age_group'].astype('category')
    
    # Remove first trial per subject (no previous trial)
    df_analysis = df_sorted.dropna(subset=['previous_reward', 'previous_transition']).copy()
    
    return df_analysis

# Create analysis dataset
df_mbmf = create_mbmf_variables(df)

print("MBMF variables created successfully!")
print(f"Analysis dataset shape: {df_mbmf.shape}")
print(f"Trials excluded (first per subject): {len(df) - len(df_mbmf)}")

# Check variable distributions
print("\nVariable distributions:")
print(f"Stay behavior: {df_mbmf['stay'].value_counts()}")
print(f"Previous reward: {df_mbmf['previous_reward'].value_counts()}")
print(f"Previous transition: {df_mbmf['previous_transition'].value_counts()}")
print(f"Age groups: {df_mbmf['age_group'].value_counts()}")

# Check for missing values
print(f"\nMissing values:")
print(df_mbmf[['stay', 'previous_reward', 'previous_transition', 'age_z']].isnull().sum())

MBMF variables created successfully!
Analysis dataset shape: (29799, 18)
Trials excluded (first per subject): 401

Variable distributions:
Stay behavior: stay
1    21664
0     8135
Name: count, dtype: int64
Previous reward: previous_reward
1    15924
0    13875
Name: count, dtype: Int64
Previous transition: previous_transition
common    20915
rare       8884
Name: count, dtype: int64
Age groups: age_group
Adults         10053
Adolescents     9901
Children        9845
Name: count, dtype: int64

Missing values:
stay                   0
previous_reward        0
previous_transition    0
age_z                  0
dtype: int64


## 3. Exploratory Data Analysis

In [None]:
# =============================================================================
# BASIC STAY BEHAVIOR ANALYSIS
# =============================================================================

# Overall stay probability
overall_stay = df_mbmf['stay'].mean()
print(f"Overall stay probability: {overall_stay:.3f}")

# Stay probability by previous reward
stay_by_reward = df_mbmf.groupby('previous_reward')['stay'].agg(['mean', 'std', 'count'])
print(f"\nStay probability by previous reward:")
print(stay_by_reward)

# Stay probability by previous transition
stay_by_transition = df_mbmf.groupby('previous_transition')['stay'].agg(['mean', 'std', 'count'])
print(f"\nStay probability by previous transition:")
print(stay_by_transition)

# Stay probability by age group
stay_by_age = df_mbmf.groupby('age_group')['stay'].agg(['mean', 'std', 'count'])
print(f"\nStay probability by age group:")
print(stay_by_age)

# Cross-tabulation: reward × transition
stay_cross = df_mbmf.groupby(['previous_reward', 'previous_transition'])['stay'].agg(['mean', 'std', 'count'])
print(f"\nStay probability by reward × transition:")
print(stay_cross)

In [None]:
# =============================================================================
# EXPLORATORY VISUALIZATIONS
# =============================================================================

# Create figure with subplots for exploratory analysis
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 1. Stay probability by reward and transition
stay_summary = df_mbmf.groupby(['previous_reward', 'previous_transition'])['stay'].mean().reset_index()
stay_pivot = stay_summary.pivot(index='previous_reward', columns='previous_transition', values='stay')

sns.heatmap(stay_pivot, annot=True, cmap='Blues', ax=axes[0,0], 
            cbar_kws={'label': 'Stay Probability'})
axes[0,0].set_title('Stay Probability by Previous Reward × Transition')
axes[0,0].set_xlabel('Previous Transition')
axes[0,0].set_ylabel('Previous Reward')

# 2. Stay probability by age (continuous)
age_stay = df_mbmf.groupby('age')['stay'].mean().reset_index()
axes[0,1].scatter(age_stay['age'], age_stay['stay'], alpha=0.6, color=APA_BLUE)
axes[0,1].plot(np.unique(age_stay['age']), 
               np.poly1d(np.polyfit(age_stay['age'], age_stay['stay'], 1))(np.unique(age_stay['age'])),
               color=APA_ORANGE, linewidth=2)
axes[0,1].set_title('Stay Probability by Age')
axes[0,1].set_xlabel('Age (years)')
axes[0,1].set_ylabel('Stay Probability')

# 3. Stay probability by age group and reward
stay_age_reward = df_mbmf.groupby(['age_group', 'previous_reward'])['stay'].mean().reset_index()
sns.barplot(data=stay_age_reward, x='age_group', y='stay', hue='previous_reward', 
            palette=[APA_GRAY, APA_ORANGE], ax=axes[1,0])
axes[1,0].set_title('Stay Probability by Age Group × Reward')
axes[1,0].set_xlabel('Age Group')
axes[1,0].set_ylabel('Stay Probability')
axes[1,0].legend(title='Previous Reward', labels=['Unrewarded', 'Rewarded'])

# 4. Distribution of stay behavior by age group
sns.boxplot(data=df_mbmf, x='age_group', y='stay', color=APA_BLUE, ax=axes[1,1])
axes[1,1].set_title('Stay Behavior Distribution by Age Group')
axes[1,1].set_xlabel('Age Group')
axes[1,1].set_ylabel('Stay (0=Switch, 1=Stay)')

# Apply APA styling to all subplots
for ax in axes.flat:
    ax.tick_params(labelsize=APA_TICK_SIZE)
    sns.despine(ax=ax)

plt.tight_layout()
plt.show()

print("Exploratory analysis complete!")

## 4. Main Stay Logistic Regression Analysis

This is the core analysis examining how previous reward, previous transition, and age influence stay behavior. The key interaction of interest is **reward × transition × age**, which indicates age-related changes in model-based learning.

In [None]:
# =============================================================================
# MIXED-EFFECTS LOGISTIC REGRESSION MODEL
# =============================================================================

# Prepare data for regression (encode categorical variables)
df_regression = df_mbmf.copy()

# Convert transition to numeric (0 = common, 1 = rare)
df_regression['transition_numeric'] = (df_regression['previous_transition'] == 'rare').astype(int)

print("Data preparation for regression:")
print(f"Previous transition mapping: common=0, rare=1")
print(f"Previous reward mapping: 0=unrewarded, 1=rewarded")
print(f"Age standardized (z-scored): mean={df_regression['age_z'].mean():.3f}, std={df_regression['age_z'].std():.3f}")

# Model formula for mixed-effects logistic regression
# Note: In Python, we'll start with standard logistic regression and discuss mixed-effects approaches
print("\n" + "="*80)
print("STAY BEHAVIOR LOGISTIC REGRESSION ANALYSIS")
print("="*80)

# Create interaction terms manually for clearer interpretation
df_regression['reward_x_transition'] = df_regression['previous_reward'] * df_regression['transition_numeric']
df_regression['reward_x_age'] = df_regression['previous_reward'] * df_regression['age_z']
df_regression['transition_x_age'] = df_regression['transition_numeric'] * df_regression['age_z']
df_regression['reward_x_transition_x_age'] = df_regression['previous_reward'] * df_regression['transition_numeric'] * df_regression['age_z']

# Standard logistic regression (population-level effects)
X = df_regression[['previous_reward', 'transition_numeric', 'age_z', 
                   'reward_x_transition', 'reward_x_age', 'transition_x_age', 
                   'reward_x_transition_x_age']]
y = df_regression['stay']

# Fit logistic regression model
logit_model = sm.Logit(y, sm.add_constant(X)).fit()

print("LOGISTIC REGRESSION RESULTS:")
print("="*50)
print(logit_model.summary())

# Extract key statistics
print("\n" + "="*50)
print("KEY EFFECTS:")
print("="*50)

coef_names = ['Intercept', 'Previous Reward (MF)', 'Previous Transition', 'Age', 
              'Reward × Transition (MB)', 'Reward × Age', 'Transition × Age', 
              'Reward × Transition × Age']

for i, name in enumerate(coef_names):
    coef = logit_model.params[i]
    pval = logit_model.pvalues[i]
    ci_low, ci_high = logit_model.conf_int().iloc[i]
    
    significance = ""
    if pval < 0.001:
        significance = "***"
    elif pval < 0.01:
        significance = "**"
    elif pval < 0.05:
        significance = "*"
    elif pval < 0.1:
        significance = "."
    
    print(f"{name:25}: β = {coef:6.3f}, p = {pval:6.3f}{significance:3}, 95% CI [{ci_low:6.3f}, {ci_high:6.3f}]")

print("\nSignificance codes: *** p<0.001, ** p<0.01, * p<0.05, . p<0.1")
print("\nKey interpretations:")
print("- Previous Reward (MF): Model-free learning effect")
print("- Reward × Transition (MB): Model-based learning effect") 
print("- Reward × Transition × Age: Age-related change in model-based learning")

## 5. Stay Probability Analysis by Age Group

Now we'll visualize the key findings by creating publication-ready plots showing stay probabilities across different conditions.

In [None]:
# =============================================================================
# STAY PROBABILITY ANALYSIS BY AGE GROUP
# =============================================================================

# Calculate stay statistics by subject first, then aggregate
subject_stay_stats = df_mbmf.groupby(['subject_id', 'age_group', 'previous_reward', 'previous_transition']).agg({
    'stay': ['mean', 'count']
}).reset_index()

# Flatten column names
subject_stay_stats.columns = ['subject_id', 'age_group', 'previous_reward', 'previous_transition', 'mean_stay', 'n_trials']
subject_stay_stats = subject_stay_stats[subject_stay_stats['n_trials'] >= 3]  # Only include conditions with at least 3 trials

# Calculate group-level statistics
group_stay_stats = subject_stay_stats.groupby(['age_group', 'previous_reward', 'previous_transition']).agg({
    'mean_stay': ['mean', 'std', 'sem', 'count']
}).reset_index()

# Flatten column names
group_stay_stats.columns = ['age_group', 'previous_reward', 'previous_transition', 'stay_prop', 'stay_sd', 'stay_sem', 'n_subjects']

print("Stay probability by age group, reward, and transition:")
print(group_stay_stats.round(3))

# Create reward and transition labels for plotting
group_stay_stats['reward_label'] = group_stay_stats['previous_reward'].map({0: 'Unrewarded', 1: 'Rewarded'})
group_stay_stats['transition_label'] = group_stay_stats['previous_transition'].map({'common': 'Common', 'rare': 'Rare'})

# Create the main stay probability plot
fig, ax = plt.subplots(figsize=(10, 6))

# Set up position for grouped bars
age_groups = ['Children', 'Adolescents', 'Adults']
x_pos = np.arange(len(age_groups))
width = 0.2

# Define positions for each condition
pos_unrewarded_common = x_pos - 1.5*width
pos_unrewarded_rare = x_pos - 0.5*width  
pos_rewarded_common = x_pos + 0.5*width
pos_rewarded_rare = x_pos + 1.5*width

# Create bars for each condition
for i, age_group in enumerate(age_groups):
    group_data = group_stay_stats[group_stay_stats['age_group'] == age_group]
    
    # Unrewarded-Common
    uc_data = group_data[(group_data['previous_reward'] == 0) & (group_data['previous_transition'] == 'common')]
    if not uc_data.empty:
        ax.bar(pos_unrewarded_common[i], uc_data['stay_prop'].iloc[0], width, 
               yerr=uc_data['stay_sem'].iloc[0], color='lightblue', edgecolor='black',
               label='Unrewarded-Common' if i == 0 else "")
    
    # Unrewarded-Rare  
    ur_data = group_data[(group_data['previous_reward'] == 0) & (group_data['previous_transition'] == 'rare')]
    if not ur_data.empty:
        ax.bar(pos_unrewarded_rare[i], ur_data['stay_prop'].iloc[0], width,
               yerr=ur_data['stay_sem'].iloc[0], color='lightcoral', edgecolor='black',
               label='Unrewarded-Rare' if i == 0 else "")
    
    # Rewarded-Common
    rc_data = group_data[(group_data['previous_reward'] == 1) & (group_data['previous_transition'] == 'common')]
    if not rc_data.empty:
        ax.bar(pos_rewarded_common[i], rc_data['stay_prop'].iloc[0], width,
               yerr=rc_data['stay_sem'].iloc[0], color=COMMON_COLOR, edgecolor='black',
               label='Rewarded-Common' if i == 0 else "")
    
    # Rewarded-Rare
    rr_data = group_data[(group_data['previous_reward'] == 1) & (group_data['previous_transition'] == 'rare')]
    if not rr_data.empty:
        ax.bar(pos_rewarded_rare[i], rr_data['stay_prop'].iloc[0], width,
               yerr=rr_data['stay_sem'].iloc[0], color=RARE_COLOR, edgecolor='black',
               label='Rewarded-Rare' if i == 0 else "")

# Customize the plot
ax.set_xlabel('Age Group')
ax.set_ylabel('Proportion of First-Stage Stays')
ax.set_title('Stay Probability by Age Group, Reward, and Transition')
ax.set_xticks(x_pos)
ax.set_xticklabels(age_groups)
ax.set_ylim(0.4, 1.0)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Apply APA styling
apply_apa_style(ax)
plt.tight_layout()
plt.show()

print("Stay probability plot created successfully!")