# RDEX-ABCD Model Tutorial

This notebook demonstrates the **RDEX-ABCD model** for analyzing stop-signal task data from the ABCD study.

**Reference**: Weigard, A., Matzke, D., Tanis, C., & Heathcote, A. (2023). A cognitive process modeling framework for the ABCD study stop-signal task. *Developmental Cognitive Neuroscience, 59*, 101191.

## Overview

1. Generate synthetic ABCD stop-signal data with context independence violations
2. Fit the RDEX-ABCD model
3. Verify parameter recovery
4. Demonstrate that the model captures key patterns from Weigard et al. (2023):
   - Choice accuracy increases with SSD on stop trials
   - SSD-dependent drift rates account for stimulus replacement
   - Separation of processing speed, perceptual growth, and inhibition

## Setup

In [None]:
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import arviz as az

# Import the RDEX-ABCD model
import sys
sys.path.insert(0, '..')  # Add parent directory to path
from pydmc import RDEXABCDModel

# Plotting settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

# Random seed for reproducibility
np.random.seed(42)

## 1. The ABCD Task and Context Independence Violation

### Standard Stop-Signal Task
In standard stop-signal tasks, the go stimulus remains visible when the stop signal appears, maintaining "context independence" - the go process is identical on go and stop trials.

### ABCD Task Design
The ABCD study's stop-signal task has a critical difference: **the stop signal replaces the go stimulus**. This means:
- At short SSDs, participants have limited time to process the choice stimulus
- At SSD = 0, no choice information is available (chance accuracy expected)
- As SSD increases, more processing time is available before the stimulus disappears

### RDEX-ABCD Solution (Weigard et al., 2023)
The model accounts for this by implementing **SSD-dependent drift rates**:
- **v0**: Processing speed (base rate without discrimination)
- **g**: Perceptual growth rate (how discrimination improves with SSD)
- **v+, v-**: Asymptotic drift rates (reached at long SSDs)

The drift rates grow as: `v(SSD) = v0 + min(g * SSD, v_asymptote - v0)`

## 2. Generate Synthetic Data

Generate synthetic ABCD stop-signal data with context independence violations to test parameter recovery.

In [None]:
def simulate_abcd_stop_signal_data(n_subjects=10, n_go_per_subject=300, n_stop_per_subject=60, seed=42):
    """
    Simulate ABCD stop-signal data with context independence violations.
    
    Implements the RDEX-ABCD data-generating process with SSD-dependent drift rates.
    """
    np.random.seed(seed)
    
    all_data = []
    true_params = []
    
    MAX_RT = 1.5
    
    for subj_idx in range(n_subjects):
        # Individual-level parameters
        params = {
            'subject_id': f'sub-{subj_idx:03d}',
            'B': np.random.lognormal(-0.2, 0.15) + 0.5,
            't0': np.random.beta(3, 17) + 0.08,
            'v_plus': np.random.normal(3.5, 0.3),
            'v_minus': np.random.normal(1.0, 0.3),
            'v0': np.random.normal(2.2, 0.25),
            'g': np.random.gamma(4, 0.6),
            'ssrt_mu': np.random.normal(0.20, 0.03),
            'ssrt_sigma': np.random.gamma(2, 0.02),
            'ssrt_tau': np.random.gamma(3, 0.015),
            'ptf': np.random.beta(2, 20),
            'pgf': np.random.beta(2, 30),
        }
        true_params.append(params)
        
        # Simulate GO trials
        for trial_idx in range(n_go_per_subject):
            stimulus = np.random.randint(0, 2)
            
            if np.random.rand() < params['pgf']:
                all_data.append({
                    'subject': params['subject_id'],
                    'trial_type': 'go',
                    'stimulus': stimulus,
                    'response': 0,
                    'rt': np.nan,
                    'ssd': np.nan
                })
                continue
            
            max_attempts = 100
            for attempt in range(max_attempts):
                t_match = stats.invgauss.rvs(
                    mu=params['B'] / params['v_plus'],
                    scale=params['B']**2
                )
                t_mismatch = stats.invgauss.rvs(
                    mu=params['B'] / max(params['v_minus'], 0.3),
                    scale=params['B']**2
                )
                
                if t_match < t_mismatch:
                    response = stimulus + 1
                    rt = params['t0'] + t_match
                else:
                    response = 2 - stimulus
                    rt = params['t0'] + t_mismatch
                
                if rt <= MAX_RT:
                    break
            else:
                rt = params['t0'] + np.random.uniform(0.2, 0.5)
                response = stimulus + 1
            
            all_data.append({
                'subject': params['subject_id'],
                'trial_type': 'go',
                'stimulus': stimulus,
                'response': response,
                'rt': rt,
                'ssd': np.nan
            })
        
        # Simulate STOP trials
        for _ in range(n_stop_per_subject):
            stimulus = np.random.randint(0, 2)
            ssd = np.random.uniform(0.05, 0.45)
            trigger_failed = np.random.rand() < params['ptf']
            
            # SSD-dependent drift rates
            discrimination_growth = min(params['g'] * ssd, params['v_plus'] - params['v0'])
            v_plus_ssd = params['v0'] + discrimination_growth
            
            v_minus_discrimination = min(params['g'] * ssd, abs(params['v_minus'] - params['v0']))
            if params['v_minus'] < params['v0']:
                v_minus_ssd = params['v0'] - v_minus_discrimination
            else:
                v_minus_ssd = params['v0'] + v_minus_discrimination
            
            if np.random.rand() < params['pgf']:
                all_data.append({
                    'subject': params['subject_id'],
                    'trial_type': 'successful_stop',
                    'stimulus': stimulus,
                    'response': 0,
                    'rt': np.nan,
                    'ssd': ssd
                })
                continue
            
            max_attempts = 50
            for attempt in range(max_attempts):
                t_match = stats.invgauss.rvs(
                    mu=params['B'] / max(v_plus_ssd, 0.3),
                    scale=params['B']**2
                )
                t_mismatch = stats.invgauss.rvs(
                    mu=params['B'] / max(v_minus_ssd, 0.3),
                    scale=params['B']**2
                )
                
                if t_match < t_mismatch:
                    go_response = stimulus + 1
                    go_rt = params['t0'] + t_match
                else:
                    go_response = 2 - stimulus
                    go_rt = params['t0'] + t_mismatch
                
                if go_rt <= MAX_RT:
                    break
            else:
                go_rt = params['t0'] + np.random.uniform(0.2, 0.5)
                go_response = stimulus + 1
            
            if not trigger_failed:
                normal_part = np.random.normal(params['ssrt_mu'], params['ssrt_sigma'])
                exp_part = np.random.exponential(params['ssrt_tau'])
                ssrt_sample = max(normal_part + exp_part, 0.05)
                stop_rt = ssd + ssrt_sample
            else:
                stop_rt = np.inf
            
            if go_rt < stop_rt:
                all_data.append({
                    'subject': params['subject_id'],
                    'trial_type': 'signal_respond',
                    'stimulus': stimulus,
                    'response': go_response,
                    'rt': go_rt,
                    'ssd': ssd
                })
            else:
                all_data.append({
                    'subject': params['subject_id'],
                    'trial_type': 'successful_stop',
                    'stimulus': stimulus,
                    'response': 0,
                    'rt': np.nan,
                    'ssd': ssd
                })
    
    df = pd.DataFrame(all_data)
    params_df = pd.DataFrame(true_params)
    
    return df, params_df

# Generate data
data, true_params = simulate_abcd_stop_signal_data(
    n_subjects=10,
    n_go_per_subject=300,
    n_stop_per_subject=60,
    seed=42
)

print(f"Generated {len(data)} trials from {data['subject'].nunique()} subjects")
print(f"  Go trials: {(data['trial_type'] == 'go').sum()}")
print(f"  Signal-respond trials: {(data['trial_type'] == 'signal_respond').sum()}")
print(f"  Successful stops: {(data['trial_type'] == 'successful_stop').sum()}")

data.head(10)

## 3. Verify Context Independence Violation

Confirm that the synthetic data exhibits the key pattern from Weigard et al. (2023): choice accuracy on signal-respond trials increases with SSD.

In [None]:
# Compute choice accuracy for signal-respond trials by SSD
signal_respond = data[data['trial_type'] == 'signal_respond'].copy()

signal_respond['correct'] = (
    ((signal_respond['stimulus'] == 0) & (signal_respond['response'] == 1)) |
    ((signal_respond['stimulus'] == 1) & (signal_respond['response'] == 2))
).astype(int)

signal_respond['ssd_bin'] = pd.cut(signal_respond['ssd'], bins=5)

accuracy_by_ssd = signal_respond.groupby('ssd_bin')['correct'].agg(['mean', 'sem', 'count']).reset_index()
accuracy_by_ssd['ssd_midpoint'] = accuracy_by_ssd['ssd_bin'].apply(lambda x: x.mid)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Choice accuracy by SSD
ax = axes[0]
ax.errorbar(accuracy_by_ssd['ssd_midpoint'], accuracy_by_ssd['mean'],
            yerr=accuracy_by_ssd['sem'], marker='o', markersize=8,
            capsize=5, capthick=2, linewidth=2)
ax.axhline(0.5, color='gray', linestyle='--', label='Chance')
ax.set_xlabel('Stop-Signal Delay (s)', fontsize=12)
ax.set_ylabel('Choice Accuracy', fontsize=12)
ax.set_title('Context Independence Violation\n(Accuracy increases with SSD)', fontsize=13, fontweight='bold')
ax.set_ylim([0.4, 1.0])
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Inhibition function
ax = axes[1]
stop_trials = data[data['trial_type'].isin(['signal_respond', 'successful_stop'])].copy()
stop_trials['responded'] = (stop_trials['response'] > 0).astype(int)
stop_trials['ssd_bin'] = pd.cut(stop_trials['ssd'], bins=5)

inhib_func = stop_trials.groupby('ssd_bin')['responded'].agg(['mean', 'sem']).reset_index()
inhib_func['ssd_midpoint'] = inhib_func['ssd_bin'].apply(lambda x: x.mid)

ax.errorbar(inhib_func['ssd_midpoint'], inhib_func['mean'],
            yerr=inhib_func['sem'], marker='s', markersize=8,
            capsize=5, capthick=2, linewidth=2, color='coral')
ax.set_xlabel('Stop-Signal Delay (s)', fontsize=12)
ax.set_ylabel('P(Respond | Stop Signal)', fontsize=12)
ax.set_title('Inhibition Function', fontsize=13, fontweight='bold')
ax.set_ylim([0, 1.0])
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Accuracy range: {accuracy_by_ssd['mean'].iloc[0]:.2f} (short SSD) to {accuracy_by_ssd['mean'].iloc[-1]:.2f} (long SSD)")
print(f"P(respond) range: {inhib_func['mean'].iloc[0]:.2f} to {inhib_func['mean'].iloc[-1]:.2f}")
print("\nLeft panel demonstrates the context independence violation that RDEX-ABCD models.")

## 4. Fit the RDEX-ABCD Model

Fit the hierarchical RDEX-ABCD model to the synthetic data. Note: Using reduced samples for speed. For publication analyses, use draws=1000+, tune=1000+, chains=4.

In [None]:
# Create and fit model
model = RDEXABCDModel(use_hierarchical=True)

trace = model.fit(
    data,
    draws=100,        # Use 1000+ for real analysis
    tune=100,         # Use 1000+ for real analysis  
    chains=2,         # Use 4 for real analysis
    target_accept=0.9,
    return_inferencedata=True
)

print("Model fitting complete")

## 5. Convergence Diagnostics

In [None]:
# Summary statistics
summary = az.summary(
    trace,
    var_names=['mu_B', 'mu_t0', 'mu_v_plus', 'mu_v_minus', 'mu_v0', 'mu_g',
               'mu_stop_mu', 'mu_pgf_probit', 'mu_ptf_probit']
)

print("GROUP-LEVEL PARAMETER SUMMARY")
print("="*70)
print(summary)
print("\nFor convergence: Rhat < 1.01, ESS > 400")

In [None]:
# Trace plots for key parameters
var_names = ['mu_v_plus', 'mu_v_minus', 'mu_v0', 'mu_g']
az.plot_trace(
    trace,
    var_names=var_names,
    compact=True,
    figsize=(12, 8)
)
plt.suptitle('MCMC Traces: Context Independence Parameters', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 6. Parameter Recovery

Evaluate whether the model recovers the true parameters used to generate the data.

In [None]:
# Extract posterior means
posterior = trace.posterior

estimates = {
    'B': posterior['mu_B'].mean().item(),
    't0': posterior['mu_t0'].mean().item(),
    'v_plus': posterior['mu_v_plus'].mean().item(),
    'v_minus': posterior['mu_v_minus'].mean().item(),
    'v0': posterior['mu_v0'].mean().item(),
    'g': posterior['mu_g'].mean().item(),
}

true_values = {
    'B': true_params['B'].mean(),
    't0': true_params['t0'].mean(),
    'v_plus': true_params['v_plus'].mean(),
    'v_minus': true_params['v_minus'].mean(),
    'v0': true_params['v0'].mean(),
    'g': true_params['g'].mean(),
}

# Create recovery plot
fig, ax = plt.subplots(figsize=(8, 8))

params = list(estimates.keys())
est_vals = [estimates[p] for p in params]
true_vals = [true_values[p] for p in params]

ax.scatter(true_vals, est_vals, s=100, alpha=0.7)
for i, param in enumerate(params):
    ax.annotate(param, (true_vals[i], est_vals[i]), 
                xytext=(5, 5), textcoords='offset points', fontsize=11)

lims = [min(true_vals + est_vals) * 0.9, max(true_vals + est_vals) * 1.1]
ax.plot(lims, lims, 'k--', alpha=0.5, linewidth=2, label='Perfect recovery')

r = np.corrcoef(true_vals, est_vals)[0, 1]

ax.set_xlabel('True Parameter Value', fontsize=12)
ax.set_ylabel('Estimated Parameter Value', fontsize=12)
ax.set_title(f'Parameter Recovery (r = {r:.3f})', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print comparison table
print("PARAMETER RECOVERY")
print("="*70)
print(f"{'Parameter':<15} {'True Value':<15} {'Estimated':<15} {'Error %':<15}")
print("-"*70)
for param in params:
    true = true_values[param]
    est = estimates[param]
    error = abs(est - true) / true * 100
    print(f"{param:<15} {true:<15.3f} {est:<15.3f} {error:<15.1f}")
print("="*70)
print(f"\nCorrelation: {r:.3f}")

## 7. Parameter Interpretation

Extract and interpret the main parameters according to Weigard et al. (2023).

In [None]:
# Convert probit-scale failure probabilities
pgf_samples = stats.norm.cdf(posterior['mu_pgf_probit'].values.flatten())
ptf_samples = stats.norm.cdf(posterior['mu_ptf_probit'].values.flatten())

ssrt_samples = (posterior['mu_stop_mu'] + posterior.get('mu_stop_tau', 0)).values.flatten()

print("KEY PARAMETER ESTIMATES (Group Level)")
print("="*70)

print("\nGO PROCESS:")
print(f"  Threshold (B):                 {estimates['B']:.3f}")
print(f"  Non-decision time (t0):        {estimates['t0']:.3f} s")
print(f"  Matching drift rate (v+):      {estimates['v_plus']:.3f}")
print(f"  Mismatching drift rate (v-):   {estimates['v_minus']:.3f}")

print("\nCONTEXT INDEPENDENCE PARAMETERS (Weigard et al., 2023):")
print(f"  Processing speed (v0):         {estimates['v0']:.3f}")
print(f"  Perceptual growth rate (g):    {estimates['g']:.3f}")

print("\nSTOP PROCESS:")
print(f"  Stop-Signal RT (SSRT):         {np.mean(ssrt_samples):.3f} s")

print("\nFAILURE PROCESSES:")
print(f"  Go failure (pgf):              {np.mean(pgf_samples):.3f} ({np.mean(pgf_samples)*100:.1f}%)")
print(f"  Trigger failure (ptf):         {np.mean(ptf_samples):.3f} ({np.mean(ptf_samples)*100:.1f}%)")
print("="*70)

## 8. Posterior Distributions

In [None]:
# Plot posteriors for key parameters
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

plot_params = [
    ('mu_v_plus', 'v_plus', 'Matching Rate (v+)'),
    ('mu_v_minus', 'v_minus', 'Mismatching Rate (v-)'),
    ('mu_v0', 'v0', 'Processing Speed (v0)'),
    ('mu_g', 'g', 'Perceptual Growth (g)'),
    ('mu_stop_mu', 'ssrt_mu', 'Stop Process Î¼'),
    ('mu_B', 'B', 'Threshold (B)'),
]

for idx, (post_name, true_name, title) in enumerate(plot_params):
    ax = axes[idx]
    
    samples = posterior[post_name].values.flatten()
    ax.hist(samples, bins=30, density=True, alpha=0.7, edgecolor='black')
    
    true_val = true_params[true_name].mean()
    ax.axvline(true_val, color='red', linewidth=3, linestyle='--', label=f'True: {true_val:.2f}')
    
    post_mean = samples.mean()
    ax.axvline(post_mean, color='blue', linewidth=2, linestyle='-', label=f'Est: {post_mean:.2f}')
    
    ax.set_title(title, fontsize=11, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Density')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. SSD-Dependent Drift Rates (Weigard et al., 2023)

Demonstrate the key innovation of the RDEX-ABCD model: drift rates that grow linearly with SSD.

In [None]:
# Get posterior mean parameters
v_plus_est = estimates['v_plus']
v_minus_est = estimates['v_minus']
v0_est = estimates['v0']
g_est = estimates['g']

ssd_range = np.linspace(0, 0.5, 100)

def compute_rates(ssd, v_plus, v_minus, v0, g):
    match_disc = np.minimum(g * ssd, v_plus - v0)
    v_match = v0 + match_disc
    
    mismatch_disc = np.minimum(g * ssd, np.abs(v_minus - v0))
    if v_minus < v0:
        v_mismatch = v0 - mismatch_disc
    else:
        v_mismatch = v0 + mismatch_disc
    
    return v_match, v_mismatch

v_match_curve, v_mismatch_curve = compute_rates(ssd_range, v_plus_est, v_minus_est, v0_est, g_est)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Left: Drift rates by SSD
ax = axes[0]
ax.plot(ssd_range, v_match_curve, linewidth=3, label='Matching accumulator (v+)', color='blue')
ax.plot(ssd_range, v_mismatch_curve, linewidth=3, label='Mismatching accumulator (v-)', color='red')
ax.axhline(v0_est, linestyle='--', color='gray', linewidth=2, label=f'Processing speed (v0 = {v0_est:.2f})')
ax.axhline(v_plus_est, linestyle=':', color='blue', linewidth=2, alpha=0.5, label=f'Asymptote v+ = {v_plus_est:.2f}')
ax.axhline(v_minus_est, linestyle=':', color='red', linewidth=2, alpha=0.5, label=f'Asymptote v- = {v_minus_est:.2f}')
ax.set_xlabel('Stop-Signal Delay (s)', fontsize=12)
ax.set_ylabel('Drift Rate', fontsize=12)
ax.set_title(f'SSD-Dependent Drift Rates (g = {g_est:.2f})', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Middle: Processing speed component
ax = axes[1]
ax.axhline(v0_est, linewidth=3, color='purple', label='Processing speed (v0)')
ax.fill_between([0, 0.5], [v0_est, v0_est], alpha=0.3, color='purple')
ax.set_xlabel('Stop-Signal Delay (s)', fontsize=12)
ax.set_ylabel('Rate', fontsize=12)
ax.set_title('Processing Speed Component', fontsize=13, fontweight='bold')
ax.set_ylim([0, max(v_plus_est, v0_est) * 1.2])
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Discrimination component
ax = axes[2]
discrimination_match = np.minimum(g_est * ssd_range, v_plus_est - v0_est)
discrimination_mismatch = np.minimum(g_est * ssd_range, np.abs(v_minus_est - v0_est))
ax.plot(ssd_range, discrimination_match, linewidth=3, color='blue', label='Match discrimination')
ax.plot(ssd_range, discrimination_mismatch, linewidth=3, color='red', label='Mismatch discrimination')
ax.set_xlabel('Stop-Signal Delay (s)', fontsize=12)
ax.set_ylabel('Discrimination', fontsize=12)
ax.set_title('Discrimination Component', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("SSD-DEPENDENT DRIFT RATE MODEL (Weigard et al., 2023)")
print("="*70)
print(f"At SSD=0: Both rates = v0 = {v0_est:.2f} (no discrimination)")
print(f"As SSD increases: Discrimination grows at rate g = {g_est:.2f}")
print(f"At long SSD: Rates asymptote to v+ = {v_plus_est:.2f}, v- = {v_minus_est:.2f}")
print("="*70)

## 10. Comparison to Standard Models

Explanation of why RDEX-ABCD is necessary for ABCD data (Weigard et al., 2023).

In [None]:
print("WHY THE RDEX-ABCD MODEL IS NECESSARY (Weigard et al., 2023)")
print("="*70)

print("\nStandard stop-signal models:")
print("  - Assume constant drift rates across SSDs")
print("  - Cannot explain increasing choice accuracy with SSD")
print("  - Produce biased SSRT estimates")
print("  - Confound processing speed with inhibition")

print("\nRDEX-ABCD model:")
print("  - Models drift rate growth with SSD via v0 and g parameters")
print("  - Accurately captures choice accuracy pattern")
print("  - Provides unbiased SSRT estimates")
print("  - Separates processing speed, perceptual efficiency, and inhibition")
print("  - Distinguishes attention lapses (ptf) from inhibition deficits")

print("\nKey findings from this analysis:")
print(f"  - Processing speed (v0={v0_est:.2f}) enables some correct responses at SSD=0")
print(f"  - Perceptual growth (g={g_est:.2f}) determines discrimination improvement rate")
print(f"  - SSRT ({np.mean(ssrt_samples):.3f}s) reflects inhibition after accounting for v0, g, ptf")
print(f"  - Trigger failures ({np.mean(ptf_samples)*100:.1f}%) represent attention lapses")
print("="*70)

## Summary

This tutorial demonstrated:

1. **Package functionality**: The `pydmc` implementation of RDEX-ABCD successfully fits hierarchical models to stop-signal data
2. **Parameter recovery**: The model accurately recovers known parameters from synthetic data
3. **Alignment with Weigard et al. (2023)**:
   - Captures context independence violations (choice accuracy increases with SSD)
   - Implements SSD-dependent drift rates via v0 and g parameters
   - Separates processing speed, perceptual growth, and inhibition
   - Accounts for trigger failures (attention lapses)

### For real ABCD data analyses:

```python
from pydmc import RDEXABCDModel
import pandas as pd

# Load data with columns: subject, stimulus, response, rt, ssd
data = pd.read_csv('abcd_stop_signal.csv')

# Fit hierarchical model
model = RDEXABCDModel(use_hierarchical=True)
trace = model.fit(data, draws=1000, tune=1000, chains=4)

# Check convergence
summary = model.summary()

# Extract parameters
ssrt = trace.posterior['ssrt']
ptf = trace.posterior['ptf']
v0 = trace.posterior['v0']
g = trace.posterior['g']
```

### Reference

Weigard, A., Matzke, D., Tanis, C., & Heathcote, A. (2023). A cognitive process modeling framework for the ABCD study stop-signal task. *Developmental Cognitive Neuroscience, 59*, 101191.