# Figure 2

In [None]:
import copy
import itertools
import json
import os
import pickle
from IPython import display as ipd
from datetime import datetime

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from scipy import stats
from statannot import add_stat_annotation

from sylseq_paper.file_utils import loadmat
from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors, fancy_location_colors
from sylseq_paper.statistics import fdr_omitnans, p_value_calc

default_plot_settings(font='Helvetica', fontsize=12)

%load_ext autoreload
%autoreload 2

# Path setup

In [None]:
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
img_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'imaging'))

annot_lw = 0.5

pvalue_thresholds = [[0.001, "***"], [0.01, "**"], [0.05, "*"], [10, "ns"]]

In [None]:
alignments = {
    'target_presentation': {
        'window': [0.0, 2.5 + 0.75],
        'num_timepoints': 32,
        'phase': 'target_presentation'
    },
    'go_cue': {
        'window': [-0.5, 0.75],
        'num_timepoints': 13,
        'phase': 'go_cue'
    },
    'speech': {
        'window': [-0.75, 0.75],
        'num_timepoints': 16,
        'phase': 'first_syllable'
    }
}
alignment_labels = {
    'target_presentation': 'Encoding',
    'delay': 'Delay',
    'pre_exec': 'Pre-speech',
    'speech': 'Speech production'
}
sr = 100
sig_thresh = 0.05

plot_subjects = ['EC217', 'EC219', 'EC223', 'EC237', 'EC240', 'EC241', 'EC253', 'EC254', 'EC260', 'EC263']

mni_img_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D.png')
mni_mask_path = os.path.join(img_dir, 'masks_and_rois', 'MNI_{}_{}_brain_2D_{}_mask.png')

In [None]:
# Renaming `SMA` label to `medial SFG`
renamed_areas = {
    'location': {
        'supplementarymotor': 'medialsuperiorfrontal'
    },
    'fancy_location': {
        'SMA': 'medial SFG'
    }
}

def rename_df_areas(df):
    for label_type, renames in renamed_areas.items():
        for key, val in renames.items():
            df[label_type] = np.where(df[label_type].values == key, val, df[label_type].values)
    return df

In [None]:
# Compile electrode information
edf = pd.read_hdf(os.path.join(data_dir, 'all_anatomical_info.h5'))
edf = rename_df_areas(edf)
edf = edf.loc[(edf.subject.isin(plot_subjects))]
assert np.isin(plot_subjects, edf.subject.unique()).all()

# Load bad channels
with open(os.path.join(data_dir, 'bad_channels.pkl'), 'rb') as f:
    bad_channels = pickle.load(f)

# Load data

## Cluster data

In [None]:
# Load sustained electrode list
with open(os.path.join(data_dir, 'fig1_ndf_sustained_se.pkl'), 'rb') as f:
    results = pickle.load(f)
    sustained_se = results['sustained_se']

## Single electrode statistics

In [None]:
single_elec_stats = {key: [] for key in ['sustained', 'chi_square', 'chisquare_significant', 'allpair_significant']}

# Load sustained group statistics
results_path = os.path.join(data_dir, 'SylSeq_TaskPhase_SingleElec_Sustained.mat')
sustained_stats = loadmat(results_path)
is_pair_sig = np.where(sustained_stats['sig_all_pair_FDR'], True, False)
is_chisquare_sig = np.where(sustained_stats['sig_any_pair_FDR'], True, False)

single_elec_stats['sustained'].extend(len(sustained_stats['chi_square']) * ['Sustained'])
single_elec_stats['chi_square'].extend(sustained_stats['chi_square'])
single_elec_stats['chisquare_significant'].extend(is_chisquare_sig)
single_elec_stats['allpair_significant'].extend(is_pair_sig)

# Load non-sustained group statistics
results_path = os.path.join(data_dir, 'SylSeq_TaskPhase_SingleElec_NonSustained.mat')
non_sustained_stats = loadmat(results_path)
is_pair_sig = np.where(non_sustained_stats['sig_all_pair_FDR'], True, False)
is_chisquare_sig = np.where(non_sustained_stats['sig_any_pair_FDR'], True, False)

single_elec_stats['sustained'].extend(len(non_sustained_stats['chi_square']) * ['Non-sustained'])
single_elec_stats['chi_square'].extend(non_sustained_stats['chi_square'])
single_elec_stats['chisquare_significant'].extend(is_chisquare_sig)
single_elec_stats['allpair_significant'].extend(is_pair_sig)

single_elec_stats = pd.DataFrame(data=single_elec_stats)

In [None]:
def friedman_test(ax=None, legend_fs=10, fs=12, s=3, ms=7):

    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 4))
        
    sns.despine(ax=ax, offset=dict(left=2.5, bottom=2.5))
    ax.spines['bottom'].set_visible(False)
    
    box_width = 0.5
    jitter = box_width / 2
    
    boxplot_kwargs = {}
    for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
        boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
    boxplot_kwargs['boxprops'] = dict(alpha=1, clip_on=False, facecolor='none')
    
    ax = sns.boxplot(data=single_elec_stats, x='sustained', y='chi_square', 
                     color='k', width=box_width, zorder=1,
                     order=['Sustained', 'Non-sustained'], showfliers=False, ax=ax, **boxplot_kwargs)
    
    # Non-significant chi square
    sns.stripplot(data=single_elec_stats.loc[~single_elec_stats.chisquare_significant], x='sustained', y='chi_square', 
                  zorder=2, order=['Sustained', 'Non-sustained'], color='k', jitter=jitter, label='Not sig.',
                  ax=ax, size=s, clip_on=False, alpha=0.3)
    
    # Significant chi square but not all sig posthocs
    sns.stripplot(data=single_elec_stats.loc[(single_elec_stats.chisquare_significant) & (~single_elec_stats.allpair_significant)], 
                  x='sustained', y='chi_square', jitter=jitter, label='Sig. chi-square',
                  zorder=2, order=['Sustained', 'Non-sustained'], color=colors[0],
                  ax=ax, size=s, clip_on=False, alpha=0.4)
    
    # Significant chi square and all sig posthocs
    sns.stripplot(data=single_elec_stats.loc[(single_elec_stats.chisquare_significant) & (single_elec_stats.allpair_significant)], 
                  x='sustained', y='chi_square', jitter=jitter, label='Sig. chi-square &\npost-hoc tests',
                  zorder=3, order=['Sustained', 'Non-sustained'], color=colors[1], marker='^',
                  ax=ax, size=s, clip_on=False, alpha=0.8)
    
    a = single_elec_stats.loc[single_elec_stats.sustained == 'Sustained'].chi_square.values
    b = single_elec_stats.loc[single_elec_stats.sustained == 'Non-sustained'].chi_square.values
    zval, pval = stats.ranksums(a, b)
    print(zval)

    add_stat_annotation(ax, data=single_elec_stats, x='sustained', y='chi_square', 
                        perform_stat_test=False, text_format='star', box_pairs=[(('Sustained'), ('Non-sustained'))],
                        loc='outside', verbose=1, pvalues=[pval], linewidth=annot_lw, pvalue_thresholds=pvalue_thresholds)
    
    ax.axes.set(ylim=(0.0, 250), ylabel='Chi-square statistic', xlabel='');
    ax.xaxis.set_tick_params(length=0)
    
    # Fix legend
    s1 = mpl.lines.Line2D([], [], color='k', marker='o', linestyle='None',
                          markersize=ms, label='Not sig.', alpha=0.5)
    s2 = mpl.lines.Line2D([], [], color=colors[0], marker='o', linestyle='None',
                              markersize=ms, label='Any pair sig.', alpha=0.5)
    s3 = mpl.lines.Line2D([], [], color=colors[1], marker='^', linestyle='None',
                              markersize=ms, label='All pairs sig.', alpha=0.5)
    
    ax.legend(handles=[s1, s2, s3], frameon=False, bbox_to_anchor=(0.85, 0), 
              loc='lower left', fontsize=legend_fs, handletextpad=0.5)
    
    ax.set_xticklabels(['Sustained', 'Non-\nsustained'], fontsize=fs)
        
    return ax

In [None]:
friedman_test();

## Predicted probability trace

In [None]:
num_phases = 4
phase_colors = {
    'encoding': '#aea2cd', #colors[4],
    'delay': colors[0],
    'pre-speech': colors[3],
    'speech prod.': colors[2]
}
phase_names_to_idx = {
    'encoding': 0,
    'delay': 1,
    'pre-speech': 2,
    'speech prod.': 3
}
phase_idx_to_names = {}
for key, val in phase_names_to_idx.items():
    phase_idx_to_names[str(val)] = key
    
# Load sustained cluster decoding
with open(os.path.join(data_dir, 'fig2b_decoding.pkl'), 'rb') as f:
    sustained_decoding = pickle.load(f)

In [None]:
def phase_probability(axs=None, decoding_results=None, legend_fs=10, bbox_height=-0.5):
    
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(12, 3))

    prev_timepoint = 0
    time_x = []
    for ax, align_key in zip(axs.ravel(), alignments.keys()):

        num_timepoints = alignments[align_key]['num_timepoints']
        y = decoding_results['tp_prob_ts_mean'][:, prev_timepoint:prev_timepoint+num_timepoints]
        err = decoding_results['tp_prob_ts_std'][:, prev_timepoint:prev_timepoint+num_timepoints]
        x = np.linspace(*alignments[align_key]['window'], num_timepoints)
        time_x.append(x)

        for cur_phase, phase_label in phase_idx_to_names.items():
            ax.plot(x, y[int(cur_phase), :], color=phase_colors[phase_label],
                    label=phase_label.capitalize(), zorder=1, clip_on=False)
            ax.fill_between(x, y[int(cur_phase), :] - err[int(cur_phase), :], y[int(cur_phase), :] + err[int(cur_phase), :], 
                            color=phase_colors[phase_label],
                    zorder=1, clip_on=False, alpha=0.3)

        prev_timepoint += num_timepoints
        ax.axes.set(xlim=alignments[align_key]['window'])

    axs[-1].legend(loc='upper left', bbox_to_anchor=(1, 1), frameon=False, fontsize=legend_fs)

    ylim = [0, 1]
    for i, (a, t) in enumerate(zip(axs, time_x)):

        a.axes.set(xticks=np.arange(t[0], t[-1]+0.03, 0.25), ylim=ylim, xlabel='Time (s)')
        a.axhline(y=0, color='k', alpha=0.75, linestyle='--', zorder=0)

        if i != 0:
            a.axvline(x=0, color='k', alpha=0.75, linestyle='--', zorder=0, clip_on=False)
            a.spines['left'].set_visible(False)
            a.axes.set(yticks=[])
            sns.despine(ax=a, left=True, offset=dict(bottom=5, left=5))
        else:
            a.axvline(x=0, color='k', alpha=0.75, linestyle='--', zorder=0, clip_on=False)
            a.axvline(x=2.5, color='k', alpha=0.75, linestyle='--', zorder=0, clip_on=False)
            sns.despine(ax=a, offset=dict(left=5, bottom=5))

        # Set major ticks at 1, minor at 0.5
        if i == 0:
            a.xaxis.set_major_locator(MultipleLocator(1))
            a.xaxis.set_major_formatter(FormatStrFormatter('%d'))
        else:
            a.xaxis.set_major_locator(MultipleLocator(0.5))

        a.xaxis.set_minor_locator(MultipleLocator(0.25))

    axs[0].annotate('Target pres.', (-0.1, 1), ha='left', va='bottom', annotation_clip=False)
    axs[0].annotate('Fix. cross', (2.5, 1), ha='center', va='bottom', annotation_clip=False)
    axs[1].annotate('Go-cue', (0, 1), ha='center', va='bottom', annotation_clip=False)
    axs[2].annotate('Speech onset', (0, 1), ha='center', va='bottom', annotation_clip=False)
    axs[0].axes.set(ylabel='Probability', yticks=[0, 0.5, 1]);
    
    return axs

In [None]:
phase_probability(decoding_results=sustained_decoding);

## Growth curve

In [None]:
# Load growth curve info
with open(os.path.join(data_dir, 'fig2_growth_curve_data.pkl'), 'rb') as f:
    gc_data = pickle.load(f)

In [None]:
def growth_curve(ax=None, ms=5, fs=10, sig_fs=15, legend_fs=10, all_subgroups=False, return_fig=False, return_pvals=False):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 4))
        
    if not all_subgroups:
        growth_curve_group_order = ['Sustained', 'Non-sustained', 'Sustained (mPrCG)', 'Sustained (IFG)']
        markers = ['o', '^', 's', 'v']
    else:
        growth_curve_group_order = ['Sustained', 'Non-sustained', 'Sustained (mPrCG)', 'Sustained (IFG)', 'Sustained (vPrCG)', 'Sustained (pSTG)', 'Sustained (SMG)', 'Sustained (medial SFG)']
        markers = ['o', '^', 's', 'v', '>', 'x', 'p', 'P']
    
    for cur_group, group_label in enumerate(growth_curve_group_order):
        
        if group_label == 'Sustained (medial SFG)':
            i = np.where(gc_data['legend_str'] == 'Sustained (SMA)')[0][0]
        else:
            i = np.where(gc_data['legend_str'] == group_label)[0][0]
        
        x = gc_data['N_ch_actual'][i]
        y = gc_data['c_rate_iter_mean'][i]
        err = gc_data['c_rate_iter_std'][i]
        
        if group_label == 'Sustained (mPrCG)':
            color = fancy_location_colors['middle PrCG']
        elif group_label == 'Sustained (IFG)':
            color = fancy_location_colors['Pars operc.']
        elif group_label == 'Sustained':
            color = colors[5]
        elif group_label == 'Non-sustained':
            color = colors[7]
        elif group_label == 'Sustained (vPrCG)':
            color = fancy_location_colors['ventral PrCG']
        elif group_label == 'Sustained (pSTG)':
            color = fancy_location_colors['pSTG']
        elif group_label == 'Sustained (SMG)':
            color = fancy_location_colors['SMG']
        elif group_label == 'Sustained (medial SFG)':
            color = fancy_location_colors['medial SFG']
        
        ax.semilogx(x, y, marker=markers[cur_group], color=color, ms=ms,
                    label=group_label, clip_on=False, alpha=0.7)
        ax.errorbar(x, y, err, color=color, clip_on=False, alpha=0.7)
        
    if not all_subgroups:
        
        all_pvals = []
        for i in range(9):
            data = gc_data['c_rate_iter_draw_mean_sus'][:, i] - gc_data['c_rate_iter_draw_mean_nonsus'][:, i]
            p = p_value_calc(data, test_statistic=0)
            all_pvals.append(p)

        fdr_pvals = fdr_omitnans(np.array(all_pvals),
                                           alpha=sig_thresh,
                                           method='poscorr',
                                           is_sorted=False)
        
        
        for x, sig in zip(gc_data['N_ch_actual'][0], fdr_pvals):
            if sig < sig_thresh:
                ax.annotate('*', (x, 1), ha='center', va='bottom', fontsize=sig_fs)
            
    sns.despine(ax=ax, offset=dict(left=5, bottom=5))
        
    xticks = [1, 2, 5, 10, 20, 40, 80, 160, 320, 901]
    yticks = [0.25, 0.50, 0.75, 1.00]
    ax.axes.set(xticks=xticks, xlim=(1, 901), ylim=(0.25, 1.0),
                yticks=yticks,
                ylabel='Accuracy', xlabel='Number of electrodes')
    ax.set_xticklabels(xticks, fontsize=legend_fs)
    ax.set_yticklabels([f'{l:0.2f}' for l in yticks], fontsize=legend_fs)
    
    ax.legend(loc='lower right', bbox_to_anchor=(1.03, -0.05), frameon=False, fontsize=legend_fs)
        
    if return_fig:
        return fig, ax
    elif return_pvals:
        return ax, fdr_pvals
    else:
        return ax

In [None]:
_, pvals = growth_curve(legend_fs=10, all_subgroups=False, return_pvals=True);

## State space

In [None]:
# Load state space data
with open(os.path.join(data_dir, 'fig2_state_space.pkl'), 'rb') as f:
    stsp = pickle.load(f)

# Load sustained state space projection data
with open(os.path.join(data_dir, 'fig2_state_space_projections.pkl'), 'rb') as f:
    stsp_proj = pickle.load(f)

In [None]:
for i, area in enumerate(['sustained', 'nonsustained'] + list(stsp['regions'])):
    # print(area, stsp['p_angle'][0][i])
    print(area, stsp['p_angle'][i], stsp['r_sq_adj'][0][i], stsp['r_sq_adj'][3][i])

In [None]:
stsp['var_exp'][0][:3].sum()

In [None]:
for i, area in enumerate(['sustained', 'nonsustained'] + list(stsp_proj['regions'])):
    print(area, stsp_proj['p_angle'][i], stsp_proj['r_sq_adj'][0][i], stsp_proj['r_sq_adj'][3][i])

In [None]:
azimuths = {
    'sustained': -150,
    'nonsustained': -120,
    'mprcg': -142
}
elevations = {
    'sustained': 55,
    'nonsustained': 0,
    'mprcg': 109
}

plane_azimuths = {
    'sustained': -14,
    'nonsustained': -16,
    'mprcg': 152
}
plane_elevations = {
    'sustained': 25,
    'nonsustained': 29,
    'mprcg': 26
}

In [None]:
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [None]:
def state_space(region=None, ax=None, plot_title=True, fs=10, plot_planes=False, corner_axis=False, up_to_frame=None, 
                return_fig=False, plot_subprojection=False, alpha=1.0, s=10, trajectory_color=None, zorder=1,
                label=None, plot_legend=False):
    """if plotting a subprojection, the axis must be provided. plotting subprojections and planes does not work together"""
    
    if ax is None:
        fig = plt.figure(figsize=(4, 4))
        ax = fig.add_subplot(111, projection='3d', computed_zorder=False)

    if plot_subprojection:
        
        if region == 'mprcg':
            region_idx = np.where(stsp_proj['regions'] == 'middleprecentral')[0][0] + 2
            title = 'mPrCG'
            scale_adjustment = 3.25
        
        x, y, z = stsp_proj['X_pca_all'][region_idx][:, 0], stsp_proj['X_pca_all'][region_idx][:, 1], stsp_proj['X_pca_all'][region_idx][:, 2]
        phase_labels = stsp_proj['phase_label'][region_idx]
        
        x = scale_adjustment * np.copy(x)
        y = scale_adjustment * np.copy(y)
        z = scale_adjustment * np.copy(z)
    
    else:
        
        ax.grid(False)
        ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        
        if region == 'sustained':
            region_idx = 0
            title = 'Sustained'
        elif region == 'nonsustained':
            region_idx = 1
            title = 'Non-sustained'
        elif region == 'mprcg':
            region_idx = np.where(stsp['regions'] == 'middleprecentral')[0][0] + 2
            title = 'mPrCG'
        
        x, y, z = stsp['X_pca_all'][region_idx][:, 0], stsp['X_pca_all'][region_idx][:, 1], stsp['X_pca_all'][region_idx][:, 2]
        xlim = (min(x), max(x))
        ylim = (min(y), max(y))
        zlim = (min(z), max(z))
        
        phase_labels = stsp['phase_label'][region_idx]
    
    if trajectory_color is None:
        pcolors = []
        for p in phase_labels:
            pcolors.append(phase_colors[phase_idx_to_names[str(p-1)]])
    else:
        pcolors = len(phase_labels) * [trajectory_color]
    
    if up_to_frame is not None:
        x = x[:up_to_frame]
        y = y[:up_to_frame]
        z = z[:up_to_frame]
        pcolors = pcolors[:up_to_frame]
        
    ax.scatter(x, y, z, c=pcolors, clip_on=False, zorder=zorder, alpha=alpha, s=s, label=label)
    
    if not plot_subprojection:
        ax.axes.set(xlim=xlim, ylim=ylim, zlim=zlim)
    
    if plot_planes:
        ax.view_init(elev=stsp['view_angles_plane'][region_idx, 1], azim=stsp['view_angles_plane'][region_idx, 0]-90)
        
        X = np.arange(min(x), max(x), 0.1)
        Y = np.arange(min(y), max(y), 0.1)
        X, Y = np.meshgrid(X, Y)
        
        encoding_Z = stsp['p_coeff'][0, region_idx].a*X + stsp['p_coeff'][0, region_idx].b*Y + stsp['p_coeff'][0, region_idx].c
        prod_Z = stsp['p_coeff'][3, region_idx].a*X + stsp['p_coeff'][3, region_idx].b*Y + stsp['p_coeff'][3, region_idx].c

        ax.plot_surface(X, Y, encoding_Z, color=phase_colors['encoding'], alpha=0.5, zorder=1)
        ax.plot_surface(X, Y, prod_Z, color=phase_colors['speech prod.'], alpha=0.5, zorder=1)
        
    elif not plot_planes and not plot_subprojection:
        ax.view_init(elev=elevations[region], azim=azimuths[region])
        
    ax.axes.set(xticks=[], yticks=[], zticks=[])
    ax.set_xlabel('PC1', labelpad=-15)
    ax.set_ylabel('PC2', labelpad=-15)
    ax.set_zlabel('PC3', labelpad=-15)
    
    if plot_title:
        ax.set_title(title, fontsize=fs)
        
    if plot_legend:
        ax.legend(ncol=2, fontsize=fs)
    
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
ax = state_space(region='sustained', plot_planes=True, plot_title=False);

In [None]:
ax = state_space(region='sustained', plot_planes=False, plot_title=False, alpha=0.3, s=5, trajectory_color='grey', zorder=1);
ax = state_space(region='mprcg', plot_planes=False, plot_title=False, ax=ax, plot_subprojection=True, alpha=0.7, s=15, zorder=4);

In [None]:
def proportions(region=None, ax=None, plot_title=True, fs=10, title=None, plot_xticklabels=False):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(1, 2))
        
    if region == 'sustained':
        region_idx = 0
        
        if title is None:
            title = 'Sustained'
    elif region == 'nonsustained':
        region_idx = 1
        
        if title is None:
            title = 'Non-sustained'
    elif region == 'mprcg':
        region_idx = np.where(stsp['regions'] == 'middleprecentral')[0][0] + 2
        
        if title is None:
            title = 'mPrCG'
    
    ax.bar(phase_colors.keys(), stsp['ss_dist_prop_show'][:, region_idx], color=phase_colors.values(), clip_on=False)
    
    if plot_title:
        ax.set_title(title, fontsize=fs)
        
    if plot_xticklabels:
        ax.set_xticklabels([l.capitalize() for l in phase_colors.keys()])
    else:
        ax.axes.set(xticks=[], xticklabels=[])
        
    ax.axes.set(ylim=(0, 0.5), yticks=[0, 0.5])
    ax.set_yticklabels([0.0, 0.5], fontsize=fs)
    ax.set_ylabel('Proportional distance', fontsize=fs)
    
    return ax

In [None]:
proportions(region='sustained');
proportions(region='nonsustained');
proportions(region='mprcg');

## Sustained area accuracy and projection angle similarity recon

In [None]:
def metric_compare_recon(hemi='lh', view='lateral', plot_colorbar=True, ax=None, metric='accuracy', s=10):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))

    if view == 'lateral':
        areas = ['mPrCG', 'vPrCG', 'SMG', 'IFG', 'pSTG']
    else:
        areas = ['medial SFG']

    img = plt.imread(mni_img_path.format(hemi, view))
    ax.imshow(img, alpha=0.5)
    ax.axis('off');

    for area in areas:
        
        if metric == 'accuracy':
            group_label = f'Sustained ({area})'
            
            if group_label == 'Sustained (medial SFG)':
                i = np.where(gc_data['legend_str'] == 'Sustained (SMA)')[0][0]
            else:
                i = np.where(gc_data['legend_str'] == group_label)[0][0]
                
            area_metric = gc_data['accuracy_ratio'][i]
            vlim = [0.75, 1]
            
        elif metric == 'projection_angle':
            if area == 'mPrCG':
                group_label = 'middleprecentral'
            elif area == 'vPrCG':
                group_label = 'ventralprecentral'
            elif area == 'SMG':
                group_label = 'supramarginal'
            elif area == 'IFG':
                group_label = 'pars'
            elif area == 'pSTG':
                group_label = 'posteriorsuperiortemporal'
            elif area == 'medial SFG':
                group_label = 'supplementarymotor'
                
            i = np.where(stsp_proj['regions'] == group_label)[0][0] + 2
            area_metric = -np.mean([stsp_proj['rp_angle1'][i], stsp_proj['rp_angle4'][i]])
            vlim = [-50, -10]

        if area == 'medial SFG':
            mask = plt.imread(mni_mask_path.format(hemi, view, 'sma'))
        else:
            mask = plt.imread(mni_mask_path.format(hemi, view, area.lower()))
        mask = mask[:, :, 0]
        mask = np.where(mask == 1, np.nan, area_metric)
        im = ax.pcolormesh(mask, cmap='viridis', alpha=0.5, vmin=vlim[0], vmax=vlim[1])

        if area == 'mPrCG':
            loc = ['middle PrCG']
        elif area == 'vPrCG':
            loc = ['ventral PrCG']
        elif area == 'IFG':
            loc = ['Pars operc.', 'Pars triang.']
        else:
            loc = [area]

        cur_df = edf.loc[(edf.view == view) & 
                         (edf.hemisphere == hemi) & 
                         (edf.subject_electrode.isin(sustained_se)) & 
                         (edf.fancy_location.isin(loc)) &
                         (edf.subject.isin(plot_subjects))]

        ax.scatter(cur_df.warp_x, cur_df.warp_y, color='k', s=s)

    if plot_colorbar:
        
        if metric == 'accuracy':
            cbar = plt.colorbar(im, ax=ax, fraction=0.032, ticks=vlim);
            cbar.set_label('Relative accuracy', labelpad=-10);
        elif metric == 'projection_angle':
            cbar = plt.colorbar(im, ax=ax, fraction=0.032, ticks=vlim)
            cbar.set_label('Projection similarity\nto sustained', labelpad=-10);
            cbar.ax.set_yticklabels([f'{-vlim[0]}$^\circ$', f'{-vlim[1]}$^\circ$']);
        
    return ax

In [None]:
metric_compare_recon(view='medial', plot_colorbar=True, metric='projection_angle');

# Overall figure

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', fontsize=5, linewidth=0.5, ticklength=2)
label_fontsize = 5

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['axes.labelpad'] = 1
plt.rcParams['patch.linewidth'] = 0.25
plt.rcParams['hatch.linewidth'] = 0.25
plt.rcParams['lines.markeredgewidth'] = 0.25
plt.rcParams['axes.titlepad'] = 2
plt.rcParams['xtick.major.pad'] = 2
plt.rcParams['ytick.major.pad'] = 2
plt.rcParams['legend.handletextpad'] = 1

all_axs = {}
mm = 1/25.4
fig = plt.figure(figsize=(180*mm, 140*mm))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig)

timecourse_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[:20, 31:90], 
                                                     width_ratios=[0.75, 0.5, 0.75], wspace=0.1)

########
all_axs['pca_sustained'] = fig.add_subplot(gs[20:85, 73:100], projection='3d')
all_axs['pca_sustained'] = state_space(region='sustained', plot_planes=False, plot_title=False, ax=all_axs['pca_sustained'], s=2);
all_axs['pca_sustained'].annotate('Sustained', (0.05, 1.12), xycoords='axes fraction', fontsize=6, annotation_clip=False)

all_axs['pca_projection'] = fig.add_subplot(gs[60:94, :25], projection='3d', computed_zorder=False)
all_axs['pca_projection'] = state_space(region='sustained', plot_planes=False, plot_title=False, alpha=0.3, s=2, ax=all_axs['pca_projection'],
                                       trajectory_color='grey', zorder=1);
all_axs['pca_projection'] = state_space(region='mprcg', plot_planes=False, plot_title=False, ax=all_axs['pca_projection'], plot_subprojection=True, 
                                        alpha=0.7, s=2, zorder=3);

all_axs['pca_sustained_planes'] = fig.add_subplot(gs[20:52, 85:97], projection='3d')
all_axs['pca_sustained_planes'].axis('off');
all_axs['pca_sustained_planes'] = state_space(region='sustained', plot_planes=True, plot_title=False, 
                                              ax=all_axs['pca_sustained_planes'], s=2);


########
all_axs['single_elec'] = fig.add_subplot(gs[0:22, 2:12])
all_axs['single_elec'] = friedman_test(ax=all_axs['single_elec'], legend_fs=label_fontsize, fs=label_fontsize, s=2, ms=3)

all_axs['phase_probability'] = np.array([fig.add_subplot(timecourse_gs[i]) for i in range(3)])
all_axs['phase_probability'] = phase_probability(decoding_results=sustained_decoding, axs=all_axs['phase_probability'], 
                                                 bbox_height=-0.3, legend_fs=label_fontsize)

all_axs['growth_curve'] = fig.add_subplot(gs[30:55, 2:23])
all_axs['growth_curve'] = growth_curve(ax=all_axs['growth_curve'], ms=4, fs=label_fontsize, sig_fs=7, legend_fs=label_fontsize)


all_axs['accuracy_recon_medial'] = fig.add_subplot(gs[30:42, 26:39])
all_axs['accuracy_recon_medial'] = metric_compare_recon(view='medial', plot_colorbar=False, metric='accuracy', 
                                                        s=1, ax=all_axs['accuracy_recon_medial']);

all_axs['accuracy_recon_lateral'] = fig.add_subplot(gs[30:60, 38:63])
all_axs['accuracy_recon_lateral'] = metric_compare_recon(view='lateral', plot_colorbar=True, metric='accuracy', 
                                                        s=1, ax=all_axs['accuracy_recon_lateral']);

all_axs['angle_recon_medial'] = fig.add_subplot(gs[64:76, 26:39])
all_axs['angle_recon_medial'] = metric_compare_recon(view='medial', plot_colorbar=False, metric='projection_angle', 
                                                        s=1, ax=all_axs['angle_recon_medial']);

all_axs['angle_recon_lateral'] = fig.add_subplot(gs[64:94, 38:63])
all_axs['angle_recon_lateral'] = metric_compare_recon(view='lateral', plot_colorbar=True, metric='projection_angle', 
                                                        s=1, ax=all_axs['angle_recon_lateral']);

########
panel_axs = {}
panel_axs['a'] = fig.add_subplot(gs[3:8, :10])
panel_axs['b'] = fig.add_subplot(gs[3:8, 28:38])
panel_axs['c'] = fig.add_subplot(gs[35:40, :10])
panel_axs['d'] = fig.add_subplot(gs[35:40, 30:40])
panel_axs['e'] = fig.add_subplot(gs[35:40, 75:85])
panel_axs['f'] = fig.add_subplot(gs[71:76, :10])
panel_axs['g'] = fig.add_subplot(gs[71:76, 30:40])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (-15, 35), xycoords='axes points', ha='right', fontsize=7, weight='bold');

# Supplementary figures

## Fig. S8 - Growth curve, all subgroups

In [None]:
fig, _ = growth_curve(legend_fs=10, all_subgroups=True, return_fig=True);

## Fig. S9 - State space, non-sustained and individual mPrCG

In [None]:
default_plot_settings(font='Helvetica', fontsize=12, linewidth=1.5, ticklength=4)

all_axs = {}
fig = plt.figure(figsize=(10, 10))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig)

########
all_axs['pca_mprcg'] = fig.add_subplot(gs[7:47, :55], projection='3d')
all_axs['pca_mprcg'] = state_space(region='mprcg', plot_planes=False, plot_title=False, ax=all_axs['pca_mprcg']);

all_axs['pca_mprcg_planes'] = fig.add_subplot(gs[40:70, :55], projection='3d')
all_axs['pca_mprcg_planes'].axis('off');
all_axs['pca_mprcg_planes'] = state_space(region='mprcg', plot_planes=True, plot_title=False, 
                                          ax=all_axs['pca_mprcg_planes']);

########
all_axs['pca_nonsustained'] = fig.add_subplot(gs[5:45, 50:], projection='3d')
all_axs['pca_nonsustained'] = state_space(region='nonsustained', plot_planes=False, plot_title=False, ax=all_axs['pca_nonsustained']);

all_axs['pca_nonsustained_planes'] = fig.add_subplot(gs[40:70, 50:], projection='3d')
all_axs['pca_nonsustained_planes'].axis('off');
all_axs['pca_nonsustained_planes'] = state_space(region='nonsustained', plot_planes=True, plot_title=False, 
                                              ax=all_axs['pca_nonsustained_planes']);

########
all_axs['pca_sustained_proportion'] = fig.add_subplot(gs[75:95, 25:35])
all_axs['pca_sustained_proportion'] = proportions(region='sustained', plot_title=True, ax=all_axs['pca_sustained_proportion']);

all_axs['pca_mprcg_proportion'] = fig.add_subplot(gs[75:95, 47:57])
all_axs['pca_mprcg_proportion'] = proportions(region='mprcg', plot_title=True, ax=all_axs['pca_mprcg_proportion'], title='mPrCG\n(sustained)');

all_axs['pca_nonsustained_proportion'] = fig.add_subplot(gs[75:95, 69:79])
all_axs['pca_nonsustained_proportion'] = proportions(region='nonsustained', plot_title=True, ax=all_axs['pca_nonsustained_proportion']);

########
all_axs['pca_mprcg'].annotate('mPrCG\n(sustained)', (0.1, 0.96), xycoords='axes fraction', fontsize=12);
all_axs['pca_nonsustained'].annotate('Non-sustained', (0.15, 0.79), xycoords='axes fraction', fontsize=12);

########
panel_axs = {}
panel_axs['a'] = fig.add_subplot(gs[12:17, 6:16])
panel_axs['b'] = fig.add_subplot(gs[12:17, 55:65])
panel_axs['c'] = fig.add_subplot(gs[48:53, 6:16])
panel_axs['d'] = fig.add_subplot(gs[48:53, 55:65])
panel_axs['e'] = fig.add_subplot(gs[79:84, 20:30])
panel_axs['f'] = fig.add_subplot(gs[79:84, 42:52])
panel_axs['g'] = fig.add_subplot(gs[79:84, 64:74])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (0, 65), xycoords='axes points', ha='right', fontsize=16, weight='bold');