In [None]:
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy.stats

# Function to create DataFrame with task data, significance, and subset labels
def create_dataframe(task_predictive, task_nonpredictive, sig_predictive, sig_nonpredictive, label):
    # Combining task data
    task_combined = np.concatenate((task_predictive, task_nonpredictive))
    # Combining significance data
    sig_combined = np.concatenate((sig_predictive, sig_nonpredictive))
    # Labels for predictive and nonpredictive subsets
    labels = [f'P {label}'] * len(task_predictive) + [f'N {label}'] * len(task_nonpredictive)

    df = pd.DataFrame({
        'Accuracy': task_combined,
        'Sig': sig_combined,
        'Label': labels
    })
    df['Color'] = df['Sig'].apply(lambda x: 'red' if x == 1 else 'blue')
    return df

# Define a color mapping based on the unique values in the 'Color' column
color_mapping = {'red': 'red', 'blue': 'blue'}

# Creating dataframes for spatial task
df_start_from_start_spa = create_dataframe(data['subsetStartFromStartSpa'][0], data['notSubsetStartFromStartSpa'][0], 
                                       data['subsetStartFromStartSpaSig'][0], data['notSubsetStartFromStartSpaSig'][0], 'Current')
df_goal_from_start_spa = create_dataframe(data['subsetGoalFromStartSpa'][0], data['notSubsetGoalFromStartSpa'][0], 
                                      data['subsetGoalFromStartSpaSig'][0], data['notSubsetGoalFromStartSpaSig'][0], 'Prospective')
df_start_from_goal_spa = create_dataframe(data['subsetStartFromGoalSpa'][0], data['notSubsetStartFromGoalSpa'][0], 
                                      data['subsetStartFromGoalSpaSig'][0], data['notSubsetStartFromGoalSpaSig'][0], 'Retrospective')
df_goal_from_goal_spa = create_dataframe(data['subsetGoalFromGoalSpa'][0], data['notSubsetGoalFromGoalSpa'][0], 
                                     data['subsetGoalFromGoalSpaSig'][0], data['notSubsetGoalFromGoalSpaSig'][0], 'Current')

# Concatenate data for each subplot
df_start_arm_activity_spa = pd.concat([df_start_from_start_spa, df_goal_from_start_spa], ignore_index=True)
df_goal_arm_activity_spa = pd.concat([df_start_from_goal_spa, df_goal_from_goal_spa], ignore_index=True)

# Create a figure with two subplots
fig, axes = plt.subplots(2, 1, figsize=(8, 12))

# Plot 1: Start Arm Activity with significance coloring
sns.swarmplot(x='Label', y='Accuracy', hue='Color', data=df_start_arm_activity_spa, ax=axes[0], palette=color_mapping, size=14, dodge=False)
# axes[0].set_title('Start Arm Activity - Spatial Task')
# axes[0].set_title('Start Arm Activity', fontsize=32)
axes[0].set_xticks(range(len(df_start_arm_activity_spa['Label'].unique())))
# axes[0].set_xticklabels(df_start_arm_activity_spa['Label'].unique(), rotation=45, ha='right')
axes[0].set_xticklabels('')
axes[0].set_xlabel('')
# axes[0].set_ylabel('Proportion of trials correctly decoded (Spatial Task)', fontsize=16)
axes[0].set_ylabel('')
axes[0].set_ylim(bottom=0.4)
axes[0].tick_params(axis='both', which='major', labelsize=24)
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label='p < 0.05', markerfacecolor='red', markersize=10),
                   plt.Line2D([0], [0], marker='o', color='w', label='p ≥ 0.05', markerfacecolor='blue', markersize=10),
                   plt.Line2D([0], [0], marker='$P$', color='black', markerfacecolor='black', markersize=10, lw=0, label='Predictive\nSubset'),
                   plt.Line2D([0], [0], marker='$N$', color='black', markerfacecolor='black', markersize=10, lw=0, label='Nonpredictive\nSubset')]
    
# axes[0].legend(handles=legend_elements, loc='lower right')
axes[0].get_legend().remove()

# Adjust order of categories for Plot 2: Goal Arm Activity
order_labels = ['P Current', 'N Current', 'P Retrospective', 'N Retrospective']  # New specified order

# Plot 2: Goal Arm Activity with significance coloring
sns.swarmplot(x='Label', y='Accuracy', hue='Color', data=df_goal_arm_activity_spa, ax=axes[1], palette=color_mapping, size=14, dodge=False, order=order_labels)
# axes[1].set_title('Goal Arm Activity', fontsize=32)
axes[1].set_xticks(range(len(order_labels)))
# axes[1].set_xticklabels(order_labels, rotation=45, ha='right')
axes[1].set_xticklabels('')
axes[1].set_xlabel('')
axes[1].set_ylabel('')
axes[1].set_yticklabels('')
axes[1].set_ylim(bottom=0.4)
axes[1].tick_params(axis='both', which='major', labelsize=14)
axes[1].legend(handles=legend_elements, fontsize=12, loc='upper right')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()

# Creating dataframes for cue task
df_start_from_start_cue = create_dataframe(data['subsetStartFromStartCue'][0], data['notSubsetStartFromStartCue'][0], 
                                       data['subsetStartFromStartCueSig'][0], data['notSubsetStartFromStartCueSig'][0], 'Current')
df_goal_from_start_cue = create_dataframe(data['subsetGoalFromStartCue'][0], data['notSubsetGoalFromStartCue'][0], 
                                      data['subsetGoalFromStartCueSig'][0], data['notSubsetGoalFromStartCueSig'][0], 'Prospective')
df_start_from_goal_cue = create_dataframe(data['subsetStartFromGoalCue'][0], data['notSubsetStartFromGoalCue'][0], 
                                      data['subsetStartFromGoalCueSig'][0], data['notSubsetStartFromGoalCueSig'][0], 'Retrospective')
df_goal_from_goal_cue = create_dataframe(data['subsetGoalFromGoalCue'][0], data['notSubsetGoalFromGoalCue'][0], 
                                     data['subsetGoalFromGoalCueSig'][0], data['notSubsetGoalFromGoalCueSig'][0], 'Current')

# Concatenate data for each subplot
df_start_arm_activity_cue = pd.concat([df_start_from_start_cue, df_goal_from_start_cue], ignore_index=True)
df_goal_arm_activity_cue = pd.concat([df_start_from_goal_cue, df_goal_from_goal_cue], ignore_index=True)

# Create a figure with two subplots
fig, axes = plt.subplots(2, 1, figsize=(8, 12))

# Plot 1: Start Arm Activity with significance coloring
sns.swarmplot(x='Label', y='Accuracy', hue='Color', data=df_start_arm_activity_cue, ax=axes[0], palette=color_mapping, size=14, dodge=False)
# axes[0].set_title('Start Arm Activity - Cue Task')
axes[0].set_xticks(range(len(df_start_arm_activity_cue['Label'].unique())))
# axes[0].set_xticklabels(df_start_arm_activity_cue['Label'].unique(), rotation=45, ha='right')
# axes[0].set_xticklabels(df_start_arm_activity_cue['Label'].unique())
axes[0].set_xticklabels('')
axes[0].set_xlabel('')
# axes[0].set_ylabel('Proportion of trials correctly decoded (Cue Task)', fontsize=16)
axes[0].set_ylabel('')
axes[0].set_ylim(bottom=0.4)
axes[0].tick_params(axis='both', which='major', labelsize=24)
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label='p < 0.05', markerfacecolor='red', markersize=10),
                       plt.Line2D([0], [0], marker='o', color='w', label='p ≥ 0.05', markerfacecolor='blue', markersize=10)]
# axes[0].legend(handles=legend_elements, loc='lower right')
axes[0].get_legend().remove()

# Adjust order of categories for Plot 2: Goal Arm Activity
order_labels = ['P Current', 'N Current', 'P Retrospective', 'N Retrospective']  # New specified order

# Plot 2: Goal Arm Activity with significance coloring
sns.swarmplot(x='Label', y='Accuracy', hue='Color', data=df_goal_arm_activity_cue, ax=axes[1], palette=color_mapping, size=14, dodge=False, order=order_labels)
# axes[1].set_title('Goal Arm Activity - Cue Task')
axes[1].set_xticks(range(len(order_labels)))
# axes[1].set_xticklabels(order_labels, rotation=45, ha='right')
axes[1].set_xticklabels('')
axes[1].set_xlabel('')
axes[1].set_ylabel('')
# axes[1].set_ylabel('Proportion of trials correctly decoded (Cue Task)')
axes[1].set_ylim(bottom=0.4)
axes[1].set_yticklabels('')
axes[1].tick_params(axis='both', which='major', labelsize=14)
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label='p < 0.05', markerfacecolor='red', markersize=10),
                       plt.Line2D([0], [0], marker='o', color='w', label='p ≥ 0.05', markerfacecolor='blue', markersize=10)]
# axes[1].legend(handles=legend_elements, loc='lower right')
axes[1].get_legend().remove()

# Adjust layout and show the plot
plt.tight_layout()
plt.show()