In [11]:
import pandas as pd
import numpy as np
from utils import *
import pytensor.tensor as pt  # Import PyTensor (Theano backend)

In [12]:
df, scaling_factors = read_dd_data('../data/PD_data', standardize=True, reduced_data=False)

df_train, df_test = split_train_test(df, test_size=0.3) # consistent across runs
df = df_test

participants = df["participant"].unique()

Train set size: 6833 trials
Test set size: 2916 trials
Test set percentage: 29.9%


In [13]:
df

Unnamed: 0,rcert,runcert,event_prob,choice,condition,rt,odds,participant,participant_idx
140,0.200650,0.034843,0.25,0.0,1.0,-2.309281,3.000000,bh348lli7,0
141,0.200650,0.017079,0.75,0.0,1.0,-2.599964,0.333333,bh348lli7,0
142,0.404075,0.176955,0.10,0.0,1.0,-2.359636,9.000000,bh348lli7,0
143,0.200650,0.017079,0.50,0.0,1.0,-2.956961,1.000000,bh348lli7,0
144,2.031478,0.114781,0.25,0.0,1.0,-3.594024,3.000000,bh348lli7,0
...,...,...,...,...,...,...,...,...,...
9795,-0.409626,-0.178325,0.10,0.0,2.0,-0.584245,9.000000,u3yyfob1i,48
9796,-0.409626,-0.071741,0.25,1.0,2.0,-0.233874,3.000000,u3yyfob1i,48
9797,-0.206201,-0.017561,0.10,1.0,2.0,-0.185090,9.000000,u3yyfob1i,48
9798,-2.037028,-0.178325,0.75,0.0,2.0,0.478903,0.333333,u3yyfob1i,48


In [14]:
trace_linear = az.from_netcdf("models/linear_model.nc")
trace_linear = trace_linear.posterior

trace_quadratic = az.from_netcdf("models/quadratic_model.nc")
trace_quadratic = trace_quadratic.posterior

In [15]:
def predict_choice_and_rt_per_participant(trace, df, config_rt: str = "default",
                                          scaling_factors: dict = None):
    """
    Predict choice probabilities and reaction times for each participant using posterior parameters.
    Compare predictions against ground truth data.
    
    Parameters:
    -----------
    trace : xarray.Dataset
        Dataset containing posterior samples
    df : pandas.DataFrame
        DataFrame containing the experimental data with columns:
        rcert, runcert, event_prob, choice, rt, participant_idx
        
    Returns:
    --------
    predictions : dict
        Dictionary containing predictions for each participant with keys:
        - choice_prob: Predicted choice probabilities
        - choice_accuracy: Accuracy of predicted choices vs actual choices
        - rt_pred: Mean predicted reaction times
        - rt_pred_standardized: Standardized predicted reaction times
        - rt_pred_full_value: Full distribution of predicted reaction times
        - rt_pred_full_std: Standard deviation of predicted reaction times
        - rt_true_unstandardized: Actual reaction times (unstandardized)
        - rt_true_standardized: Actual reaction times (standardized)
        - rt_mse: Mean squared error of reaction time predictions
        - rt_std: Standard deviation of predicted reaction times
    
    decision_difficulty_choice : dict
        Dictionary containing decision difficulty metrics for each participant with keys:
        - value_diff: Difference in subjective values between options
        - decision_difficulty: Absolute value difference between options
        
    posterior_rt_uncertainty : dict
        Dictionary containing uncertainty estimates for each participant with keys:
        - uncertainty_standardized: Standardized uncertainty in reaction time predictions
        - uncertainty_unstandardized: Unstandardized uncertainty in reaction time predictions
    """
    predictions = {}
    decision_difficulty_choice = {}
    posterior_rt_uncertainty = {}
    
    # For each unique participant
    for p_idx in df['participant_idx'].unique():
        # Get data for this participant
        p_data = df[df['participant_idx'] == p_idx]
        p_mask = df['participant_idx'] == p_idx
        
        # Extract posterior samples for this participant
        # extract median across chains and draws
        # TODO: use the whole values for uncertainty estimation
        k_samples = trace.k.median(dim=['chain', 'draw']).values
        k_samples = k_samples[p_idx]
        beta_samples = trace.beta.median(dim=['chain', 'draw']).values
        beta_samples = beta_samples[p_idx]
        beta0_samples = trace.beta0.median(dim=['chain', 'draw']).values
        beta0_samples = beta0_samples[p_idx]
        beta1_samples = trace.beta1.median(dim=['chain', 'draw']).values
        beta1_samples = beta1_samples[p_idx]
        sigma_rt_samples = trace.sigma_RT.median(dim=['chain', 'draw']).values
        sigma_rt_samples = sigma_rt_samples[p_idx]

        if config_rt == "quadratic":
            beta2_samples = trace.beta2.median(dim=['chain', 'draw']).values
            beta2_samples = beta2_samples[p_idx]

        if config_rt == "sigma-per-trial":
            sigma_rt_trial_samples = trace.sigma_RT_trial.median(dim=['chain', 'draw']).values[p_mask]
        else:
            sigma_rt_trial_samples = None

        lambda_samples = trace.loss_aversion.median(dim=['chain', 'draw']).values
        lambda_samples = lambda_samples[p_idx]

        gains_mask = p_data['condition'] == 1

        SV_certain = np.where(gains_mask,
                            p_data['rcert'].values,
                            -lambda_samples * p_data['rcert'].values)
        
        SV_uncertain = np.where(gains_mask,
                              p_data['runcert'].values / (1 + k_samples * p_data['event_prob'].values),
                              -lambda_samples * p_data['runcert'].values / (1 + k_samples * p_data['event_prob'].values))
        
        # Calculate value differences
        value_diff = SV_uncertain - SV_certain
        
        # Predict choice probabilities using logistic function
        exp = beta_samples * value_diff
        exp = np.clip(exp, -10, 10)  # Clip to avoid numerical issues
        choice_probs = 1 / (1 + np.exp(-exp))
        
        # Calculate decision difficulty for RT prediction
        decision_difficulty =  np.abs(value_diff)
        # Predict RTs
        if config_rt == 'quadratic':
            rt_mean = beta0_samples + beta1_samples * decision_difficulty + beta2_samples * decision_difficulty**2
        else:
            rt_mean = beta0_samples + beta1_samples * decision_difficulty

        if sigma_rt_trial_samples is not None:
            standardized_log_rts_samples = np.random.normal(loc=rt_mean[:,np.newaxis], scale=sigma_rt_trial_samples[:,np.newaxis], size=(rt_mean.shape[0], 50))
        else:
            standardized_log_rts_samples = np.random.normal(loc=rt_mean[:,np.newaxis], scale=sigma_rt_samples, size=(rt_mean.shape[0], 50))
        
        log_rt = (standardized_log_rts_samples * scaling_factors["rt_std"]) + scaling_factors["rt_mean"]
        rt_pred = np.exp(log_rt)

        rt_pred_mean = np.mean(rt_pred, axis=1)

        # calculate choices
        choices = (choice_probs > 0.5).astype(int)

        # calculate accuracy
        accuracy = np.mean(choices == p_data['choice'].values)

        # calculate rt mse
        rt_true = np.exp((p_data['rt'].values * scaling_factors["rt_std"]) + scaling_factors["rt_mean"])
        rt_mse = np.mean((rt_pred_mean - rt_true) ** 2)

        predictions[p_idx] = {
            'choice_prob': choice_probs,
            'choice_accuracy': accuracy,
            'rt_pred': rt_pred_mean,
            'rt_pred_standardized': np.mean(standardized_log_rts_samples,axis=1),
            'rt_pred_full_value': rt_pred,
            'rt_pred_full_std': sigma_rt_trial_samples,
            'rt_true_unstandardized': rt_true,
            'rt_true_standardized': p_data['rt'].values,
            'rt_mse': rt_mse,
            'rt_std': np.std(rt_pred)
        }

        decision_difficulty_choice[p_idx] = {
            'value_diff': value_diff,
            'decision_difficulty': decision_difficulty,
            
        }

        if sigma_rt_trial_samples is not None:
            log_not_standardized_sigma_rt_trial_samples = np.exp(sigma_rt_trial_samples * scaling_factors['rt_std'])
            original_space_sigma_rt_trial = normal_to_lognormal_std(rt_pred_mean, log_not_standardized_sigma_rt_trial_samples)

        if sigma_rt_trial_samples is not None:
            posterior_rt_uncertainty[p_idx] = {
                'uncertainty_standardized': sigma_rt_trial_samples,
                'uncertainty_unstandardized': original_space_sigma_rt_trial
                
            }

    return predictions, decision_difficulty_choice, posterior_rt_uncertainty


In [16]:

predictions_linear, decision_difficulty_choice_linear, posterior_rt_uncertainty_linear = predict_choice_and_rt_per_participant(trace_linear, df, scaling_factors=scaling_factors, config_rt="linear")
accuracies_linear = [pred['choice_accuracy'] for pred in predictions_linear.values()]
print(accuracies_linear)

predictions_quadratic, decision_difficulty_choice_quadratic, posterior_rt_uncertainty_quadratic = predict_choice_and_rt_per_participant(trace_quadratic, df, scaling_factors=scaling_factors, config_rt="quadratic")
accuracies_quadratic = [pred['choice_accuracy'] for pred in predictions_quadratic.values()]
print(accuracies_quadratic)

print("Predictions for linear model:")
for p_idx, pred in predictions_linear.items():
    print(f"Participant {p_idx}:")
    print(f"  Choice Accuracy: {pred['choice_accuracy']}")
    print(f"  RT MSE: {pred['rt_mse']}")
    print("---------------------")

print("Predictions for quadratic model:")
for p_idx, pred in predictions_quadratic.items():
    print(f"Participant {p_idx}:")
    print(f"  Choice Accuracy: {pred['choice_accuracy']}")
    print(f"  RT MSE: {pred['rt_mse']}")
    print("---------------------")

IndexError: index 30 is out of bounds for axis 0 with size 30

In [5]:
USING_SIGMA_PER_TRIAL = False
if USING_SIGMA_PER_TRIAL:
    # Create box plots of uncertainty estimates for 5 random participants
    plt.figure(figsize=(12, 6))

    # Get 5 random participant indices
    participant_indices = np.random.choice(list(posterior_rt_uncertainty_linear.keys()), size=10, replace=False)

    # Get uncertainty values for each participant
    value_diffs = [posterior_rt_uncertainty_linear[idx]['uncertainty_standardized'] for idx in participant_indices]

    # Create box plot
    plt.boxplot(value_diffs, labels=[f'P{idx}' for idx in participant_indices])

    plt.xlabel('Participant')
    plt.ylabel('RT Uncertainty')
    plt.title('Distribution of RT Uncertainty Estimates for random participants across all trials (standardized space)')

    plt.show()


In [None]:
import matplotlib.pyplot as plt

# Plot for linear model
plt.figure(figsize=(10, 6))
plt.hist(accuracies_linear, bins=20, edgecolor='black')
plt.xlabel('Choice Accuracy')
plt.ylabel('Number of Participants')
plt.title('Distribution of Choice Accuracy Across Participants (Linear Model)')
plt.axvline(np.mean(accuracies_linear), color='red', linestyle='dashed', linewidth=2, 
            label=f'Mean = {np.mean(accuracies_linear):.3f}')
plt.legend()
plt.savefig('plots/linear-model-all-data/choice_accuracy_dist_linear.png')
plt.close()

# Plot for quadratic model 
plt.figure(figsize=(10, 6))
plt.hist(accuracies_quadratic, bins=20, edgecolor='black')
plt.xlabel('Choice Accuracy')
plt.ylabel('Number of Participants')
plt.title('Distribution of Choice Accuracy Across Participants (Quadratic Model)')
plt.axvline(np.mean(accuracies_quadratic), color='red', linestyle='dashed', linewidth=2,
            label=f'Mean = {np.mean(accuracies_quadratic):.3f}')
plt.legend()
plt.savefig('plots/quadratic-model-all-data/choice_accuracy_dist_quadratic.png')
plt.close()


In [None]:
# Randomly select one participant
random_participant = np.random.choice(list(predictions_linear.keys()))

# Plot for linear model
plt.figure(figsize=(10, 6))

# Get predicted and true RTs for this participant
pred_rt = predictions_linear[random_participant]['rt_pred_standardized']
true_rt = predictions_linear[random_participant]['rt_true_standardized']
if USING_SIGMA_PER_TRIAL:
    pred_rt_std = posterior_rt_uncertainty_linear[random_participant]['uncertainty_standardized']

if USING_SIGMA_PER_TRIAL:
    # Plot error bars for each point
    plt.errorbar(true_rt, pred_rt, yerr=pred_rt_std, fmt='o', alpha=0.3, 
                markersize=4, elinewidth=1, capsize=2)
else:
    plt.scatter(true_rt, pred_rt, alpha=0.3)

# Add diagonal line representing perfect prediction
min_val = min(min(true_rt), min(pred_rt))
max_val = max(max(true_rt), max(pred_rt))
plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect Prediction')

plt.xlabel('True RT (standardized)')
plt.ylabel('Predicted RT (standardized)') 
plt.title(f'Predicted vs True RT for Participant {random_participant} (Linear Model)\nwith Standard Deviation')
plt.legend()

# Add correlation coefficient
correlation = np.corrcoef(true_rt, pred_rt)[0,1]
plt.text(0.05, 0.95, f'Correlation: {correlation:.3f}', 
         transform=plt.gca().transAxes, 
         bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('plots/linear-model-all-data/rt_predictions_linear.png')
plt.close()

# Plot for quadratic model
plt.figure(figsize=(10, 6))

# Get predicted and true RTs for this participant
pred_rt = predictions_quadratic[random_participant]['rt_pred_standardized']
true_rt = predictions_quadratic[random_participant]['rt_true_standardized']
if USING_SIGMA_PER_TRIAL:
    pred_rt_std = posterior_rt_uncertainty_quadratic[random_participant]['uncertainty_standardized']

if USING_SIGMA_PER_TRIAL:
    # Plot error bars for each point
    plt.errorbar(true_rt, pred_rt, yerr=pred_rt_std, fmt='o', alpha=0.3, 
                markersize=4, elinewidth=1, capsize=2)
else:
    plt.scatter(true_rt, pred_rt, alpha=0.3)

# Add diagonal line representing perfect prediction
min_val = min(min(true_rt), min(pred_rt))
max_val = max(max(true_rt), max(pred_rt))
plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect Prediction')

plt.xlabel('True RT (standardized)')
plt.ylabel('Predicted RT (standardized)') 
plt.title(f'Predicted vs True RT for Participant {random_participant} (Quadratic Model)\nwith Standard Deviation')
plt.legend()

# Add correlation coefficient
correlation = np.corrcoef(true_rt, pred_rt)[0,1]
plt.text(0.05, 0.95, f'Correlation: {correlation:.3f}', 
         transform=plt.gca().transAxes, 
         bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('plots/quadratic-model-all-data/rt_predictions_quadratic.png')
plt.close()


In [None]:
# Create box plot of choice accuracies for linear model
plt.figure(figsize=(10, 6))
plt.boxplot([pred['choice_accuracy'] for pred in predictions_linear.values()])
plt.ylabel('Choice Accuracy')
plt.title('Distribution of Choice Accuracy Across All Participants (Linear Model)')

# Add horizontal line at 0.5 for chance level
plt.axhline(y=0.5, color='r', linestyle='--', alpha=0.3, label='Chance Level')
plt.legend()

plt.tight_layout()
plt.savefig('plots/linear-model-all-data/choice_accuracy_distribution_linear.png')
plt.close()

# Create box plot of choice accuracies for quadratic model
plt.figure(figsize=(10, 6))
plt.boxplot([pred['choice_accuracy'] for pred in predictions_quadratic.values()])
plt.ylabel('Choice Accuracy')
plt.title('Distribution of Choice Accuracy Across All Participants (Quadratic Model)')

# Add horizontal line at 0.5 for chance level
plt.axhline(y=0.5, color='r', linestyle='--', alpha=0.3, label='Chance Level')
plt.legend()

plt.tight_layout()
plt.savefig('plots/quadratic-model-all-data/choice_accuracy_distribution_quadratic.png')
plt.close()


In [None]:
# # Extract value differences for each participant and randomly select 5
# import random
# participant_indices = random.sample(list(decision_difficulty_choice.keys()), 8)
# value_diffs = [decision_difficulty_choice[idx]['decision_difficulty'] for idx in participant_indices]

# # Create box plot
# plt.figure(figsize=(12, 6))
# plt.boxplot(value_diffs, labels=[f'P{idx}' for idx in participant_indices])
# plt.yscale('log')  # Set y-axis to log scale
# plt.xticks(rotation=45)
# plt.xlabel('Participants')
# plt.ylabel('Value Difference (Decision Difficulty)')
# plt.title('Distribution of Decision Difficulties for 5 Random Participants')

# # Add horizontal lines at y=0 and y=1 for reference
# plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
# plt.axhline(y=1, color='b', linestyle='--', alpha=0.3, label='Threshold at 1')

# # Adjust layout to prevent label cutoff
# plt.tight_layout()
# plt.legend()
# plt.show()


In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Flatten all decision difficulties into a single list
# all_difficulties = []
# for idx in decision_difficulty_choice.keys():
#     all_difficulties.extend(decision_difficulty_choice[idx]['decision_difficulty'])

# # Create box plot
# plt.figure(figsize=(10, 6))
# plt.boxplot(all_difficulties)
# plt.yscale('log')  # Set y-axis to log scale since values can vary widely
# plt.ylabel('Decision Difficulty')
# plt.title('Distribution of All Decision Difficulties')

# # Add horizontal line at y=1 for reference
# plt.axhline(y=1, color='r', linestyle='--', alpha=0.3, label='Threshold at 1')

# plt.legend()
# plt.tight_layout()
# plt.show()

# # Calculate summary statistics for all decision difficulties
# mean_difficulty = np.mean(all_difficulties)
# median_difficulty = np.median(all_difficulties)
# std_difficulty = np.std(all_difficulties)
# var_difficulty = np.var(all_difficulties)

# print(f"Summary Statistics for Decision Difficulties:")
# print(f"Mean: {mean_difficulty:.3f}")
# print(f"Median: {median_difficulty:.3f}") 
# print(f"Standard Deviation: {std_difficulty:.3f}")
# print(f"Variance: {var_difficulty:.3f}")

