# Basic pydmc Example: WALD Stop-Signal Model

This notebook demonstrates the basic usage of the pydmc package for fitting a WALD stop-signal model to response inhibition data.

## Overview

The WALD stop-signal model is a diffusion model for stop-signal tasks that estimates:
- Response thresholds (B)
- Non-decision time (t0)
- Drift rates for correct/incorrect responses (vT, vF)
- Go and trigger failure probabilities (gf, tf)
- And other parameters related to the stop process

## 1. Setup and Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Import pydmc
from pydmc import WaldStopSignalModel

# Set random seed for reproducibility
np.random.seed(42)

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 2. Generate Synthetic Data

For this example, we'll generate synthetic stop-signal task data. In a real analysis, you would load your own data.

### Expected Data Format

Your data should be a pandas DataFrame with the following columns:
- `subject`: Subject identifier
- `stimulus`: Stimulus type (0=left, 1=right)
- `response`: Response made (0=no response, 1=left, 2=right)
- `rt`: Reaction time in seconds
- `ssd`: Stop-signal delay (NaN for go trials, value for stop trials)

In [None]:
def generate_synthetic_data(n_subjects=3, n_trials_per_subject=200):
    """
    Generate synthetic stop-signal task data.
    
    This is a simplified data generator for demonstration purposes.
    Real data will have more complex patterns.
    """
    data = []
    
    for subj in range(1, n_subjects + 1):
        for trial in range(n_trials_per_subject):
            # Randomly determine if this is a stop trial (25% of trials)
            is_stop = np.random.rand() < 0.25
            
            # Random stimulus (0=left, 1=right)
            stimulus = np.random.choice([0, 1])
            
            if is_stop:
                # Stop trial
                ssd = np.random.uniform(0.1, 0.4)  # Stop signal delay
                
                # 50% chance of successful stop
                if np.random.rand() < 0.5:
                    response = 0  # No response (successfully stopped)
                    rt = np.nan  # No RT for successful stops
                else:
                    # Failed to stop - generate RT
                    rt = np.random.gamma(2, 0.15) + 0.3  # Faster RTs
                    # Response with some errors
                    if np.random.rand() < 0.9:  # 90% correct
                        response = stimulus + 1
                    else:
                        response = 2 - stimulus
            else:
                # Go trial
                ssd = np.nan
                
                # Generate RT (ex-Gaussian distribution)
                rt = np.random.gamma(3, 0.1) + 0.35 + np.random.exponential(0.05)
                
                # Determine response (95% correct on go trials)
                if np.random.rand() < 0.95:
                    response = stimulus + 1  # Correct response
                else:
                    response = 2 - stimulus  # Incorrect response
            
            data.append({
                'subject': f'S{subj:02d}',
                'trial': trial + 1,
                'stimulus': stimulus,
                'response': response,
                'rt': rt,
                'ssd': ssd
            })
    
    return pd.DataFrame(data)

# Generate data
data = generate_synthetic_data(n_subjects=3, n_trials_per_subject=200)

print(f"Generated {len(data)} trials from {data['subject'].nunique()} subjects")
print(f"Go trials: {data['ssd'].isna().sum()}")
print(f"Stop trials: {(~data['ssd'].isna()).sum()}")
print("\nFirst few rows:")
data.head(10)

## 3. Explore the Data

In [None]:
# Basic data summary
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# RT distribution for go trials
go_trials = data[data['ssd'].isna() & (data['rt'].notna())]
axes[0].hist(go_trials['rt'], bins=30, alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Reaction Time (s)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('RT Distribution (Go Trials)')
axes[0].axvline(go_trials['rt'].mean(), color='red', linestyle='--', label=f'Mean: {go_trials["rt"].mean():.3f}s')
axes[0].legend()

# Stop-signal performance by subject
stop_trials = data[~data['ssd'].isna()]
stop_success = stop_trials.groupby('subject').apply(lambda x: (x['response'] == 0).mean()).reset_index()
stop_success.columns = ['subject', 'p_stop_success']

axes[1].bar(stop_success['subject'], stop_success['p_stop_success'], alpha=0.7, edgecolor='black')
axes[1].set_xlabel('Subject')
axes[1].set_ylabel('P(Successful Stop)')
axes[1].set_title('Stop Success Rate by Subject')
axes[1].set_ylim([0, 1])
axes[1].axhline(0.5, color='red', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print("\nData Summary:")
print(f"Mean RT (go trials): {go_trials['rt'].mean():.3f} s")
print(f"SD RT (go trials): {go_trials['rt'].std():.3f} s")
print(f"Overall stop success rate: {(stop_trials['response'] == 0).mean():.2%}")

## 4. Create and Fit the Model

Now we'll create a hierarchical WALD stop-signal model and fit it to the data.

**Note**: This may take several minutes depending on your computer. For a quick test, reduce the number of iterations.

In [None]:
# Create the model
model = WaldStopSignalModel(use_hierarchical=True)

print("Model created successfully!")
print(f"Using backend: {model.backend.backend_name}")

In [None]:
# Fit the model
# For a quick test, use fewer iterations
# For real analysis, use chains=4, iter=2000, warmup=1000

fit = model.fit(
    data,
    chains=2,           # Number of MCMC chains (use 4 for real analysis)
    iter=500,           # Total iterations per chain (use 2000 for real analysis)
    warmup=250,         # Warmup iterations (use 1000 for real analysis)
    cores=2,            # Number of parallel cores
    show_progress=True
)

print("\nModel fitted successfully!")

## 5. Examine Model Results

In [None]:
# Print model summary
print("Model Summary:")
print("=" * 50)
model.summary()

In [None]:
# Get parameter estimates
estimates = model.get_parameter_estimates()

# Display group-level parameters (hierarchical means)
print("\nGroup-Level Parameter Estimates:")
print("=" * 70)

param_names = {
    'mu_params[1]': 'B (threshold)',
    'mu_params[2]': 't0 (non-decision time)',
    'mu_params[3]': 'gf (go failure)',
    'mu_params[8]': 'vT (drift true)',
    'mu_params[9]': 'vF (drift false)',
    'mu_params[7]': 'tf (trigger failure)'
}

for param_key, param_name in param_names.items():
    if param_key in estimates:
        est = estimates[param_key]
        print(f"{param_name:30s}: {est['mean']:8.3f} (SD: {est['std']:6.3f}, 95% CI: [{est['q2.5']:6.3f}, {est['q97.5']:6.3f}])")

## 6. Visualize Results

### 6.1 MCMC Trace Plots

Check convergence by examining trace plots for group-level parameters.

In [None]:
# Plot MCMC traces
model.plot_traces(params=['mu_params', 'sigma_params'], figsize=(14, 10))

### 6.2 Posterior Predictive Check

Compare observed data to model predictions.

In [None]:
# Posterior predictive check
model.posterior_predictive_check(figsize=(14, 5))

## 7. Save Results

In [None]:
# Save results to JSON file
model.save_results('../models/example_model_results.json', include_samples=False)

print("Results saved successfully!")

## 8. Load Results (Optional)

You can load previously saved results:

In [None]:
# Load results
loaded_results = model.load_results('../models/example_model_results.json')

print("\nLoaded results summary:")
print(f"Model type: {loaded_results['model_type']}")
print(f"Number of subjects: {loaded_results['data_summary']['n_subjects']}")
print(f"Number of trials: {loaded_results['data_summary']['n_trials']}")

## Next Steps

Now that you've seen the basic workflow, you can:

1. **Load your own data**: Replace the synthetic data with your actual stop-signal task data
2. **Increase sampling**: Use more chains and iterations for better convergence (e.g., chains=4, iter=2000)
3. **Compare models**: Try both hierarchical and individual-level models
4. **Customize analysis**: Extract specific parameters and create custom visualizations
5. **Model validation**: Perform more extensive posterior predictive checks

For more information, see the pydmc documentation.