In [None]:
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import scipy.stats 
import matplotlib.patches as mpatches

# New labels for display
new_task_labels = {'spa': 'Spatial', 'cue': 'Cue'}
new_subset_labels = {'Subset': 'Predictive', 'NotSubset': 'Nonpredictive'}
# new_grouping_labels = {'StartDecStart': 'Start Decoding Start', 'GoalDecGoal': 'Goal Decoding Goal', 'StartDecGoal': 'Start Decoding Goal', 'GoalDecStart': 'Goal Decoding Start'}
new_grouping_labels = {'StartDecStart': 'Current (Start Arm)', 'GoalDecGoal': 'Current (Goal Arm)', 
                       'StartDecGoal': 'Prospective', 'GoalDecStart': 'Retrospective'}

# Original labels for data extraction
tasks = ['spa', 'cue']
subsets = ['Subset', 'NotSubset']
groupings = ['StartDecStart', 'GoalDecGoal', 'StartDecGoal', 'GoalDecStart']

# Create a DataFrame to organize the data for plotting
df = pd.DataFrame()

# Extract data for each condition and add to the DataFrame
for task in tasks:
    for subset in subsets:
        for grouping in groupings:
            # Construct the variable name for data extraction
            var_name = f'{task}{subset}{grouping}'
            # Extract data
            data_array = data[var_name].flatten()
            # Create temporary DataFrame to organize this portion of the data
            temp_df = pd.DataFrame({
                'Task': f'{new_task_labels[task]} {new_subset_labels[subset]}',
                'Grouping': new_grouping_labels[grouping],
                'MeanBitsPerSpike': data_array
            })
            # Append to the main DataFrame
            df = pd.concat([df, temp_df], ignore_index=True)

# Set the corrected order for x-axis categories based on the task and subset
corrected_order = ['Spatial Predictive', 'Spatial Nonpredictive', 'Cue Predictive', 'Cue Nonpredictive']

# Initialize a grid for the 2x2 subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10), sharey=True)

# Flatten the axes array for easy iteration
axes_flat = axes.flatten()

# Create a swarmplot and a boxplot for each grouping
for i, grouping in enumerate(new_grouping_labels.values()):
    ax = axes_flat[i]
    # Overlaying the boxplot
    sns.boxplot(x='Task', y='MeanBitsPerSpike', data=df[df['Grouping'] == grouping], 
                order=corrected_order, ax=ax, width=0.3, palette="Set2", fliersize=0)
    # Overlaying the swarmplot
    sns.swarmplot(x='Task', y='MeanBitsPerSpike', data=df[df['Grouping'] == grouping], order=corrected_order, ax=ax, size=1, color="k", alpha=0.5)
    ax.set_title(grouping, fontsize=36)
    # ax.set_xlabel('Task')
    # ax.set_ylabel('Mean Bits/Spike')
    
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    # Set xlabel and ylabel based on subplot position
    if i < 2:  # Top row plots
        ax.set_xticklabels('')
        ax.set_xlabel('')
    else:  # Bottom row plots
        ax.set_xlabel('Task,\nSubset', fontsize=30)
        ax.set_xticklabels(['Spatial,\nPredictive', 'Spatial,\nNonpredictive', 'Cue,\nPredictive', 'Cue,\nNonpredictive'])

    if i % 2 != 0:  # Right column plots
        ax.set_ylabel('')
    else:  # Left column plots
        ax.set_ylabel('Mean Bits/Spike', fontsize=30)

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the plot
plt.show()

# Data for means and SEMs
means_sems = {
    'spaSubsetStartDecStart': {'Mean': np.mean(data['spaSubsetStartDecStart'].flatten()), 'SEM': scipy.stats.sem(data['spaSubsetStartDecStart'].flatten(), ddof=0)},
    'spaNotSubsetStartDecStart': {'Mean': np.mean(data['spaNotSubsetStartDecStart'].flatten()), 'SEM': scipy.stats.sem(data['spaNotSubsetStartDecStart'].flatten(), ddof=0)},
    'spaSubsetStartDecGoal': {'Mean': np.mean(data['spaSubsetStartDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['spaSubsetStartDecGoal'].flatten(), ddof=0)},
    'spaNotSubsetStartDecGoal': {'Mean': np.mean(data['spaNotSubsetStartDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['spaNotSubsetStartDecGoal'].flatten(), ddof=0)},
    'spaSubsetGoalDecStart': {'Mean': np.mean(data['spaSubsetGoalDecStart'].flatten()), 'SEM': scipy.stats.sem(data['spaSubsetGoalDecStart'].flatten(), ddof=0)},
    'spaNotSubsetGoalDecStart': {'Mean': np.mean(data['spaNotSubsetGoalDecStart'].flatten()), 'SEM': scipy.stats.sem(data['spaNotSubsetGoalDecStart'].flatten(), ddof=0)},
    'spaSubsetGoalDecGoal': {'Mean': np.mean(data['spaSubsetGoalDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['spaSubsetGoalDecGoal'].flatten(), ddof=0)},
    'spaNotSubsetGoalDecGoal': {'Mean': np.mean(data['spaNotSubsetGoalDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['spaNotSubsetGoalDecGoal'].flatten(), ddof=0)},
    'cueSubsetStartDecStart': {'Mean': np.mean(data['cueSubsetStartDecStart'].flatten()), 'SEM': scipy.stats.sem(data['cueSubsetStartDecStart'].flatten(), ddof=0)},
    'cueNotSubsetStartDecStart': {'Mean': np.mean(data['cueNotSubsetStartDecStart'].flatten()), 'SEM': scipy.stats.sem(data['cueNotSubsetStartDecStart'].flatten(), ddof=0)},
    'cueSubsetStartDecGoal': {'Mean': np.mean(data['cueSubsetStartDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['cueSubsetStartDecGoal'].flatten(), ddof=0)},
    'cueNotSubsetStartDecGoal': {'Mean': np.mean(data['cueNotSubsetStartDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['cueNotSubsetStartDecGoal'].flatten(), ddof=0)},
    'cueSubsetGoalDecStart': {'Mean': np.mean(data['cueSubsetGoalDecStart'].flatten()), 'SEM': scipy.stats.sem(data['cueSubsetGoalDecStart'].flatten(), ddof=0)},
    'cueNotSubsetGoalDecStart': {'Mean': np.mean(data['cueNotSubsetGoalDecStart'].flatten()), 'SEM': scipy.stats.sem(data['cueNotSubsetGoalDecStart'].flatten(), ddof=0)},
    'cueSubsetGoalDecGoal': {'Mean': np.mean(data['cueSubsetGoalDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['cueSubsetGoalDecGoal'].flatten(), ddof=0)},
    'cueNotSubsetGoalDecGoal': {'Mean': np.mean(data['cueNotSubsetGoalDecGoal'].flatten()), 'SEM': scipy.stats.sem(data['cueNotSubsetGoalDecGoal'].flatten(), ddof=0)}
}

# Function to check overlap within ±1 SEM for specific consecutive pairs (odd index only)
def check_sem_overlap_modified(means_sems, keys):
    overlaps = []
    for i in range(0, len(keys) - 1, 2):  # Adjusting the range to step by 2 starting from 0
        key1, key2 = keys[i], keys[i + 1]
        mean1, sem1 = means_sems[key1]['Mean'], means_sems[key1]['SEM']
        mean2, sem2 = means_sems[key2]['Mean'], means_sems[key2]['SEM']
        if (mean1 + sem1 >= mean2 - sem2) and (mean1 - sem1 <= mean2 + sem2):
            overlaps.append((key1, key2))
    return overlaps

# Assume means_sems is your pre-calculated dictionary of means and SEMs
# Your data setup should go here...

# Define the specific keys for each subplot in the 2x2 layout
plot_layout_2x2 = {
    (0, 0): ['spaSubsetStartDecStart', 'spaNotSubsetStartDecStart', 'cueSubsetStartDecStart', 'cueNotSubsetStartDecStart'],
    (0, 1): ['spaSubsetGoalDecGoal', 'spaNotSubsetGoalDecGoal', 'cueSubsetGoalDecGoal', 'cueNotSubsetGoalDecGoal'],
    (1, 0): ['spaSubsetStartDecGoal', 'spaNotSubsetStartDecGoal', 'cueSubsetStartDecGoal', 'cueNotSubsetStartDecGoal'],
    (1, 1): ['spaSubsetGoalDecStart', 'spaNotSubsetGoalDecStart', 'cueSubsetGoalDecStart', 'cueNotSubsetGoalDecStart']
}

# Create the 2x2 subplots
fig, axs = plt.subplots(2, 2, figsize=(14, 12))

# Define the patch for the legend
red_patch = mpatches.Patch(color='red', label='Overlapping Pairs')

# Plotting each subplot with the designated keys and fixed y-range
for (i, j), keys in plot_layout_2x2.items():
    ax = axs[i, j]
    for idx, key in enumerate(keys):
        ax.errorbar(x=idx, y=means_sems[key]['Mean'], yerr=means_sems[key]['SEM'], fmt='o', color='blue')
        ax.plot([idx - 0.2, idx + 0.2], [means_sems[key]['Mean'] + means_sems[key]['SEM'], means_sems[key]['Mean'] + means_sems[key]['SEM']], color='blue')
        ax.plot([idx - 0.2, idx + 0.2], [means_sems[key]['Mean'] - means_sems[key]['SEM'], means_sems[key]['Mean'] - means_sems[key]['SEM']], color='blue')
    # comment out line below for auto y axes to clearly see overlaps better
    ax.set_ylim(0.75, 1.75)
    ax.set_yticks(np.arange(0.75, 2.00, 0.25))
    ax.grid(True, which='major', linestyle='--', linewidth='0.5', color='grey')
    ax.set_xticks(range(len(keys)))
    ax.set_xticklabels(keys, rotation=45, ha='right')
    ax.set_title(' '.join(keys[0].split(' ')[0:2]))
    ax.set_ylabel('Mean Bits/Spike')
    overlapping_pairs = check_sem_overlap_modified(means_sems, keys)
    for pair in overlapping_pairs:
        idx1 = keys.index(pair[0])
        idx2 = keys.index(pair[1])
        ax.plot([idx1, idx2], [means_sems[pair[0]]['Mean'], means_sems[pair[1]]['Mean']], linestyle='--', color='red')

# Placing the legend outside of the top right corner of all subplots
plt.suptitle('Means and +/- 1 SEMs with Overlapping Indications')
plt.legend(handles=[red_patch], bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()