In [1]:
# =============================================================================
# 1. IMPORTS
# =============================================================================
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from scipy.stats import ttest_ind, shapiro, levene, mannwhitneyu

# =============================================================================
# 2. UTILITY FUNCTIONS
# =============================================================================

def read_csv_files(directory):
    """
    Read CSV files from a directory that start with "sub" and end with ".csv".
    
    Parameters:
        directory (str): Path to the folder containing CSV files.
    
    Returns:
        list of pd.DataFrame: List of DataFrames read from the CSV files.
    """
    csv_files = []
    print(f"Checking files in directory: {directory}")
    for filename in os.listdir(directory):
        print(f"Found file: {filename}")
        if filename.startswith("sub") and filename.endswith(".csv"):
            print(f"Reading CSV file: {filename}")
            file_path = os.path.join(directory, filename)
            df = pd.read_csv(file_path)
            csv_files.append(df)
        else:
            print(f"Skipping file: {filename}")
    return csv_files

def summarySE(data, measurevar, groupvars, na_rm=True):
    """
    Compute summary statistics (mean, count, std, standard error) for a measure variable.
    
    Parameters:
        data (pd.DataFrame): Data source.
        measurevar (str): Column name of the measure variable.
        groupvars (list or str): Column name(s) to group by.
        na_rm (bool): If True, remove NA values in measurevar.
    
    Returns:
        pd.DataFrame: Summary statistics with columns: groupvars, 'mean', 'count', 'std', 'stderr'
    """
    if na_rm:
        data = data.dropna(subset=[measurevar])
        
    summary = data.groupby(groupvars).agg(
        mean=(measurevar, 'mean'),
        count=(measurevar, 'size'),
        std=(measurevar, 'std')
    ).reset_index()
    
    summary['stderr'] = summary['std'] / np.sqrt(summary['count'])
    return summary

def draw_significance_bracket(ax, x1, x2, y, text, height_percent=0.02):
    """
    Draw a significance bracket on an axis.
    
    Parameters:
        ax (matplotlib.axes.Axes): Axis on which to draw.
        x1, x2 (float): x-positions for the left and right ends of the bracket.
        y (float): y-position for the base of the bracket.
        text (str): Text (e.g. significance stars) to display above the bracket.
        height_percent (float): Fraction of y to use for bracket height.
    """
    height = y * height_percent
    ax.plot([x1, x1, x2, x2], [y, y + height, y + height, y], lw=1.5, color='black')
    ax.text((x1 + x2) * 0.5, y + height * 1.5, text,
            ha='center', va='bottom', color='black', fontsize=16)

def get_significance_text(p_value):
    """
    Return significance stars based on a p-value.
    """
    if p_value < 0.001:
        return '***'
    elif p_value < 0.01:
        return '**'
    elif p_value < 0.05:
        return '*'
    else:
        return 'n.s.'

# =============================================================================
# 3. DATA LOADING & PREPARATION
# =============================================================================

# (a) Read any CSV files starting with "sub" in the "data" directory
data_directory = 'data'
dataframes = read_csv_files(data_directory)
for i, df in enumerate(dataframes):
    print(f"\nFirst few rows of DataFrame {i+1}:")
    print(df.head())

# (b) Change directory to where our main CSV files reside and load them.
os.chdir("data")

# Load various CSV files (adjust file names as needed)
CrypticCreatures = pd.read_csv("Table_CrypticCreatures_YaleCohort.csv")
CrypticCreature_relativeShift = pd.read_csv("Table_CrypticCreaturesShiftRelative_YaleCohort.csv")
CrypticCreatures_patients_relativeShift = pd.read_csv("Table_CrypticCreaturesShiftRelative_patients_YaleCohort.csv")
CrypticCreatures_patients_relativeShift = CrypticCreatures_patients_relativeShift.sort_values(by=['id', 'nTrial_rel'])
CrypticCreatures_controls_relativeShift = pd.read_csv("Table_CrypticCreaturesShiftRelative_controls_YaleCohort.csv")
CrypticCreatures_controls_relativeShift = CrypticCreatures_controls_relativeShift.sort_values(by=['id', 'nTrial_rel'])

CrypticCreatures_BayesianLearner = pd.read_csv("CrypticCreatures_BayesianLearner.csv")
CrypticCreatures_BayesianLearner_patients_relativeShift = pd.read_csv("CrypticCreaturesBayesianLearner_relativeShift_OCD.csv")
CrypticCreatures_BayesianLearner_patients_relativeShift = CrypticCreatures_BayesianLearner_patients_relativeShift.sort_values(by=['id', 'nTrial_rel'])
CrypticCreatures_BayesianLearner_controls_relativeShift = pd.read_csv("CrypticCreaturesBayesianLearner_relativeShift_controls.csv")
CrypticCreatures_BayesianLearner_controls_relativeShift = CrypticCreatures_BayesianLearner_controls_relativeShift.sort_values(by=['id', 'nTrial_rel'])

# (c) Compute summary statistics for various variables (patients and controls)

# Patients summary for accuracy and confidence (by trial shift)
Cryptic_mean_acc_ID_patients = summarySE(CrypticCreatures_patients_relativeShift, 'mean_accuracy_id', ['nTrial_rel'])
Cryptic_mean_acc_ED_patients = summarySE(CrypticCreatures_patients_relativeShift, 'mean_accuracy_ed', ['nTrial_rel'])
Cryptic_mean_acc_patients    = summarySE(CrypticCreatures_patients_relativeShift, 'mean_accuracy', ['nTrial_rel'])
Cryptic_mean_conf_ID_patients = summarySE(CrypticCreatures_patients_relativeShift, 'mean_confidence_id', ['nTrial_rel'])
Cryptic_mean_conf_ED_patients = summarySE(CrypticCreatures_patients_relativeShift, 'mean_confidence_ed', ['nTrial_rel'])
Cryptic_mean_conf_patients    = summarySE(CrypticCreatures_patients_relativeShift, 'mean_confidence', ['nTrial_rel'])

# Controls summary for accuracy and confidence
Cryptic_mean_acc_ID_controls = summarySE(CrypticCreatures_controls_relativeShift, 'mean_accuracy_id', ['nTrial_rel'])
Cryptic_mean_acc_ED_controls = summarySE(CrypticCreatures_controls_relativeShift, 'mean_accuracy_ed', ['nTrial_rel'])
Cryptic_mean_acc_controls    = summarySE(CrypticCreatures_controls_relativeShift, 'mean_accuracy', ['nTrial_rel'])
Cryptic_mean_conf_ID_controls = summarySE(CrypticCreatures_controls_relativeShift, 'mean_confidence_id', ['nTrial_rel'])
Cryptic_mean_conf_ED_controls = summarySE(CrypticCreatures_controls_relativeShift, 'mean_confidence_ed', ['nTrial_rel'])
Cryptic_mean_conf_controls   = summarySE(CrypticCreatures_controls_relativeShift, 'mean_confidence', ['nTrial_rel'])

# Bayesian learner variables for entropy, sum_prior, BLR_confidence, and deviations
Cryptic_mean_entr_patients = summarySE(CrypticCreatures_BayesianLearner_patients_relativeShift, 'entropy', ['nTrial_rel'])
Cryptic_mean_entr_controls = summarySE(CrypticCreatures_BayesianLearner_controls_relativeShift, 'entropy', ['nTrial_rel'])
Cryptic_mean_sumprior_patients = summarySE(CrypticCreatures_BayesianLearner_patients_relativeShift, 'sum_prior_chosen_features', ['nTrial_rel'])
Cryptic_mean_sumprior_controls = summarySE(CrypticCreatures_BayesianLearner_controls_relativeShift, 'sum_prior_chosen_features', ['nTrial_rel'])
Cryptic_mean_BLR_confidence_patients = summarySE(CrypticCreatures_BayesianLearner_patients_relativeShift, 'BLR_confidence', ['nTrial_rel'])
Cryptic_mean_BLR_confidence_controls = summarySE(CrypticCreatures_BayesianLearner_controls_relativeShift, 'BLR_confidence', ['nTrial_rel'])
Cryptic_mean_signed_confidence_deviation_patients = summarySE(CrypticCreatures_BayesianLearner_patients_relativeShift, 'signed_confidence_deviation', ['nTrial_rel'])
Cryptic_mean_signed_confidence_deviation_controls = summarySE(CrypticCreatures_BayesianLearner_controls_relativeShift, 'signed_confidence_deviation', ['nTrial_rel'])
Cryptic_mean_signed_prior_deviation_patients = summarySE(CrypticCreatures_BayesianLearner_patients_relativeShift, 'signed_prior_deviation', ['nTrial_rel'])
Cryptic_mean_signed_prior_deviation_controls = summarySE(CrypticCreatures_BayesianLearner_controls_relativeShift, 'signed_prior_deviation', ['nTrial_rel'])

# (d) Create feedback and lagged variables in the main dataset
CrypticCreatures['feedback'] = CrypticCreatures['chosen_outcome'].apply(lambda x: 1 if x == 'correct' else 0)
CrypticCreatures['prev_feedback'] = CrypticCreatures['feedback'].shift(1)
CrypticCreatures['prev_confidence'] = CrypticCreatures['confidence'].shift(1)
# Remove first trial (which has no previous data)
CrypticCreatures = CrypticCreatures.dropna(subset=['prev_feedback', 'prev_confidence'])

# (e) Merge Bayesian learner datasets with relative shift data
controls_df = CrypticCreatures_controls_relativeShift 
patients_df = CrypticCreatures_patients_relativeShift

controls_merged_df = pd.merge(
    CrypticCreatures_BayesianLearner_controls_relativeShift,
    controls_df,
    on=['nTrial_rel', 'id'],
    how='inner'
)
patients_merged_df = pd.merge(
    CrypticCreatures_BayesianLearner_patients_relativeShift,
    patients_df,
    on=['nTrial_rel', 'id'],
    how='inner'
)

# (f) Compute summary statistics for several measure variables in the merged datasets.
measure_vars = [
    'change_in_mean_accuracy', 'change_in_mean_accuracy_abs', 
    'change_in_mean_accuracy_ed', 'change_in_mean_accuracy_abs_ed', 
    'change_in_mean_accuracy_id', 'change_in_mean_accuracy_abs_id',
    'change_in_mean_confidence', 'change_in_mean_confidence_abs',
    'change_in_mean_confidence_ed', 'change_in_mean_confidence_abs_ed',
    'change_in_mean_confidence_id', 'change_in_mean_confidence_abs_id',
    'signed_confidence_deviation', 'signed_prior_deviation',
    'mean_confidence', 'BLR_confidence', 'mean_accuracy'
]
summary_controls_rel = {var: summarySE(controls_merged_df, var, 'nTrial_rel') for var in measure_vars}
summary_patients_rel = {var: summarySE(patients_merged_df, var, 'nTrial_rel') for var in measure_vars}

# Example print to compare summaries:
print("Summary for controls (mean_confidence):")
print(summary_controls_rel['mean_confidence'])
print("\nSummary for controls (from Cryptic_mean_conf_controls):")
print(Cryptic_mean_conf_controls)

# =============================================================================
# 4. STATISTICAL TESTS & AVERAGE-LEVEL DATA
# =============================================================================

# Compute each individual's average accuracy and confidence
CrypticCreatures['average_accuracy'] = CrypticCreatures.groupby('id')['chosen_outcome'].transform('mean')
CrypticCreatures['average_confidence'] = CrypticCreatures.groupby('id')['confidence'].transform('mean')
average_data = CrypticCreatures[['id', 'average_accuracy', 'average_confidence', 'patientstatus']].drop_duplicates()

# Separate groups
controls_acc = average_data[average_data['patientstatus'] == 0]['average_accuracy']
patients_acc = average_data[average_data['patientstatus'] == 1]['average_accuracy']
controls_conf = average_data[average_data['patientstatus'] == 0]['average_confidence']
patients_conf = average_data[average_data['patientstatus'] == 1]['average_confidence']

def perform_tests(controls, patients, measure_name):
    """
    Perform normality, variance, and group difference tests.
    """
    shapiro_controls = shapiro(controls)
    shapiro_patients = shapiro(patients)
    print(f'\nShapiro-Wilk Test for Controls {measure_name}: {shapiro_controls}')
    print(f'Shapiro-Wilk Test for Patients {measure_name}: {shapiro_patients}')
    
    levene_test = levene(controls, patients)
    print(f"Levene's Test for {measure_name}: {levene_test}")
    
    # Use t-test if assumptions are met; otherwise, use Mann-Whitney U test.
    if (shapiro_controls.pvalue > 0.05 and shapiro_patients.pvalue > 0.05 and levene_test.pvalue > 0.05):
        t_stat, p_value = ttest_ind(controls, patients)
        print(f'T-test for {measure_name}: T-statistic = {t_stat}, P-value = {p_value}')
    else:
        u_stat, p_value = mannwhitneyu(controls, patients)
        print(f'Mann-Whitney U Test for {measure_name}: U-statistic = {u_stat}, P-value = {p_value}')
    return p_value

p_value_acc = perform_tests(controls_acc, patients_acc, 'Accuracy')
p_value_conf = perform_tests(controls_conf, patients_conf, 'Confidence')

# =============================================================================
# 5. PLOTTING FUNCTIONS (AX-BASED)
# =============================================================================

def plot_accuracy_boxplot_ax(ax, average_data, p_value_acc):
    """
    Plot average accuracy boxplot with swarm points and a significance bracket on a provided axis.
    """
    sns.boxplot(x='patientstatus', y='average_accuracy', data=average_data, 
                palette={0: 'olive', 1: 'darkblue'}, ax=ax, width=0.3)
    sns.swarmplot(x='patientstatus', y='average_accuracy', data=average_data, 
                  color='black', dodge=True, ax=ax, marker='o', size=5)
    
    max_acc = average_data['average_accuracy'].max()
    ylim = max_acc + 0.1
    significance_text = get_significance_text(p_value_acc)
    draw_significance_bracket(ax, 0, 1, ylim - 0.07, significance_text)
    
    ax.set_ylim(0.4, ylim)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Controls', 'Patients'])
    ax.set_yticks(np.arange(0.5, 0.9, 0.1))
    ax.set_xlabel('')
    ax.set_ylabel('Average Accuracy', fontweight='bold')
    sns.despine(ax=ax)

def plot_confidence_boxplot_ax(ax, average_data, p_value_conf):
    """
    Plot average confidence boxplot with swarm points and a significance bracket on a provided axis.
    """
    sns.boxplot(x='patientstatus', y='average_confidence', data=average_data, 
                palette={0: 'olive', 1: 'darkblue'}, ax=ax, width=0.3)
    sns.swarmplot(x='patientstatus', y='average_confidence', data=average_data, 
                  color='black', dodge=True, ax=ax, marker='o', size=5)
    
    max_conf = average_data['average_confidence'].max()
    ylim = max_conf + 10
    significance_text = get_significance_text(p_value_conf)
    draw_significance_bracket(ax, 0, 1, ylim - 7, significance_text)
    
    ax.set_ylim(0, ylim)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Controls', 'Patients'])
    ax.set_yticks(np.arange(0, 121, 20))
    ax.set_xlabel('')
    ax.set_ylabel('Average Confidence', fontweight='bold')
    sns.despine(ax=ax)

def plot_mean_var_ax(ax, summary_controls, summary_patients, ylabel, ylim, yticks, colors):
    """
    Plot the mean value (with error bars) over trial shifts for two groups on a provided axis.
    """
    # Ensure summaries are DataFrames and add a group label
    summary_controls = pd.DataFrame(summary_controls).copy()
    summary_patients = pd.DataFrame(summary_patients).copy()
    summary_controls['Group'] = 'Controls'
    summary_patients['Group'] = 'Patients'
    
    df_combined = pd.concat([summary_controls, summary_patients], ignore_index=True)
    
    for group, color in colors.items():
        label = 'Controls' if group == 0 else 'Patients'
        subset = df_combined[df_combined['Group'] == label]
        
        sns.lineplot(data=subset, x='nTrial_rel', y='mean', color=color, ax=ax)
        sns.scatterplot(data=subset, x='nTrial_rel', y='mean', color=color, 
                        edgecolor='white', s=100, label=label, ax=ax)
        ax.errorbar(subset['nTrial_rel'], subset['mean'], yerr=subset['stderr'], 
                    fmt='none', color=color, capsize=5)
    
    ax.axvline(0, color='black', linewidth=1, linestyle='--')
    ax.set_ylim(ylim)
    ax.set_yticks(yticks)
    ax.set_xticks(sorted(df_combined['nTrial_rel'].unique()))
    ax.set_xlabel('Trial (0 = Shifts)', fontsize=14, fontweight='bold')
    ax.set_ylabel(ylabel, fontsize=14)
    ax.legend(title='Group', fontsize=12, title_fontsize=12, loc='lower left')
    sns.despine(ax=ax)

# =============================================================================
# 6. COMBINED FIGURE: BOXPLOTS & MEAN-OVER-TRIAL PLOTS
# =============================================================================

fig, axes = plt.subplots(2, 2, figsize=(12, 12), gridspec_kw={'wspace': 0.3, 'hspace': 0.3})

# Top row: Accuracy
plot_accuracy_boxplot_ax(axes[0, 0], average_data, p_value_acc)
# For the mean plot we use the summary computed for 'chosen_outcome' (assumed here as mean accuracy)
plot_mean_var_ax(axes[0, 1], 
                 summary_controls_rel['change_in_mean_accuracy'], 
                 summary_patients_rel['change_in_mean_accuracy'], 
                 ylabel='Mean Accuracy', 
                 ylim=(0, 1), 
                 yticks=np.arange(0, 1.1, 0.2),
                 colors={0: 'olive', 1: 'darkblue'})

# Bottom row: Confidence
plot_confidence_boxplot_ax(axes[1, 0], average_data, p_value_conf)
plot_mean_var_ax(axes[1, 1], 
                 summary_controls_rel['change_in_mean_confidence'], 
                 summary_patients_rel['change_in_mean_confidence'], 
                 ylabel='Mean Confidence', 
                 ylim=(0, 115), 
                 yticks=np.arange(0, 121, 20),
                 colors={0: 'olive', 1: 'darkblue'})

plt.tight_layout()
plt.show()

# =============================================================================
# 7. ADDITIONAL PLOTS: CHANGE IN MEAN VARIABLES (DELTA PLOTS)
# =============================================================================

def plot_change_in_mean_var(summary_controls, summary_patients, title, ydesc, colors, ylims):
    """
    Plot the change (delta) in a variable across trial shifts for controls and patients.
    """
    summary_controls = pd.DataFrame(summary_controls).copy()
    summary_patients = pd.DataFrame(summary_patients).copy()
    summary_controls['patientstatus'] = 0
    summary_patients['patientstatus'] = 1
    
    df_combined = pd.concat([summary_controls, summary_patients], ignore_index=True)
    
    plt.figure(figsize=(12, 6))
    for group in [0, 1]:
        subset = df_combined[df_combined['patientstatus'] == group]
        sns.lineplot(data=subset, x='nTrial_rel', y='mean', color=colors[group])
        sns.scatterplot(data=subset, x='nTrial_rel', y='mean', color=colors[group],
                        edgecolor=colors[group], s=100, label='Controls' if group == 0 else 'Patients')
        plt.errorbar(subset['nTrial_rel'], subset['mean'], yerr=subset['stderr'], fmt='o',
                     color=colors[group], capsize=5)
    
    plt.axhline(0, color='black', linewidth=1.2, linestyle='--')
    plt.axvline(0, color='black', linewidth=0.5, linestyle='--')
    plt.ylim(ylims)
    plt.title(title)
    plt.xlabel('Trial (0 = Shifts)')
    plt.ylabel(ydesc)
    plt.legend(title='Group')
    plt.tight_layout()
    plt.show()

# Plot delta for accuracy
plot_change_in_mean_var(
    summary_controls_rel['change_in_mean_accuracy'], 
    summary_patients_rel['change_in_mean_accuracy'], 
    title='Change in Mean Accuracy Across Shifts', 
    ydesc='Accuracy Delta',
    colors={0: 'olive', 1: 'darkblue'},
    ylims=(-0.7, 0.4)
)

# Plot delta for confidence
plot_change_in_mean_var(
    summary_controls_rel['change_in_mean_confidence'], 
    summary_patients_rel['change_in_mean_confidence'], 
    title='Change in Mean Confidence Across Shifts', 
    ydesc='Confidence Delta',
    colors={0: '#16463F', 1: '#2D80A7'},
    ylims=(-19, 6)
)


Checking files in directory: data


FileNotFoundError: [Errno 2] No such file or directory: 'data'