# Imports

In [None]:
import sys
import os
from os.path import join, dirname, realpath, exists
import json
import string
import gc
import glob
import pickle
import inspect
import time
import random
import statistics
from itertools import combinations
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from copy import copy, deepcopy
import multiprocessing as mp

import cebra
import numpy as np
import xarray as xr
import pandas as pd
import scipy as scp
from tqdm.notebook import trange, tqdm
from scipy import signal, stats, interpolate
from scipy.io import loadmat
from scipy.stats import norm
from scipy.special import lambertw
import sklearn as skl
from sklearn.decomposition import PCA
from sklearn import preprocessing
np.random.seed(42)

%matplotlib qt
from PIL import Image
from io import BytesIO
import seaborn as sns
import ptitprince as pt
import pylab as pl
import plotly.io as pio
import plotly.express as px
import plotly.graph_objs as go
from colorcet.plotting import swatches, sine_combs
import colorcet as cc
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
from matplotlib.ticker import MultipleLocator, FuncFormatter
import matplotlib.gridspec as gridspec
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections import register_projection
from matplotlib.projections.polar import PolarAxes
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colors, cm
from matplotlib.colors import Normalize
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from matplotlib_scalebar.scalebar import ScaleBar
from statannotations.Annotator import Annotator
from matplotlib import rc
from matplotlib.font_manager import get_font_names
from IPython.display import display, Math, Latex, HTML
from vtk.util import numpy_support
rc('font',**{'family':'sans-serif','sans-serif':['FreeSans']})
plt.rcParams['svg.fonttype'] = 'none'

# MNE configuration
import mne
from mne.preprocessing import (ICA, corrmap, create_ecg_epochs,create_eog_epochs, annotate_muscle_zscore)
from mne.datasets import fetch_fsaverage
import mne_connectivity
mne.utils.set_config('MNE_USE_CUDA', 'true')
mne.set_log_level('error')  # reduce extraneous MNE output
mne.viz.set_browser_backend('matplotlib')
mne.viz.set_3d_options(
    antialias=False, depth_peeling=False, smooth_shading=False, multi_samples=1,
)

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter(action="ignore", category=FutureWarning)
np.seterr(all = 'ignore')

# Example_dir = dirname(realpath(__file__)) # directory of this file
modules_dir = '' # directory with all TMSI modules
sp_dir = '' # directory with all measurements
sys.path.append(modules_dir)

import numba
from numba import jit
import numpyro as npr
import numpyro.infer
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
import jax
import jax.numpy as jnp
from scipy.integrate import solve_ivp, quad
import arviz as az

from jax import grad, vmap, lax
from jax import random as jrandom
from functools import partial
from jax.experimental.ode import odeint

import torch
import sbi 
import sbi.inference
from sbi.inference.base import infer
from sbi.inference import SNPE, SNLE, SNRE, prepare_for_sbi ,simulate_for_sbi
from sbi.inference import likelihood_estimator_based_potential, DirectPosterior, MCMCPosterior, VIPosterior
from sbi.analysis import ActiveSubspace, pairplot
import sbi.utils as utils
from ssm.plots import plot_dynamics_2d

import nest 
import nest.raster_plot
nest.set_verbosity(100)

In [None]:
font = "sans-serif"  # Replace with your desired font
fontsize = 10  # Replace with your desired size

plt.rc('font', size=fontsize, family=font)

In [None]:
def load_SC_88(connectome_dir):   
    SC = np.loadtxt(connectome_dir+'weights.txt')
    centers = np.loadtxt(connectome_dir+'centers.txt', dtype=object)
    centers = pd.DataFrame(centers, columns=('label', 'x', 'y', 'z'))
    xyz = centers[['x', 'y', 'z']].values
    np.fill_diagonal(SC, 0.0)
    SC = SC/np.max(SC)
    return np.abs(SC), xyz

In [None]:
def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }

def find_nearest(array, values):
    # make sure array is a numpy array
    array = np.array(array)

    # get insert positions
    idxs = np.searchsorted(array, values, side="left")
    
    # find indexes where previous index is closer
    prev_idx_is_less = ((idxs == len(array))|(np.fabs(values - array[np.maximum(idxs-1, 0)]) < np.fabs(values - array[np.minimum(idxs, len(array)-1)])))
    idxs[prev_idx_is_less] -= 1
    
    return array[idxs], idxs

def formatData(data, time, timeLocks, binSize, maxLead, maxLag, sr = 120, isComplex = False):
    
    if len(data) != len(time):
        raise Exception('Data and time must have equal length')

    num_trials = len(timeLocks)
    
    rangeInds = np.round(np.array([-maxLead*sr, maxLag*sr], dtype = np.int64))
    
    nearestTimes, nearestInds = find_nearest(time, timeLocks)
    beginInds = nearestInds - rangeInds[0]
    endInds = nearestInds + rangeInds[-1]
    
    fData = np.zeros((np.sum(rangeInds), data.shape[1], num_trials))

    if isComplex:
        fData = fData.astype(complex)
    
    for trialInd in range(num_trials):
        
        trial_data = data[beginInds[trialInd] : endInds[trialInd],:]
        trial_length = trial_data.shape[0]
        
        if trial_length >= fData.shape[0]:
            fData[:, :, trialInd] = trial_data
        else:
            fData[:trial_length, :, trialInd] = trial_data
    
    return fData

In [None]:
def calculate_events(raw, pong_results, subject = '', num_trials = 160, filter_events = True):

    if subject == '':
        print('No input subject!')
    
    events = mne.find_events(raw, output = 'onset')
    
    if events.shape[0] != 507:
        events = events[1:,:]
    
    trigs = events[:,0]
    
    conditions = pong_results.sel(variable = 'condition', subject = subject)
    intercepts = pong_results.sel(variable = 'result', subject = subject)
    
    pcond = conditions == 1
    acond = conditions == 0
    
    feedback_times = pong_results.sel(variable = 'feedbackTime', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    thresh_times = pong_results.sel(variable = 'threshTime', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    ball_starts = pong_results.sel(variable = 'startTrig0', subject = subject).interpolate_na('trial', limit = None, method = 'spline')
    res_array = pong_results.sel(variable = 'result', subject = subject)

    tsDiffs = np.diff(trigs)
    start_trigs = np.where(tsDiffs <= 25)[0]
    end_trigs = start_trigs + 2
    
    events[start_trigs,2] = 1
    events[start_trigs+1,2] = 2
    events[end_trigs,2] = 3
    sEvents = events[start_trigs]
    eEvents = events[end_trigs]
    
    fb_to_thresh = np.round((feedback_times - thresh_times)* raw.info['sfreq'])
    feedbackTimestamps = (fb_to_thresh + events[end_trigs,0][-num_trials:]).to_numpy()
    
    fbEvents = eEvents[-num_trials:].copy()
    fbEvents[:,0] = feedbackTimestamps
    fbEvents[-num_trials:,2] = res_array
    
    bmEvents = sEvents[-num_trials:].copy()
    bmEvents[-num_trials:,2] = res_array
    
    pBMEvs = bmEvents[pcond]
    aBMEvs = bmEvents[acond]
    
    pFBEvs = fbEvents[pcond]
    aFBEvs = fbEvents[acond]

    if filter_events:
        return pBMEvs, aBMEvs, pFBEvs, aFBEvs
    else:
        return bmEvents, fbEvents, conditions, intercepts

In [None]:
def erp_default_config(tend=100, dt=0.1, t0=0, ns=9, constants=np.array([-0.56])):

    ts = np.arange(t0, tend + dt, dt)
    nt = ts.shape[0]
    x_init=np.zeros((ns))

    return ts, nt, x_init, constants

def DCM_default_params():
    
    delta=8.41;
    tau_i=10.38;
    h_i=25.0;
    tau_e=6.1;
    h_e=1.48;
    u=0.03;
    
    return delta, tau_i, h_i, tau_e, h_e, u

def odeint_euler(f, y0, t, *args):
    def step(state, t):
            y_prev, t_prev = state
            dt = t - t_prev
            y = y_prev + dt * f(y_prev, t_prev, *args)
            return (y, t), y
    _, ys = lax.scan(step, (y0, t[0]), t[0:])

    return ys

In [None]:
@jax.jit
def Sigmodal(x1, x2, delta, alpha):
    S=(1./(1.+jnp.exp(alpha*(x1-(delta*x2)))))-0.5
    return S

@jax.jit
def DCM_NMM_ERP_vector_field(state, t, constants, params):
    
    x0, x1, x2, x3, x4, x5, x6, x7, x8 = state
    g_1, g_2, g_3, g_4 = params
    delta, tau_i, h_i, tau_e, h_e, u = DCM_default_params()
    
    alpha = constants[0]

    dx0 = x3
    dx1 = x4
    dx2 = x5
    dx3 = (1./tau_e) * (h_e * (g_1 * (Sigmodal(x8, x4 - x5, delta, alpha)) + u) - (x0 / tau_e) - 2 * x3)
    dx4 = (1./tau_e) * (h_e * (g_2 * (Sigmodal(x0, x3, delta, alpha))) - (x1 / tau_e) - 2 * x4)
    dx5 = (1./tau_i) * (h_i * (g_4 * (Sigmodal(x6, x7, delta, alpha))) - (x2 / tau_i) - 2 * x5)
    dx6 = x7
    dx7 = (1. / tau_e) * (h_e * (g_3 * (Sigmodal(x8, x4 - x5, delta, alpha))) - (x6 / tau_e) - 2 * x7)
    dx8 = x4 - x5

    return jnp.array([dx0, dx1, dx2, dx3, dx4, dx5, dx6, dx7, dx8])

@jax.jit
def DCM_NMM_ERP_ODE_JAXOdeint(params, constants, x_init, ts):

    xs_rk4 = odeint_euler(DCM_NMM_ERP_vector_field,  x_init, ts, constants, params)
    x_py=xs_rk4[:,8]
    
    return x_py

@jax.jit
def DCM_NMM_ERP_ODE_JAXOdeint_full(params, constants, x_init, ts):

    xs_rk4 = odeint_euler(DCM_NMM_ERP_vector_field,  x_init, ts, constants, params)
    
    return xs_rk4

In [None]:
def simulate_from_posterior(params, tend=100):

    ts, nt, x_init, constants = erp_default_config(tend=tend)
    
    simData = DCM_NMM_ERP_ODE_JAXOdeint_full(params, constants, x_init, ts)

    xsc = simData[:, 0]
    xin = simData[:, 6]
    xpy = simData[:, 8]
    
    return xsc, xin, xpy

In [None]:
#####################################################
def LSE(x1, x2):
    return np.sum((x1 - x2)**2)
#####################################################
def Err(x1, x2):
    return np.sum(np.abs(x1 - x2))
#####################################################    
def RMSE(x1, x2):
    return np.sqrt(((x1 - x2) ** 2).mean()) 
#####################################################
def LSE_obs(Obs, Obs_lo, Obs_hi):
    return np.average([LSE(Obs, Obs_lo), LSE(Obs, Obs_hi)])
#####################################################
def z_score(true_mean, post_mean, post_std):
    return np.abs((post_mean - true_mean) / post_std)
#####################################################
def shrinkage(prior_std, post_std):
    return 1 - (post_std / prior_std)**2
#####################################################

from scipy.stats import gaussian_kde
from sbi.analysis.plot import _get_default_opts, _update, ensure_numpy


def _get_limits(samples, limits=None):

    if type(samples) != list:
        samples = ensure_numpy(samples)
        samples = [samples]
    else:
        for i, sample_pack in enumerate(samples):
            samples[i] = ensure_numpy(samples[i])

    # Dimensionality of the problem.
    dim = samples[0].shape[1]

    if limits == [] or limits is None:
        limits = []
        for d in range(dim):
            min = +np.inf
            max = -np.inf
            for sample in samples:
                min_ = sample[:, d].min()
                min = min_ if min_ < min else min
                max_ = sample[:, d].max()
                max = max_ if max_ > max else max
            limits.append([min, max])
    else:
        if len(limits) == 1:
            limits = [limits[0] for _ in range(dim)]
        else:
            limits = limits
    limits = torch.as_tensor(limits)

    return limits

def posterior_peaks(samples, return_dict=False, **kwargs):
    '''
    Finds the peaks of the posterior distribution.

    Args:
        samples: torch.tensor, samples from posterior
    Returns: torch.tensor, peaks of the posterior distribution
            if labels provided as a list of strings, and return_dict is True
            returns a dictionary of peaks

    '''

    opts = _get_default_opts()
    opts = _update(opts, kwargs)

    if type(samples) != np.ndarray:
        samples = samples.numpy()

    limits = _get_limits(samples)
    n, dim = samples.shape

    try:
        labels = opts['labels']
    except:
        labels = range(dim)

    peaks = {}
    if labels is None:
        labels = range(dim)
    for i in range(dim):
        peaks[labels[i]] = 0

    for row in range(dim):
        density = gaussian_kde(
            samples[:, row],
            bw_method=opts["kde_diag"]["bw_method"])
        xs = np.linspace(
            limits[row, 0], limits[row, 1],
            opts["kde_diag"]["bins"])
        ys = density(xs)

        # y, x = np.histogram(samples[:, row], bins=bins)
        peaks[labels[row]] = xs[ys.argmax()]

    if return_dict:
        return peaks
    else:
        return list(peaks.values())
    
def plot_posterior(samples,
                   ax,
                   prob=[0.025, 0.975],
                   labels=None,
                   xlim=None,
                   ylim=None,
                   xticks=None,
                   yticks=None,
                   xlabel=None,
                   ylabel=None):

    if ax is None:
        print('pass axis!')
        exit(0)

    samples = samples.numpy()
    # print(type(samples))
    # print(samples.shape)
    dim = samples.shape[1]
    # assert (len(ax) == dim)
    if labels is not None:
        assert (len(labels) == dim)

    hist_diag = {"alpha": 1.0, "bins": 50, "density": True, "histtype": "step"}

    max_values = np.zeros(dim)

    for i in range(dim):
        ax0 = ax if (dim == 1) else ax[i]
        n, bins, _ = ax0.hist(samples[:, i], **hist_diag)

        if labels is not None:
            ax0.set_xlabel(labels[i], fontsize=13)
        ax0.tick_params(labelsize=12)

        xs = mquantiles(samples[:, i], prob)

        max_value = bins[np.argmax(n)]
        max_values[i] = max_value
        ax0.axvline(x=max_value, ls='--', color='gray', lw=2)
        # for j in range(2):
        #     ax0.axvline(x=xs[j], ls='--', color="royalblue", lw=2)
        y = n[np.where((bins > xs[0]) & (bins < xs[1]))]

        ax0.fill_between(np.linspace(xs[0], xs[1], len(y)), y,
                         color="gray",
                         alpha=0.2)

        ax0.set_title("{:g}".format(max_value))

    plt.tight_layout()

    return list(max_values)


###############################################################################
# Definition of functions used in this example. First, define the `Lambert W`
# function implemented in SLI. The second function computes the maximum of
# the postsynaptic potential for a synaptic input current of unit amplitude
# (1 pA) using the `Lambert W` function. Thus function will later be used to
# calibrate the synaptic weights.

def LambertWm1(x):
    # Using scipy to mimic the gsl_sf_lambert_Wm1 function.
    return lambertw(x, k=-1 if x < 0 else 0).real


def ComputePSPnorm(tauMem, CMem, tauSyn):
    a = (tauMem / tauSyn)
    b = (1.0 / tauSyn - 1.0 / tauMem)

    # time of maximum
    t_max = 1.0 / b * (-LambertWm1(-np.exp(-1.0 / a) / a) - 1.0 / a)

    # maximum of PSP for current of unit amplitude
    return (np.exp(1.0) / (tauSyn * CMem * b) *
            ((np.exp(-t_max / tauMem) - np.exp(-t_max / tauSyn)) / b -
             t_max * np.exp(-t_max / tauSyn)))

In [None]:
def spike_to_rate(spiketimes, nbins = 100, remove_tails = False, axis = -1):

    spiketimes = spiketimes[np.array(spiketimes)!=None]

    binned_spikes, _ = np.histogram(spiketimes, bins = nbins, range = (0,1000))

    binned_fr = smooth_rates(binned_spikes, nbins = nbins, remove_tails = remove_tails, axis = axis)
    
    return binned_fr

def smooth_rates(firing_rate, nbins = 100, remove_tails = False, axis = -1, order = 5, lp_savgol = 3, lp_filtfilt = 2):
    
    nneigh = 70
    
    if remove_tails:
        lowpass = signal.butter(order, lp_savgol, 'lp', fs=nbins, output='sos')
        firing_rate = signal.savgol_filter(firing_rate, nneigh, order, mode = 'mirror', axis = axis)
    else:
        lowpass = signal.butter(order, lp_filtfilt, 'lp', fs=nbins, output='sos')

    firing_rate = signal.sosfiltfilt(lowpass, firing_rate, axis = axis)
    
    return firing_rate

def subpopulationCurrent(t, h, tau):
    return (h/tau) * t * np.exp(-t/tau)

In [None]:
def modify_axis_spines(ax, which=None, base=1.0, xticks=[], yticks=[], yaxis_left=True, xaxis_bot=True):

    tick_locator = plticker.MultipleLocator(base=base)

    if yaxis_left: 
        ax.spines.right.set(visible=False)
        yspine = ax.spines.left
    else:
        ax.spines.left.set(visible=False)
        yspine = ax.spines.right
        
    if xaxis_bot:
        ax.spines.top.set(visible=False)
        xspine = ax.spines.bottom
    else:
        ax.spines.bottom.set(visible=False)
        xspine = ax.spines.top
                           
    if 'x' in which:
        if len(xticks) == 0:
            xticks = ax.get_xticks() 
            ax.xaxis.set_major_locator(tick_locator)
        ax.set_xticks(xticks)
        xspine.set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])
        
    else:
        ax.spines.bottom.set(visible=False)
    
    if 'y' in which:
        if len(yticks) == 0:
            yticks = ax.get_yticks()
        ax.set_yticks(yticks)
        yspine.set_bounds(ax.get_yticks()[0], ax.get_yticks()[-1])
        if len(yticks) == 0:
            ax.yaxis.set_major_locator(tick_locator)
    else:
        ax.spines.left.set(visible=False)

def fmt_plot_text(text):
    return f'{text:.2f}'

def get_source_time_label(time_value):
    return 'Time from ball movement: {} ms'.format(np.round(time_value*1000))

In [None]:
def subpopulationCurrent(t, h, tau):
    return (h/tau) * t * np.exp(-t/tau)

def compute_syn_current(cond_thetas_max):
    
    ex_g_norm = np.sum(cond_thetas_max[0:3, :], axis=0)
    ex_g_norm /= ex_g_norm.max()
    in_g_norm = cond_thetas_max[3,:]
    synRatio = ex_g_norm/in_g_norm
    
    num_cond = 2
    time_vec = np.arange(0, 100, 0.1)
    
    delta, tau_i, h_i, tau_e, h_e, u = DCM_default_params()
    
    I_i = np.zeros((len(time_vec), num_cond))
    I_e = np.zeros((len(time_vec), num_cond))
    auc_i = np.zeros((num_cond))
    auc_e = np.zeros((num_cond))
    
    for c_ind in range(num_cond):
    
        I_i[:, c_ind] = subpopulationCurrent(time_vec, h_i, tau_i)*in_g_norm[c_ind]
        I_e[:, c_ind] = subpopulationCurrent(time_vec, h_e, tau_e)*ex_g_norm[c_ind]
    
        auc_i[c_ind] = np.trapz(y=I_i[:, c_ind], x=time_vec)
        auc_e[c_ind] = np.trapz(y=I_e[:, c_ind], x=time_vec)
    
    EIRatio = auc_e/auc_i
    condEIRatio = EIRatio[0]/EIRatio[1]

    return I_e, I_i, EIRatio, time_vec

In [None]:
def genPlotConfig(ylim = [-4, 8], colors = ['k', 'royalblue', 'royalblue'], alpha = 0.5,
                  title = '', markersize = 5, labels = ['Empirical', 'C-I', 'Fitted'],
                  linewidths = [1,1,2], ticksize = 14, titlesize = 18, labelsize = 14,
                  zorders = [1, 2, 3], frameon = False, fill = True):
    
    pDict = {'ylim':ylim,
              'color': colors,
              'alpha': alpha,
              'title': title,
              'labels': labels,
             'linewidths': linewidths,
             'ticksize': ticksize,
             'titlesize': titlesize,
             'zorders': zorders,
             'frameon': frameon,
             'labelsize': labelsize,
             'markersize': markersize,
             'fill': fill
             }
    
    return pDict

In [None]:
def linear_function(x, a, b):
    return (a*x + b)

def nonlinear_function(x, a, b, c, d):
    return (a + b*x + c*x**2 + d*x**3)

def linreg_system(N, y, x=None):
    a = npr.sample('a', dist.Normal(0, 10))
    b = npr.sample('b', dist.Normal(50, 100))
    sigma= npr.sample('sigma', dist.HalfNormal(100))
    xdot = npr.deterministic('xdot', linear_function(x=x, a=a, b=b))

    with npr.plate('N', N):
        npr.sample('obs', dist.Normal(xdot, sigma), obs=y)

def nonlinear_system(N, y, x=None):
    a = npr.sample('a', dist.Normal(0, 10))
    b = npr.sample('b', dist.Normal(0, 10))
    c = npr.sample('c', dist.Normal(0, 10))
    d = npr.sample('d', dist.Normal(0, 10))
    sigma= npr.sample('sigma', dist.HalfNormal(10))
    xdot = npr.deterministic('xdot', nonlinear_function(x=x, a=a, b=b, c=c, d=d))

    with npr.plate('N', N):
        npr.sample('obs', dist.Normal(xdot, sigma), obs=y)


def run_mcmc_from_system(target_system, x, y, num_warmup = 1000, num_samples = 2000):

    N = x.size

    if type(x) == xr.core.dataarray.DataArray:
        x = x.to_numpy()

    nuts_kernel = NUTS(target_system, adapt_step_size=True)
    mcmc = MCMC(nuts_kernel, num_chains=1, num_warmup=num_warmup, num_samples=num_samples)
    rng_key = jax.random.PRNGKey(0)
    mcmc.run(rng_key, N=N, y=y, x=x)

    return mcmc

In [None]:
def plot_ERP_fit(time, data, x_fit, x_ppc_lo, x_ppc_hi, fig, axe, plotConfig, plot_fit = True, plot_fill = True, setLabels = True):
    
    ylim = plotConfig['ylim']
    colors = plotConfig['color']
    alpha = plotConfig['alpha']
    title = plotConfig['title']
    labels = plotConfig['labels']
    linewidths = plotConfig['linewidths']
    ticksize = plotConfig['ticksize']
    titlesize = plotConfig['titlesize']
    labelsize = plotConfig['labelsize']
    zorders = plotConfig['zorders']
    markersize = plotConfig['markersize']
    alpha = 0.5
    
    axe.plot(time, data, lw = linewidths[0], color = colors[0], zorder = zorders[0], label = labels[0])
    
    if plot_fit:
        axe.plot(time, x_fit, lw = linewidths[2], marker = 'o', markersize = markersize, alpha = alpha, ls = '-', color=colors[2] ,label=labels[2], markevery = 70)
    
    # axe.set_ylim(ylim)
    axe.tick_params(length = 0)
    axe.tick_params(axis='both', which='major', labelsize=ticksize)
    axe.set_title(title, fontsize=titlesize)
    axe.set_xticks([])
    axe.set_yticks([])
    
    if ~plotConfig['frameon']:
            axe.set_frame_on(plotConfig['frameon'])

    if plot_fill:
        axe.fill_between(time, x_ppc_lo, x_ppc_hi, linewidth=linewidths[1], alpha=alpha, facecolor=colors[1], edgecolor=colors[1], zorder=zorders[1], label=labels[1])

    if setLabels:
        axe.set_xlabel('Time (ms)', size = labelsize)
        axe.set_ylabel('Voltage (mv)', size = labelsize)
    

In [None]:
def generate_random_colors(length, color_range="rgb", alpha=None):
  """
  Generates a list of random colors of specified length with optional alpha value.

  Args:
      length (int): The desired length of the color list.
      color_range (str, optional): The color space for generation.
          Defaults to "rgb". Other options include "hsv".
      alpha (float, optional): The alpha value for transparency (0.0-1.0).
          If None, colors will be fully opaque.

  Returns:
      list: A list of random colors in hex format with optional alpha.
  """
  colors = []
  if color_range == "rgb":
    for _ in range(length):
      # Generate random RGB values between 0 and 255
      r = hex(np.random.randint(0, 255))[2:].zfill(2)
      g = hex(np.random.randint(0, 255))[2:].zfill(2)
      b = hex(np.random.randint(0, 255))[2:].zfill(2)
      if alpha is not None:
        a = hex(int(alpha * 255))[2:].zfill(2)  # Convert alpha to hex (0-255)
        colors.append(f"#{a}{r}{g}{b}")
      else:
        colors.append(f"#{r}{g}{b}")
  elif color_range == "hsv":
    for _ in range(length):
      # Generate random HSV values (consider adjusting ranges for desired hue)
      h = int(np.random.random() * 360)
      s = 100
      v = 100
      if alpha is not None:
        a = hex(int(alpha * 255))[2:].zfill(2)  # Convert alpha to hex (0-255)
        colors.append(f"#{a}{h:02x}{s:02x}{v:02x}")
      else:
        colors.append(f"#{h:02x}{s:02x}{v:02x}")
  else:
    raise ValueError("Unsupported color range. Choose 'rgb' or 'hsv'.")
  return colors

In [None]:
def radar_factory(num_vars, frame='circle'):
    """
    Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle', 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarTransform(PolarAxes.PolarTransform):

        def transform_path_non_affine(self, path):
            # Paths with non-unit interpolation steps correspond to gridlines,
            # in which case we force interpolation (to defeat PolarTransform's
            # autoconversion to circular arcs).
            if path._interpolation_steps > 1:
                path = path.interpolated(num_vars)
            return Path(self.transform(path.vertices), path.codes)

    class RadarAxes(PolarAxes):

        name = 'radar'
        PolarTransform = RadarTransform

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.append(x, x[0])
                y = np.append(y, y[0])
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("Unknown value for 'frame': %s" % frame)

        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)
                return {'polar': spine}
            else:
                raise ValueError("Unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta

In [None]:
def compute_FC(data, fc_only=True):
    
    fc = np.corrcoef(data)
    cov = np.cov(data)

    if fc_only:
        return fc
    else:
        return fc, cov

In [None]:
def calculate_summary_statistics_brunel(x, features):
    
    """Calculate summary statistics

    Parameters
    ----------
    x : output of the simulator
    x_features[i,:], _
    Returns
    -------.36
    np.array, summary statistics
    """

    featureRanges = {}
    

    x_l=x[0:int(x.shape[0]/2)]
    x_r=x[int(x.shape[0]/2):int(x.shape[0])]

    maxInd = np.argmax(x)
    minInd = np.argmin(x)
    xMax = x[maxInd]
    xMin = x[minInd]

    x_l = x[0:maxInd]
    x_r = x[maxInd:int(x.shape[0])]

    sum_stats_vec = np.concatenate((np.array([xMax]),
                                    np.array([maxInd])))


    
    ssvShape = sum_stats_vec.shape[0]
    featureRanges['stats'] = [0, ssvShape]


    for item in features:

        if item == 'higher_moments':

            ssvShape = sum_stats_vec.shape[0]

            sum_stats_vec = np.concatenate((sum_stats_vec,
                                        np.array([moment(x, moment=2)]),
                                        np.array([moment(x, moment=3)]),
                                        np.array([moment(x, moment=4)]),
                                        np.array([moment(x, moment=5)]),
                                        np.array([moment(x, moment=6)]),
                                        np.array([moment(x, moment=7)]),
                                        np.array([moment(x, moment=8)])
                                                           ))

            featureRanges[item] = [ssvShape, sum_stats_vec.shape[0]]

        if item == 'signal_power':

            x_area = np.trapz(x, dx=0.1)

            ssvShape = sum_stats_vec.shape[0]

            sum_stats_vec = np.concatenate((sum_stats_vec,
                                            np.array([x_area])
                                                ))

            featureRanges[item] = [ssvShape, sum_stats_vec.shape[0]]

    return sum_stats_vec, featureRanges

calculate_summary_statistics_brunel = jit(calculate_summary_statistics_brunel)

In [None]:
def get_ax_size(ax):
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    width *= fig.dpi
    height *= fig.dpi
    return width, height
    
def plot_raster(raster, ax, axis=-1, offset=0, marker='.', color='b', alpha=1, markersize=1):

    raster = np.array(raster)
    raster[raster > 1] = 1    
    
    for rInd in range(raster.shape[axis]):
        
        t_spikes = raster[:,rInd]
        sinds = np.where(t_spikes != 0)[0]
        ax.scatter(sinds, t_spikes[sinds]*rInd + offset, marker=marker, color=color, alpha=alpha, s=markersize)

In [None]:
def simulate_brunel(g, save_name = '', save_path = '', save_results = False):
    
    ###############################################################################
    # Setting up simulation parameters
    
    g = np.array(g).ravel()[0]
    
    dt = 0.1    # the resolution in ms
    simtime = 1000.0  # Simulation time in ms    
    num_workers = 8  # number of threads
    
    nest.ResetKernel()
    nest.SetKernelStatus({'local_num_threads': num_workers, "print_time": False, "overwrite_files": True})
    nest.resolution = dt

    ###############################################################################
    # Neural parameters
    
    ctime_start = 350 # start of external current
    ctime_end = 900 # end of external current
    meanCurr = 150
    stdCurr = 1
    
    delay = 1.5    # synaptic delay in ms
    eta = 1  # external rate relative to threshold rate
    tauSyn = 0.5  # synaptic time constant in ms
    tauMem = 20.0  # time constant of membrane potential in ms
    CMem = 250.0  # capacitance of membrane in in pF
    theta = 20.0  # membrane threshold potential in mV
    
    neuron_params = {"C_m": CMem,
                     "tau_m": tauMem,
                     "tau_syn_ex": tauSyn,
                     "tau_syn_in": tauSyn,
                     "t_ref": 2.0,
                     "E_L": 0.0,
                     "V_reset": 10.0,
                     "V_m": 0.0,
                     "V_th": theta}
    
    
    J = 0.1        # postsynaptic amplitude in mV
    J_unit = ComputePSPnorm(tauMem, CMem, tauSyn)
    J_ex = J / J_unit  # amplitude of excitatory postsynaptic current
    J_in = -g * J_ex    # amplitude of inhibitory postsynaptic current
        
    ###############################################################################
    # Network parameters
    
    order = 2500
    NE = 4 * order  # number of excitatory neurons
    NI = 1 * order  # number of inhibitory neurons
    N_neurons = NE + NI   # number of neurons in total
    epsilon = 0.1  # connection probability
    N_rec = 100      # record from N_rec neurons

    CE = int(epsilon * NE)  # number of excitatory synapses per neuron
    CI = int(epsilon * NI)  # number of inhibitory synapses per neuron
    C_tot = int(CI + CE)      # total number of synapses per neuron
    
    conn_params_ex = {'rule': 'fixed_indegree', 'indegree': CE}
    conn_params_in = {'rule': 'fixed_indegree', 'indegree': CI}
    
    ###############################################################################
    # Rate of external poisson input
    
    nu_th = (theta * CMem) / (J_ex * CE * np.exp(1) * tauMem * tauSyn)
    nu_ex = eta * nu_th
    p_rate = 1000.0 * nu_ex * CE
    
    ###############################################################################
    # Set recorders and stimulator
    
    nodes_ex = nest.Create("iaf_psc_alpha", NE, params=neuron_params)
    nodes_in = nest.Create("iaf_psc_alpha", NI, params=neuron_params)
    noise = nest.Create("poisson_generator", params={"rate": p_rate})
    espikes = nest.Create("spike_recorder")
    ispikes = nest.Create("spike_recorder")
    pop_stepcurrent = nest.Create('noise_generator', n = 1)

    ###############################################################################
    # Set recorders and stimulator
    
    pop_stepcurrent.set(start = ctime_start, stop = ctime_end, mean = meanCurr, std = stdCurr)
    pop_stepcurrent.set(start = ctime_start, stop = ctime_end, mean = meanCurr, std = stdCurr)
    espikes.set(label="brunel-ex", record_to="memory")
    ispikes.set(label="brunel-in", record_to="memory")

    nest.CopyModel("static_synapse", "excitatory",
                   {"weight": J_ex, "delay": delay})
    nest.CopyModel("static_synapse", "inhibitory",
                   {"weight": J_in, "delay": delay})

    ###############################################################################
    # Connect network elements
    
    nest.Connect(pop_stepcurrent, nodes_ex, syn_spec={'weight': 1.})
    nest.Connect(pop_stepcurrent, nodes_in, syn_spec={'weight': 1.})

    nest.Connect(noise, nodes_ex, syn_spec="excitatory")
    nest.Connect(noise, nodes_in, syn_spec="excitatory")

    nest.Connect(nodes_ex[:N_rec], espikes, syn_spec="excitatory")
    nest.Connect(nodes_in[:N_rec], ispikes, syn_spec="excitatory")
    
    nest.Connect(nodes_ex, nodes_ex + nodes_in, conn_params_ex, "excitatory")
    nest.Connect(nodes_in, nodes_ex + nodes_in, conn_params_in, "inhibitory")
    
    nest.Simulate(simtime)

    ###############################################################################
    # Save spike times to disk
    ex_senders = espikes.events['senders']
    in_senders = ispikes.events['senders']
    ex_times = espikes.events['times']
    in_times = ispikes.events['times']

    if not save_path.endswith('/'):
        save_path = save_path + '/'
    
    if save_results:
        np.savez_compressed(file = save_path + save_name + '.npz', g = g, ex_senders = ex_senders, in_senders = in_senders, ex_times = ex_times, in_times = in_times)
    else:
        return ex_senders, ex_times

In [None]:
def plot_sample_distribution(fig, ax):
    alpha = 0.2
    colors = ['royalblue','m','r']
    
    xrange = [-27, 27]
    x_axis = np.arange(xrange[0], xrange[1], 0.01)
    
    sdvs = [5, 3, 4]
    avgs = [0, 5, 12]
    # Calculating mean and standard deviation
    mean = statistics.mean(x_axis)
    sd = statistics.stdev(x_axis)
    
    normdata = [norm.pdf(x_axis, avgs[ind], sdvs[ind]) for ind in range(len(sdvs))]
    
    for ind, normdata in enumerate(normdata):
        
        ax.fill_between(x_axis, normdata, alpha = alpha, color = colors[ind], linewidth = 0)
    
    ax.set_frame_on(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim([-17, 26])
    # fig.savefig(fig_save_loc + 'posteriorSamplePlot.svg', transparent = True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
from numpy.linalg import inv

def ls_ellipsoid(xx, yy, zz):                                  
    x = xx[:, np.newaxis]
    y = yy[:, np.newaxis]
    z = zz[:, np.newaxis]
    
    J = np.hstack((x*x, y*y, z*z, x*y, x*z, y*z, x, y, z))
    K = np.ones_like(x) 
    
    JT = J.transpose()
    JTJ = np.dot(JT, J)
    InvJTJ = np.linalg.inv(JTJ)
    ABC = np.dot(InvJTJ, np.dot(JT, K))

    eansa = np.append(ABC, -1)

    return eansa

def polyToParams3D(vec, printMe):                             
    Amat = np.array([
        [vec[0], vec[3]/2.0, vec[4]/2.0, vec[6]/2.0],
        [vec[3]/2.0, vec[1], vec[5]/2.0, vec[7]/2.0],
        [vec[4]/2.0, vec[5]/2.0, vec[2], vec[8]/2.0],
        [vec[6]/2.0, vec[7]/2.0, vec[8]/2.0, vec[9]]
    ])
    
    A3 = Amat[0:3, 0:3]
    A3inv = inv(A3)
    ofs = vec[6:9]/2.0
    center = -np.dot(A3inv, ofs)
    
    Tofs = np.eye(4)
    Tofs[3, 0:3] = center
    R = np.dot(Tofs, np.dot(Amat, Tofs.T))
    
    R3 = R[0:3, 0:3]
    s1 = -R[3, 3]
    R3S = R3/s1
    (el, ec) = np.linalg.eig(R3S)
    
    recip = 1.0 / np.abs(el)
    axes = np.sqrt(recip)
    
    inve = inv(ec)
    
    return (center, axes, inve)

In [None]:
def rotate_points(points, axis, angle_degrees):
    # Convert angle from degrees to radians
    angle_radians = np.deg2rad(angle_degrees)
    
    # Create a rotation object
    r = Rotation.from_rotvec(axis * angle_radians)
    
    # Apply rotation to the points
    rotated_points = r.apply(points)
    
    return rotated_points

In [None]:
def jitter_raster(binary_array, jitter_amount=1):
  """Jitters 'one' values in each row of a binary array by a random amount.

  Args:
      binary_array: A 2D numpy array with zeros and ones.
      jitter_amount: The maximum number of positions to jitter by (default: 1).

  Returns:
      A new binary array with jittered 'one' values.
  """

  rows, cols = binary_array.shape
  jittered_array = np.zeros_like(binary_array)

  for i in range(rows):
    row = binary_array[i]
    one_indices = np.where(row == 1)[0]  # Find indices of 'one' values

    # Generate random jitter but clip to valid range (0 to cols-1)
    jitter = np.random.randint(-jitter_amount, jitter_amount + 1, size=len(one_indices))
    jitter = np.clip(jitter, 0, cols - 1)  # Clip jitter to valid range

    # Apply jitter to original indices
    jittered_indices = (one_indices + jitter) % cols  # Wrap around if exceeding bounds

    # Place jittered 'one' values back into the jittered array
    jittered_array[i, jittered_indices] = 1

  return jittered_array

In [None]:
def letter_annotation(ax, xoffset, yoffset, letter, fontsize=plt.rcParams['font.size'], fontweight='bold', aux_args={}):
    ax.text(x=xoffset, y=yoffset, s=letter, transform=ax.transAxes, 
            fontsize=fontsize, weight=fontweight, **aux_args)

def get_alphabet_list(length=26):
    
    sorted_alphabet = string.ascii_uppercase[:length]
    alphabet_list = list(sorted_alphabet)
    
    return alphabet_list

In [None]:
from ptitprince import *

def plot_RainClouds(x = None, y = None, hue = None, data = None, num_samples=1000, swarm_sample=False,
              order = None, hue_order = None, plot_legend=False,
              orient = "v", width_viol = .7, width_box = .15, clouds=True,
              palette = "Set2", bw = .2, linewidth = 1, cut = 0.,
              scale = "area", jitter = 1, move = 0., offset = None,
              point_size = 3, ax = None, pointplot = False,
              alpha = None, dodge_violin=False, dodge_boxes=True, linecolor = 'red', **kwargs):
    
    if orient == 'h':  # swap x and y
        x, y = y, x
    if ax is None:
        ax = plt.gca()
        # f, ax = plt.subplots(figsize = figsize) old version had this
    
    if offset is None:
        offset = max(width_box/1.8, .15) + .05
    n_plots = 3
    split = False
    boxcolor = "black"
    boxprops = {'facecolor': 'none', "zorder": 10}
    # if hue is not None:
    #     split = True
    #     boxcolor = palette
    #     boxprops = {"zorder": 10}
    
    kwcloud = dict()
    kwbox   = dict(saturation = 1, whiskerprops = {'linewidth': 2, "zorder": 10})
    kwrain  = dict(zorder = 0, edgecolor = None)
    kwpoint = dict(capsize = 0., errwidth = 0., zorder = 20)
    for key, value in kwargs.items():
        if "cloud_" in key:
            kwcloud[key.replace("cloud_", "")] = value
        elif "box_" in key:
            kwbox[key.replace("box_", "")] = value
        elif "rain_" in key:
            kwrain[key.replace("rain_", "")] = value
        elif "point_" in key:
            kwpoint[key.replace("point_", "")] = value
        else:
            kwcloud[key] = value

    if swarm_sample:
        swarm_df = data.sample(num_samples)
    else:
        swarm_df = data
    # Draw cloud/half-violin
    if clouds:
        half_violinplot(x = x, y = y, hue = hue, data = data,
                        order = order, hue_order = hue_order,
                        orient = orient, width = width_viol, dodge=dodge_violin,
                        inner = None, palette = palette, bw = bw,  linewidth = linewidth,
                        cut = cut, scale = scale, split = split, offset = offset, ax = ax, **kwcloud)
    
    # Draw umberella/boxplot
    sns.boxplot(x = x, y = y, hue = hue, data = data, orient = orient, width = width_box,
                         order = order, hue_order = hue_order, dodge=dodge_boxes,
                         color = boxcolor, showcaps = True, boxprops = boxprops,
                         palette = palette, ax =ax, **kwbox)
    
    # Set alpha of the two
    if not alpha is None:
        _ = plt.setp(ax.collections + ax.artists, alpha = alpha)
    
    # Draw rain/stripplot
    ax = stripplot(x = x, y = y, hue = hue, data = swarm_df, orient = orient, dodge=dodge_boxes,
                    order = order, hue_order = hue_order, palette = palette,
                    move = move, size = point_size, jitter = jitter,
                    width = width_box, ax = ax, **kwrain)
    
    # Add pointplot
    if pointplot:
        n_plots = 4
        if not hue is None:
            sns.pointplot(x = x, y = y, hue = hue, data = data,
                          orient = orient, order = order, hue_order = hue_order,
                          dodge = width_box/2., palette = palette, ax = ax, **kwpoint)
        else:
            sns.pointplot(x = x, y = y, hue = hue, data = data, color = linecolor,
                           orient = orient, order = order, hue_order = hue_order,
                           dodge = width_box/2., ax = ax, **kwpoint)
    
    # Prune the legend, add legend title
    if (not hue is None) and (plot_legend):
        handles, labels = ax.get_legend_handles_labels()
        _ = plt.legend(handles[0:len(labels)//n_plots], labels[0:len(labels)//n_plots], \
                       bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., \
                       title = str(hue))#, title_fontsize = 25)
    
    # Adjust the ylim to fit (if needed)
    if orient == "h":
        ylim = list(ax.get_ylim())
        ylim[-1]  -= (width_box + width_viol)/4.
        _ = ax.set_ylim(ylim)
    elif orient == "v":
        xlim = list(ax.get_xlim())
        xlim[-1]  -= (width_box + width_viol)/4.
        _ = ax.set_xlim(xlim)
    
    return ax

# Loading directories

In [None]:
fig_save_loc = ''
parent_preprocess_dir = ''

# Plotting parameters
absence_color, presence_color = 'crimson', 'dodgerblue'
all_letters = get_alphabet_list()

# Spikes

In [None]:
save_string = 'SPK/'

colors=['dodgerblue', 'crimson']
presence_color, absence_color = colors

In [None]:
spk_performance = pd.read_excel('./Correlations_Paper2017.xlsx', sheet_name='Social-Asocial')

firing_rates_emp = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_SPK_rates_norm.nc')
neuron_names = firing_rates_emp.neuron.to_numpy()

raster_array = xr.load_dataarray(parent_preprocess_dir + 'aggSpikes_2014.nc').load()
raster_prop = xr.load_dataarray(parent_preprocess_dir + 'neuronProps_2014.nc').load()

spike_array = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_SPKs.nc')

rate_neurons = spike_array.sel(neuron=neuron_names).copy().mean('trial').max('time').to_dataframe(name = 'value').reset_index()

with open(parent_preprocess_dir + 'spk_fr_lds_posterior.pkl', 'rb') as file:
    spk_lds_dict = pickle.load(file)

In [None]:
spk_dir = './DCM_SPK_SBI/input/brunel_delta_10000_sims/'
spk_input_data = np.load(spk_dir + 'input_data.npz')
theta_spk = spk_input_data['theta']
firing_rates_sim = spk_input_data['firing_rates']
fr_sim_maxs = xr.load_dataarray(spk_dir + 'features.nc')

ex_senders, ex_times = simulate_brunel(5.5)

In [None]:
posterior_avg = xr.load_dataset(parent_preprocess_dir + 'Demolliens_SPK_AvgNeurons_Posterior.nc')
posterior_all = xr.load_dataset(parent_preprocess_dir + 'Demolliens_SPK_AllNeurons_Posterior.nc')

posterior_checks = posterior_avg.posterior_checks
posterior_samples_avg = posterior_avg.posterior_array
avg_conditions = posterior_samples_avg.sample.to_numpy()
emp_avgs = posterior_avg.emp_rates

posterior_samples_all = posterior_all.posterior_array
all_conditions = posterior_samples_all.sample_1.to_numpy()
posterior_means_all = posterior_all.posterior_means

posterior_check_spikes = posterior_avg.posterior_check_spikes

In [None]:
df_samples_avg = pd.DataFrame(dict(condition=avg_conditions, value=posterior_samples_avg.to_numpy()))
df_samples_all = pd.DataFrame(dict(condition=all_conditions, value=posterior_samples_all.to_numpy()))

## Empirical

In [None]:
def plot_raster_and_rate_two(fig, axes, neurons, lw=1, markersize=1, marker='o', colors=['dodgerblue', 'crimson'], raster_offset=80, ylabelpads=[5,0], plot_labels_2=True):
    
    presence_color, absence_color = colors
    neuron_titles = ['Social Neuron', 'Asocial Neuron']
    title_fontsize = plt.rcParams['font.size']
    
    for n_ind, neuron in enumerate(neurons):
        
        p_raster = neuron.sel(cond = 'Presence').copy()
        a_raster = neuron.sel(cond = 'Absence').copy()
    
        p_rate = np.nanmean(smooth_rates(p_raster, remove_tails=True, axis=0)*100, 1)
        a_rate = np.nanmean(smooth_rates(a_raster, remove_tails=True, axis=0)*100, 1)
                
        plot_raster(a_raster, axes[1, n_ind], color=absence_color, markersize=markersize, marker=marker)
        plot_raster(p_raster, axes[1, n_ind], offset=80, color=presence_color, markersize=markersize, marker=marker)

        axes[0, n_ind].plot(p_rate, label = 'Presence', c=presence_color, lw=lw);
        axes[0, n_ind].plot(a_rate, label = 'Absence', c=absence_color, lw=lw);

        axes[1, n_ind].set_xlabel('Time (ms)')

        modify_axis_spines(axes[1, n_ind], which = ['x', 'y'], xticks = np.arange(0, 101, 50), yticks = np.arange(0, 161, 80))
        modify_axis_spines(axes[0, n_ind], which = ['y'], yticks=[5, 30])
        axes[0, n_ind].set_xticks([])
        axes[0, n_ind].set_title(neuron_titles[n_ind], fontsize=title_fontsize)
        
    axes[0,0].set_ylabel('Rate (Hz)', labelpad=ylabelpads[0])
    axes[1,0].set_ylabel('Trial', labelpad=ylabelpads[1])

    if plot_labels_2:
        axes[0,1].set_ylabel('Rate (Hz)', labelpad=ylabelpads[0])
        axes[1,1].set_ylabel('Trial', labelpad=ylabelpads[1])

    for ax in axes.ravel():
        ax.patch.set_alpha(0.0)


In [None]:
def plot_raster_and_rate_scalebar(fig, axes, neurons, lw=1, markersize=1, marker='o', colors=['dodgerblue', 'crimson'], raster_offset=80, scalebar_fontsize=0, fixed_scale=5):

    if scalebar_fontsize == 0:
        scalebar_fontsize = plt.rcParams['font.size']
    
    presence_color, absence_color = colors
    neuron_titles = ['Social Neuron', 'Asocial Neuron']
    title_fontsize = plt.rcParams['font.size']
    
    for n_ind, neuron in enumerate(neurons):
        
        p_raster = neuron.sel(cond = 'Presence').copy()
        a_raster = neuron.sel(cond = 'Absence').copy()
    
        p_rate = np.nanmean(smooth_rates(p_raster, remove_tails=True, axis=0)*100, 1)
        a_rate = np.nanmean(smooth_rates(a_raster, remove_tails=True, axis=0)*100, 1)
                
        plot_raster(a_raster, axes[1, n_ind], color=absence_color, markersize=markersize, marker=marker)
        plot_raster(p_raster, axes[1, n_ind], offset=80, color=presence_color, markersize=markersize, marker=marker)

        axes[0, n_ind].plot(p_rate, label = 'Presence', c=presence_color, lw=lw);
        axes[0, n_ind].plot(a_rate, label = 'Absence', c=absence_color, lw=lw);

        scalebar_y = ScaleBar(dx=1, rotation='vertical', length_fraction=0.3, fixed_value=fixed_scale, frameon=False, scale_formatter=lambda value,unit: f"{value} Hz",
                              location='center left', scale_loc='left', font_properties=dict(size=scalebar_fontsize))    
                
        axes[1, n_ind].set_xlabel('Time (ms)')
        axes[0, n_ind].add_artist(scalebar_y)
        axes[0, n_ind].axis('off')

        modify_axis_spines(axes[1, n_ind], which = ['x', 'y'], xticks = np.arange(0, 101, 50), yticks = np.arange(0, 161, 80))

        axes[0, n_ind].set_title(neuron_titles[n_ind], fontsize=title_fontsize)
        
    axes[1,0].set_ylabel('Trial')
    
    for ax in axes.ravel():
        ax.patch.set_alpha(0.0)


In [None]:
def plot_raster_and_rate(fig, axes, neuron, lw=1, markersize=1, marker='o', colors=['dodgerblue', 'crimson'], raster_offset=80, scalebar_fontsize=0, fixed_scale=5):

    if scalebar_fontsize == 0:
        scalebar_fontsize = plt.rcParams['font.size']
    
    presence_color, absence_color = colors
    
    p_raster = neuron.sel(cond = 'Presence').copy()
    a_raster = neuron.sel(cond = 'Absence').copy()

    p_rate = np.nanmean(smooth_rates(p_raster, remove_tails=True, axis=0)*100, 1)
    a_rate = np.nanmean(smooth_rates(a_raster, remove_tails=True, axis=0)*100, 1)
       
    plot_raster(a_raster, axes[1], color=absence_color, markersize=markersize, marker=marker)
    plot_raster(p_raster, axes[1], offset=80, color=presence_color, markersize=markersize, marker=marker)
        
    axes[1].set_ylabel('Trial')
    axes[1].set_xlabel('Time (ms)')
    
    modify_axis_spines(axes[1], which = ['x', 'y'], xticks = np.arange(0, 101, 50), yticks = np.arange(0, 161, 80))

    axes[0].plot(p_rate, label = 'Presence', c=presence_color, lw=lw);
    axes[0].plot(a_rate, label = 'Absence', c=absence_color, lw=lw);

    scalebar_y = ScaleBar(dx=1, rotation='vertical', length_fraction=0.3, fixed_value=fixed_scale, frameon=False, scale_formatter=lambda value,unit: f"{value} Hz",
                          location='center left', scale_loc='left', font_properties=dict(size=scalebar_fontsize))    

    axes[0].add_artist(scalebar_y)
    axes[0].axis('off')
    
    for ax in axes:
        ax.patch.set_alpha(0.0)


In [None]:
def plot_spk_correlations(fig, ax, bar_width=0.35, palette='bone', lgd_loc=(0.85, 0.9), offsets=[0.02, 0.03], vas=['top', 'bottom'], ncol=1,
                          linewidth=0, orient='v', xticklabels=[], markerscale=1, legend_spacing=1.5, handletextpad=0.5, handlelength=1, lgd_labels=[]):

    if lgd_labels == []:
        lgd_labels = ['Preferred', 'Non-Preferred']
    
    bars = sns.barplot(x='Measure', y='Correlation', hue='Condition', data=spk_performance, hue_order=['Preferred', 'Non-preferred'],
                      palette=palette, ax=ax, orient=orient)

    bars.legend_.remove()
    ax.set_xlabel('')
    ax.tick_params(axis='x', length=0)
    ax.set_ylim([-0.38, 0.38])
    modify_axis_spines(ax, which=['y'], yticks=[-0.3, 0, 0.3])

    for patch in ax.patches:
        current_width = patch.get_width()
        diff = current_width - bar_width
        patch.set_width(bar_width)
        patch.set_x(patch.get_x() + diff * 0.5)        
    
    signs = [1, -1]
    offset_dict = {signs[ind]: offsets[ind] for ind in range(2)}
    va_dict = {signs[ind]: vas[ind] for ind in range(2)}
    
    for i, bar in enumerate(ax.patches):

        ax_fontsize = plt.rcParams['font.size']
        bar.set_edgecolor('k')
        bar.set_linewidth(linewidth)        
        bar_height = bar.get_height()
        bar_sign = np.sign(bar_height)
        significance = spk_performance['Significance'][i]
        
        offset = offset_dict[bar_sign] * bar_sign
            
        ax.annotate(significance, xy=(bar.get_x() + bar.get_width()/2, bar_height+offset),
                     ha='center', va=va_dict[bar_sign], fontsize=ax_fontsize)

    if xticklabels != []:
        ax.set_xticklabels(xticklabels)

    handles, _ = ax.get_legend_handles_labels()
    
    ax.legend(handles=handles, labels=lgd_labels, frameon=False, loc=lgd_loc, ncol=ncol, markerscale=markerscale, columnspacing=legend_spacing,
              handletextpad=handletextpad, handlelength=handlelength)
    ax.patch.set_alpha(0.0)


In [None]:
def plot_spk_avg_rates(fig, ax, firing_rates, shade_colors=['cornflowerblue', 'lightcoral'], alpha=0.5, lgd_loc=(0.5,0.9), ncol=1,
                      markerscale=1, legend_spacing=1.5, handletextpad=0.5, handlelength=1):

    time_vec = firing_rates.time.to_numpy()
    
    p_rates = firing_rates.sel(condition='Presence').copy()
    a_rates = firing_rates.sel(condition='Absence').copy()
    
    p_rate_avg =  p_rates.mean('neuron')
    a_rate_avg =  a_rates.mean('neuron')
    
    p_rate_err = stats.sem(p_rates.to_numpy().T)
    a_rate_err = stats.sem(a_rates.to_numpy().T)
    
    ax.plot(p_rates.mean('neuron'), label='Pr', c=presence_color);
    ax.plot(a_rates.mean('neuron'), label='Ab', c=absence_color)
    
    ax.fill_between(time_vec, p_rate_avg - p_rate_err, p_rate_avg + p_rate_err, color='cornflowerblue', alpha=0.5, linewidth=0)
    ax.fill_between(time_vec, a_rate_avg - a_rate_err, a_rate_avg + a_rate_err, color='lightcoral', alpha=0.5, linewidth=0)
    
    ax.set_ylabel('Population firing rate')
    ax.set_xlabel('Time (ms)')
    
    modify_axis_spines(ax, which = ['x', 'y'], xticks = np.arange(0, 101, 50), yticks=[0.35, 0.65]) 
    ax.set_xticklabels(['-500', 'FB', '500'])
    
    lgd = ax.legend(frameon=False, loc=lgd_loc, ncol=ncol, markerscale=markerscale, columnspacing=legend_spacing,
              handletextpad=handletextpad, handlelength=handlelength)
    ax.patch.set_alpha(0.0)

fig.savefig(fig_save_loc + save_string + 'AllNeurons_averageFiringRate_STE_nFB.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def plot_spk_maxs(fig, ax, rate_neurons, linewidth=1, point_size=5, labelpad=10, tickpad=10, saturation=0.8, alpha=1, width_viol=0.4, width_box=0.4, bw=0.4,
                  pointplot=False, orient='v', linecolor='darkslategray', move=0, offset=0, legend=False, hide_ticks=False):

    rate_df = rate_neurons.copy()
    # Jittering the peak rates - from 0 to 1 - for clearer visualization
    rate_df['value'] += np.random.random(rate_df.value.shape)
    
    palette = {'Presence': presence_color, 'Absence':absence_color}
    
    significanceComparisons=[('Presence','Absence')]
    configuration = {'test':'Mann-Whitney',  'text_format':'star', 'loc':'outside', 'line_width':linewidth}

    x, y = 'condition', 'value'

    fig_args = {'x': x, 'y': y, 'data': rate_df, 'dodge':True, 'palette':palette, 'linecolor': linecolor,
                'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation, 'alpha': alpha,
                'width_viol':width_viol, 'width_box':width_box, 'bw':bw}
    
    rainclouds = pt.RainCloud(ax=ax, orient=orient, **fig_args, cut=0, pointplot=pointplot, box_fliersize=0, box_whiskerprops=dict(linewidth=linewidth))

    if legend:
        handles, labels = rainclouds.get_legend_handles_labels()
        ax.legend(handles, labels=['Presence', 'Absence'], frameon=False)  # Adjust labels as needed
    
    annotator = Annotator(ax=ax, pairs=significanceComparisons, **fig_args, plot='boxplot', verbose=False)
    annotator.configure(**configuration).apply_test().annotate()

    if orient == 'h':
        ax.set_ylabel('')
        ax.set_xlabel('Peak Firing Rate (Hz)', labelpad=labelpad)
        ax.tick_params(axis='y', length=0, pad=tickpad)
        ax.tick_params(axis='x',)
        modify_axis_spines(ax, which = ['x'], xticks=[0, 100, 200])
        if hide_ticks:
            ax.set_yticks([])
    else:
        ax.set_xlabel('')
        ax.set_ylabel('Peak Firing Rate (Hz)', labelpad=labelpad)
        ax.tick_params(axis='x', length=0, pad=tickpad)
        ax.tick_params(axis='y')
        modify_axis_spines(ax, which = ['y'], yticks=[0, 100, 200])
        if hide_ticks:
            ax.set_xticks([])
    ax.patch.set_alpha(0.0)


In [None]:
def plot_lds_dynamics(fig, ax, state_x, state_y, lds_A, lds_b, lds_min, lds_max, traj_color='dodgerblue', quiver_color='k', traj_lw=2, traj_alpha=1,
                     dynamics_kwargs={}):

    plot_dynamics_2d(lds_A, lds_b, mins=lds_min, maxs=lds_max, color=quiver_color, axis=ax, **dynamics_kwargs)

    ax.plot(state_x, state_y, color=traj_color, lw=traj_lw, alpha=traj_alpha)

In [None]:
def simulate_lds_from_params(A, b, time_bins):

    x = np.zeros((time_bins+1, 2))
    x[0, :] = b
    
    for t in range(1, time_bins+1):
        x[t, :] = np.dot(A, x[t-1, :])

    return x

In [None]:
def plot_spk_lds_dyn(fig, ax, quiver_colors=['dodgerblue', 'crimson'], traj_colors=['dodgerblue', 'crimson'], traj_lw=2, traj_alpha=1, quiver_scale=4, quiver_alpha=0.85,
                     quiver_width=0.005, title=True, custom_extrema=True, traj_source='posterior', state_mins=[-2, -2], state_maxs=[2, 2], nbins=1000,
                     quiver_num=7, titlepad=0): 

    task_conditions = ['Presence', 'Absence']
    
    dynamics_kwargs = {'scale':quiver_scale, 'width':quiver_width, 'alpha': quiver_alpha, 'npts': quiver_num}
    
    for c_ind, condition in enumerate(task_conditions):
    
        params_A, params_b, elbos, sampled_states, state_means, lds = spk_lds_dict[condition].values()    
        
        if not custom_extrema:
            state_mins = state_means.min(axis=0)
            state_maxs = state_means.max(axis=0)            
        
        if traj_source=='fit':        
            plot_lds_dynamics(fig, ax, state_means[:,0], state_means[:,1], params_A, params_b, state_mins, state_maxs,
                              traj_lw=traj_lw, traj_alpha=traj_alpha, traj_color=traj_colors[c_ind], quiver_color=quiver_colors[c_ind], dynamics_kwargs=dynamics_kwargs) 
        elif traj_source=='sample':
            plot_lds_dynamics(fig, ax, sampled_states[:,0], sampled_states[:,1], params_A, params_b, state_mins, state_maxs,
                              traj_lw=traj_lw, traj_alpha=traj_alpha, traj_color=traj_colors[c_ind], quiver_color=quiver_colors[c_ind], dynamics_kwargs=dynamics_kwargs)

    ax.set_ylabel("Dimension 2")
    ax.set_xlabel("Dimension 1")
    ax.patch.set_alpha(0)
    if title:
        ax.set_title('Latent dynamics', pad=titlepad, fontsize=plt.rcParams['font.size'])
    
    modify_axis_spines(ax, which=['x', 'y'], xticks=[state_mins[0], state_maxs[0].round(1)],
                       yticks=[state_mins[1], state_maxs[1].round(1)])
    

## Synthetic

In [None]:
def remove_extreme_outliers(df, columns, low=0.001, high=0.999):
    low = 0.001
    high = 0.999
    for name in columns:
        quantiles = df[name].quantile([low, high])
        df = df[(df[name] > quantiles.loc[low]) & (df[name] < quantiles.loc[high])]
    return df

In [None]:
def plot_sim_spk_rates(fig, ax, firing_rates, labelpad=5, lw=1, alpha=0.5, color='k'):

    step_size = 100
    time_start = 0
    time_end = 1000
    tv = np.linspace(time_start, time_end, firing_rates.shape[0])
    ax.plot(tv, firing_rates[:,::step_size], color=color, alpha=alpha, lw=lw);
    
    modify_axis_spines(ax, which=['x', 'y'], yticks=[0, firing_rates.max()], xticks=[time_start, time_end])
    
    ax.set_xlabel('Time (ms)', labelpad=labelpad)
    ax.set_ylabel('Normalized firing rate', labelpad=labelpad)
    ax.patch.set_alpha(0.0)

    fig.savefig(fig_save_loc + save_string + 'SBI_FR_Samples' + saveString  + '.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def plot_spk_rates_fit(fig, ax, emp_lw=5, fit_lw=12, emp_colors=['blue', 'red'], fit_colors=['cornflowerblue', 'salmon'], lgd_loc=(0.3, 1.05),
                       emp_ls=':', fit_ls='-', emp_alpha=1, fit_alpha=0.5, labelpad=5, ncol=2, markerscale=1, legend_spacing=0.5, handletextpad=0.5, handlelength=0.5):

    maxLead = 0.5
    maxLag = 0.5
    timeVec = np.linspace(-maxLead*1000, maxLag*1000, emp_avgs.shape[0])
    tv = np.linspace(-maxLead*1000, maxLag*1000, posterior_checks.shape[0])
    
    ax.plot(timeVec, emp_avgs[:,0], label = 'Emp-Pr', c=emp_colors[0], lw=emp_lw, ls=emp_ls, alpha=emp_alpha)
    ax.plot(timeVec, emp_avgs[:,1], label = 'Emp-Ab', c=emp_colors[1], lw=emp_lw, ls=emp_ls, alpha=emp_alpha)
        
    ax.plot(tv, posterior_checks[:,0].to_numpy().ravel(), label='SBI-Pr', lw=fit_lw, c=fit_colors[0], ls=fit_ls, alpha=fit_alpha);
    ax.plot(tv, posterior_checks[:,1].to_numpy().ravel(), label='SBI-Ab', lw=fit_lw, c=fit_colors[1], ls=fit_ls, alpha=fit_alpha);
    
    modify_axis_spines(ax, which = ['x', 'y'], yticks=[0.35, 0.65], xticks=[-500, 0, 500])
    
    ax.set_xticklabels(['-500', 'FB', '500'])
    ax.set_xlabel('Time (ms)', labelpad=labelpad)
    ax.set_ylabel('Population firing rate', labelpad=labelpad)
    ax.legend(frameon=False, loc=lgd_loc, ncol=ncol, markerscale=markerscale, columnspacing=legend_spacing,
              handletextpad=handletextpad, handlelength=handlelength)
    ax.patch.set_alpha(0.0)


In [None]:
def plot_spk_post_boxes(fig, ax, df=None, lw=1, cap_lw=0, saturation=0.8, labelpad=10, scale='width', data_lim=[5, 6.2], stat_loc='outside', hide_ticks=False,
                        ylabel='Posterior ''$g$'+' (pooled)'):
    
    palette = {'Presence': 'dodgerblue', 'Absence': 'crimson'}
        
    violins = sns.boxplot(data=df, x='condition', y='value', fliersize=0, showfliers=False, capprops={"linewidth": cap_lw},
                   palette=palette, saturation=saturation, dodge=True, ax=ax, linewidth=lw)
    
    
    annotator = Annotator(x='condition', y='value', ax=ax, pairs=[('Absence','Presence')], data=df, verbose=False)
    annotator.configure(test='Mann-Whitney', text_format='star', loc=stat_loc, line_width=lw)
    annotator.apply_and_annotate()
    
    modify_axis_spines(ax, which=['y'], yticks=data_lim)
    ax.set_ylim(data_lim)
    
    ax.tick_params(axis='both', which='major')
    ax.tick_params(axis='x', length=0)
    ax.set_xlabel('')
    ax.set_ylabel(ylabel, labelpad=labelpad)

    if hide_ticks:
        ax.set_xticks([])
        ax.set_xticklabels([])
    ax.patch.set_alpha(0.0)


In [None]:
def plot_spk_post_boxes_avg(fig, ax, df=None, lw=1, cap_lw=0, labelpad=10, tickpad=5, scale='count', data_lim=[5, 6.2], stat_loc='inside', hide_ticks=False,
                            linecolor='k', point_size=5, saturation=0.8, alpha=1, width_viol=0.5, width_box=0.3, bw=0.3, orient='v', legend=False,
                            ylabel='Posterior ''$g$'+' (average)'):


    df = remove_extreme_outliers(df.copy(), ['value'])
    
    palette = {'Presence': 'dodgerblue', 'Absence': 'crimson'}
    significanceComparisons=[('Presence','Absence')]
    configuration = {'test':'Mann-Whitney',  'text_format':'star', 'loc':stat_loc, 'line_width':lw}
    
    labels = [r'$g$']

    if orient == 'h':
        x, y = 'value', 'condition'
    else:
        x, y = 'condition', 'value'
    x, y = 'condition', 'value'

    fig_args = {'x': x, 'y': y, 'data': df, 'dodge': True, 'palette':palette, 'linecolor': linecolor,
                'point_size':point_size, 'box_linewidth': lw, 'saturation':saturation, 'alpha': alpha,
                'width_viol':width_viol, 'width_box':width_box, 'bw':bw, 'orient':orient}
    
    rainclouds = pt.RainCloud(ax=ax, **fig_args, cut=0, pointplot=False, box_fliersize=0, linewidth=lw, box_capprops={"linewidth": cap_lw},
                             box_whiskerprops={'linewidth':lw}, box_showfliers=False)

    annotator = Annotator(ax=ax, pairs=significanceComparisons, **fig_args, plot='boxplot', verbose=False)
    annotator.configure(**configuration).apply_test().annotate()

    
    if legend:
        handles, labels = rainclouds.get_legend_handles_labels()
        ax.legend(handles, labels=['Presence', 'Absence'], frameon=False)  # Adjust labels as needed
    
    ax.patch.set_alpha(0.0)

    if orient == 'h':
        ax.set_ylabel('')
        ax.set_xlabel(r'$g$', labelpad=labelpad)
        ax.tick_params(axis='y', length=0, pad=tickpad)
        ax.tick_params(axis='x',)
        modify_axis_spines(ax, which = ['x'], xticks=data_lim) 
        ax.set_xlim(data_lim)
        if hide_ticks:
            ax.set_yticks([])
    else:
        ax.set_ylim(data_lim)    
        ax.set_xlabel('')
        ax.set_ylabel(ylabel, labelpad=labelpad)
        ax.tick_params(axis='x', length=0, pad=tickpad)
        ax.tick_params(axis='y')
        modify_axis_spines(ax, which = ['y'], yticks=data_lim)
        if hide_ticks:
            ax.set_xticks([])


In [None]:
def plot_brunel_features(fig, ax, markersize=0.5, labelpad=20, alpha=0.5, color='slategray', weight='normal', wspace=-0.1, hspace=-0.1, left=-0.1):

    xticks = [5, 8]
    yticks = [0, 1]
    
    ax.scatter(theta_spk, fr_sim_maxs, s=markersize, color=color, alpha=alpha)
    
    modify_axis_spines(ax, which=['x', 'y'], yticks=yticks, xticks=xticks)

    ax.set_xlabel(r'$g$', y=0.01)
    ax.set_ylabel('Max firing rate', x=-0.005)
    ax.patch.set_alpha(0.0)

    fig.savefig(fig_save_loc + save_string + 'Int-Seg_Sim_shortened_EqualScale.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def plot_spk_fr_hist(fig, ax, nbins=50, color='darkslategray', histtype='bar', bottom=4, rwidth=1, stacked=False, orientation='vertical'):

    numbers, bins, _ = ax.hist(fr_sim_maxs, bins=nbins, color=color, histtype=histtype, bottom=bottom, rwidth=rwidth, orientation=orientation, stacked=stacked)

    xticks = [0, 1]
    yticks = [0, numbers.max().round(-2)]
    axis_labels = ['Max firing rate', 'Count']
    xlabel, ylabel= axis_labels
    
    if orientation == 'horizontal':
        xticks = yticks.copy()
        yticks = [0, 1]
        xlabel = axis_labels[1]
        ylabel = axis_labels[0]
        
    modify_axis_spines(ax, which=['x', 'y'], yticks=yticks, xticks=xticks)

    ax.set_xlabel(xlabel, y=0.01)
    ax.set_ylabel(ylabel, x=-0.005)
    ax.patch.set_alpha(0)

In [None]:
def plot_brunel_features_all(fig, axes, markersize=0.5, labelpad=20, alpha=0.5, color='slategray', weight='normal', wspace=-0.1, hspace=-0.1, left=-0.1,
                             nbins=50, histtype='bar', bottom=4, rwidth=1, stacked=False, orientation='horizontal', hist_ylabel=0, xlabel_y=0, ylabel_x=0):

    ax_sctr = axes[0]
    ax_hist = axes[1]
    
    sctr_xticks = [5, 8]
    sctr_yticks = [0, 1]
    
    ax_sctr.scatter(theta_spk, fr_sim_maxs, s=markersize, color=color, alpha=alpha)
    
    modify_axis_spines(ax_sctr, which=['x', 'y'], yticks=sctr_yticks, xticks=sctr_xticks)

    ax_sctr.set_xlabel(r'$g$', y=xlabel_y)
    ax_sctr.set_ylabel('Max firing rate', x=ylabel_x)

    numbers, bins, _ = ax_hist.hist(fr_sim_maxs, bins=nbins, color=color, histtype=histtype, bottom=bottom, rwidth=rwidth, orientation=orientation, stacked=stacked)

    hist_xticks = [0, 1]
    hist_yticks = [0, numbers.max().round(-2)]
    hist_xlabel, hist_ylabel= ['', 'Count']
    hist_labels = [hist_xlabel, hist_ylabel]
    
    if orientation == 'horizontal':
        hist_xticks = hist_yticks.copy()
        hist_yticks = [0, 1]
        hist_xlabel = hist_labels[1]
        hist_ylabel = hist_labels[0]

    modify_axis_spines(ax_hist, which=['x', 'y'], yticks=hist_yticks, xticks=hist_xticks)

    hide_hist_y=True
    if hide_hist_y:
        ax_hist.spines['left'].set_visible(False)
        ax_hist.set_yticks([])
    ax_hist.set_xlabel(hist_xlabel, y=hist_ylabel)

    for ax in axes:
        ax.patch.set_alpha(0.0)

In [None]:
def plot_brunel_features_horiz(fig, axes, markersize=0.5, labelpad=20, alpha=0.5, color='slategray', weight='normal', wspace=-0.1, hspace=-0.1, left=-0.1,
                             nbins=50, histtype='bar', bottom=4, rwidth=1, stacked=False, xlabelpad=0, ylabelpad=0, titlepad=10):

    ax_hist = axes[0]
    ax_sctr = axes[1]
    
    sctr_yticks = [5, 8]
    sctr_xticks = [0, 1]
    
    ax_sctr.scatter(fr_sim_maxs, theta_spk, s=markersize, color=color, alpha=alpha)
    
    modify_axis_spines(ax_sctr, which=['x', 'y'], yticks=sctr_yticks, xticks=sctr_xticks)

    ax_sctr.set_ylabel(r'$g$', labelpad=ylabelpad)
    ax_sctr.set_xlabel('Max firing rate', labelpad=xlabelpad)

    numbers, bins, _ = ax_hist.hist(fr_sim_maxs, bins=nbins, color=color, histtype=histtype, bottom=bottom, rwidth=rwidth, stacked=stacked)

    hist_xlabel, hist_ylabel= ['', 'Count']
    
    hist_yticks = [0, numbers.max().round(-2)]
    hist_xticks = [0, 1]

    modify_axis_spines(ax_hist, which=['x', 'y'], yticks=hist_yticks, xticks=hist_xticks)

    ax_hist.spines['bottom'].set_visible(False)
    ax_hist.set_xticks([])
    ax_hist.set_ylabel(hist_ylabel)
    ax_hist.set_title('Simulation features', pad=titlepad, fontsize=plt.rcParams['font.size'])
    
    for ax in axes:
        ax.patch.set_alpha(0.0)

In [None]:
def plot_sim_fit_rasters(fig, ax, plotstep=10, sample_len=100, trial_spike_thresh=25, marker='o', markersize=4, markerscale=2, alpha=1, chosen_trials=[],
                         colors=['dodgerblue', 'crimson'], labels=['Presence', 'Absence'], time_range=[0,100], thresh_lo=5, thresh_hi=25):


    sel_neurons = spike_array.neuron.to_numpy()
    p_spikes = spike_array.sel(condition = 'Presence', neuron=sel_neurons).stack(aggTrials = ('trial','neuron')).dropna('aggTrials').to_numpy()
    a_spikes = spike_array.sel(condition = 'Absence', neuron=sel_neurons).stack(aggTrials = ('trial','neuron')).dropna('aggTrials').to_numpy()
    p_spikes[p_spikes>1] = 1
    a_spikes[a_spikes>1] = 1
    
    plotData = [p_spikes,a_spikes]
    selected_trials = []
    
    for cInd in range(2):    
        
        spike_times = posterior_check_spikes.isel(condition=cInd, source=0).to_numpy()/10
        neuron_inds = posterior_check_spikes.isel(condition=cInd, source=1).to_numpy()
        full_inds = np.where(spike_times == spike_times)
        spike_times = spike_times[full_inds]
        neuron_inds = neuron_inds[full_inds]
        
        axes[0, cInd].scatter(spike_times, neuron_inds, marker=marker, s=markersize*markerscale, alpha=alpha, color=colors[cInd],
                           label=labels[cInd])
        
        
        praster = plotData[cInd]

        if chosen_trials == []:
            st_mask = np.sum(praster[:25,:], axis=0) <= 5
            lo_mask = np.sum(praster[time_range[0]:time_range[1],:], axis=0) <= thresh_hi
            hi_mask = np.sum(praster[time_range[0]:time_range[1],:], axis=0) > thresh_lo
            chosen_trials = np.where(lo_mask & hi_mask & st_mask)[0]
            selected_trials.append(chosen_trials)
            
        praster = praster[:, chosen_trials]
        sample_trials = random.sample(range(1,praster.shape[1]),sample_len)

        for ind in range(sample_len):
                tspikes = praster[:,sample_trials[ind]]
                sinds = np.where(tspikes != 0)[0]
                axes[1, cInd].scatter(sinds, tspikes[sinds]*ind, marker=marker, color=colors[cInd], alpha=alpha, s=markersize*markerscale)
    
    loc = mpl.ticker.MultipleLocator(base=100.0) # this locator puts ticks at regular intervals
                
    for ax in axes.ravel():
        # ax.yaxis.set_major_locator(loc)
        modify_axis_spines(ax, which=['x', 'y'], yticks=np.arange(0, 110, 100), xticks=[0,50,100], yaxis_left=False)
        ax.yaxis.tick_right()
        ax.set_xticklabels(['-500','FB','500'])
        ax.tick_params(right = True)
        ax.patch.set_alpha(0.0)
    
    axes[0,0].spines[:].set_visible(False)
    axes[0,1].spines[['bottom', 'right']].set_visible(False)
    axes[1,0].spines.right.set(visible=False)
    axes[1,1].spines.bottom.set(visible=False)
    
    axes[0,0].set_xticks([])
    axes[0,1].set_xticks([])
    axes[0,1].set_yticks([])
    axes[0,0].set_yticks([])
    axes[1,0].set_yticks([])
    axes[1,1].set_xticks([])
    axes[0,0].set_ylabel('Synthetic')
    axes[1,0].set_ylabel('Empirical')
    
    fig.supxlabel('Time', y=0.02)
    fig.supylabel('Trial', x=0.01)
    fig.legend(frameon=False, ncols=2, bbox_to_anchor=(0.7,1.05), markerscale=markerscale)
    
    fig.tight_layout(h_pad=0, w_pad=0)
    fig.subplots_adjust(hspace=-0.01, wspace=-0.01)

    fig.savefig(fig_save_loc + save_string + 'SBI_SPK_RasterFits_' + str(sample_len) + '_' + saveString + '.svg', bbox_inches='tight', transparent=True)
    return selected_trials

In [None]:
def plot_brunel_sample_raster(fig, axes, ex_senders, ex_times, hist_color='slategray', nbins=50, markersize=10, marker='.', raster_color='dimgray',
                              labelpad_1=7, labelpad_2=0, titlepad=10):

    bin_values = axes[0].hist(ex_times, bins=nbins, color=hist_color)[0]
    axes[1].scatter(ex_times, ex_senders, c=raster_color, s=markersize, marker=marker)


    axes[0].set_title('Sample simulation', pad=titlepad, fontsize=plt.rcParams['font.size'])
    axes[1].set_xlabel('Time (ms)')
    axes[0].set_ylabel('Rate (Hz)', labelpad=labelpad_1)
    axes[1].set_ylabel('Neuron ID', labelpad=labelpad_2)
    
    modify_axis_spines(axes[1], which = ['x', 'y'], xticks=[0, 1001], yticks=[0, 100])
    modify_axis_spines(axes[0], which = ['y'], yticks=[0, bin_values.max().round(-1)])

    axes[0].set_xticks([])
    
    for ax in axes:
        ax.patch.set_alpha(0.0)

## Multi-Panel

In [None]:
height_ratios = np.array([1, 1, 1, 1])
right_space = 0.95
left_space = 0.09
wspace = 0.6
hspace = 0.3

fig = plt.figure(figsize=(9,9), frameon=False)
(fig_r1, fig_r2, fig_r3, fig_r4) = fig.subfigures(4, 1, height_ratios=height_ratios, hspace=hspace, frameon=False)

####################### 1st Row sub-figure #########################################
fig_r1_wr = [1, 1, 1]
grid_r1 = gridspec.GridSpec(2, 3, width_ratios=fig_r1_wr, figure=fig_r1, wspace=wspace, left=left_space, right=right_space, top=0.80, bottom=0.05)

ax_emp_rasters = np.zeros((2,2), dtype=object)
for r_ind in range(2):
    for c_ind in range(2):
        ax_emp_rasters[r_ind, c_ind] = fig_r1.add_subplot(grid_r1[r_ind, c_ind])

ax_beh_correl = fig_r1.add_subplot(grid_r1[:, 2])

n_names = ['310314E1N3', '150414E1N2']
tslice = np.arange(50,150)
neurons = [raster_array.sel(neuron=n_name, feedback='positive', time=tslice) for n_name in n_names]

plot_raster_and_rate_two(fig_r1, ax_emp_rasters, neurons, ylabelpads=[15, 10], markersize=0.2)

corr_palette = {'Preferred': 'lightsteelblue', 'Non-preferred': 'slategray'}
plot_spk_correlations(fig_r1, ax_beh_correl, ncol=2, lgd_loc=(0.1, 1.0), bar_width=0.35, linewidth=0.1, palette=corr_palette, offsets=[0.035, 0.08],
                     xticklabels=['LeS', 'Acc'], lgd_labels=['Pref', 'Non-Pref'], legend_spacing=0.6, handlelength=0.8)



####################### 2nd Row sub-figure #########################################
fig_r2_wr = [1, 1, 1]
grid_r2 = gridspec.GridSpec(1, 3, width_ratios=fig_r2_wr, figure=fig_r2, wspace=wspace, left=left_space, right=right_space)

ax_max_frates = fig_r2.add_subplot(grid_r2[:, 0])
ax_avg_frates = fig_r2.add_subplot(grid_r2[:, 1])
ax_spk_lds = fig_r2.add_subplot(grid_r2[:,2])

plot_spk_maxs(fig_r2, ax_max_frates, rate_neurons, orient='v', point_size=2, hide_ticks=True)
plot_spk_avg_rates(fig_r2, ax_avg_frates, firing_rates_emp, ncol=1, lgd_loc=(0.1, 0.6), handlelength=0.5)
plot_spk_lds_dyn(fig_r2, ax_spk_lds, traj_source='sample', state_mins=np.array([-0.25, -0.25]), state_maxs=np.array([0.1, 0.1]),
                 quiver_scale=0.05, quiver_width=0.005, quiver_num=8, titlepad=10, quiver_colors=['royalblue', 'firebrick'])

# ####################### 3rd Row sub-figure #########################################
fig_r3_wr = [1, 1, 1]
grid_r3 = gridspec.GridSpec(2, 3, width_ratios=fig_r3_wr, figure=fig_r3, wspace=wspace, left=left_space-0.01, right=right_space)

ax_sample_hist = fig_r3.add_subplot(grid_r3[0,0])
ax_sample_rast = fig_r3.add_subplot(grid_r3[1,0])
ax_sim_sample = [ax_sample_hist, ax_sample_rast]

ax_sim_frates = fig_r3.add_subplot(grid_r3[:,1])
ax_sim_feats_sctr = fig_r3.add_subplot(grid_r3[0,2])
ax_sim_feats_hist = fig_r3.add_subplot(grid_r3[1,2])

# ex_senders, ex_times = simulate_brunel(g=5.5)
plot_brunel_sample_raster(fig_r3, ax_sim_sample, ex_senders, ex_times, nbins=40, raster_color='slategray', hist_color='slategray', markersize=2)
plot_sim_spk_rates(fig_r3, ax_sim_frates, firing_rates_sim/firing_rates_sim.max(), alpha=0.3, color='slategray')
plot_brunel_features_horiz(fig_r3, [ax_sim_feats_sctr, ax_sim_feats_hist], alpha=0.7, markersize=0.5, nbins=20, rwidth=0.8, stacked=True, ylabelpad=20)


# ####################### 4th Row sub-figure #########################################
fig_r4_wr = [1, 1, 1]
grid_r4 = gridspec.GridSpec(1, 3, width_ratios=fig_r4_wr, figure=fig_r4, wspace=wspace, left=left_space, right=right_space, bottom=0.23)

ax_sim_fits = fig_r4.add_subplot(grid_r4[0])
ax_post_boxs_all = fig_r4.add_subplot(grid_r4[1])
ax_post_boxs_avg = fig_r4.add_subplot(grid_r4[2])

plot_spk_rates_fit(fig_r4, ax_sim_fits, emp_lw=3, fit_lw=4, lgd_loc=(0.04, 0.65), ncol=1, legend_spacing=4)
plot_spk_post_boxes(fig_r4, ax_post_boxs_all, df_samples_all, data_lim=[5, 6.5], stat_loc='inside', labelpad=5, hide_ticks=True)
plot_spk_post_boxes_avg(fig_r4, ax_post_boxs_avg, df_samples_avg, data_lim=[5.5, 6.5], tickpad=5, stat_loc='inside', hide_ticks=True, point_size=1)

# ####################### Parent Figure Configuration #########################################

annotated_axes = [ax_emp_rasters[0,0], ax_beh_correl, ax_max_frates, ax_avg_frates, ax_spk_lds, ax_sample_hist, ax_sim_frates,
                  ax_sim_feats_sctr ,ax_sim_fits, ax_post_boxs_all, ax_post_boxs_avg]
ax_letters = all_letters[2:len(annotated_axes)+2]
# ax_letters = [ltr+' )' for ltr in ax_letters]

for ax_ind, ax in enumerate(annotated_axes):

    x_offset, y_offset = -0.3, 1.2
    ax_annot = ax_letters[ax_ind]
    if ax_annot in ['C', 'H']:
        y_offset = 1.1
    letter_annotation(ax, x_offset, y_offset, ax_letters[ax_ind], fontsize=14)
                  
fig.show()

In [None]:
fig.savefig(fig_save_loc + 'Article/' + 'Fig1.png', dpi=800, transparent=True)
fig.savefig(fig_save_loc + 'Article/' + 'Fig1.svg', transparent=True)

# ERP

## Empirical

In [None]:
save_string = 'ERP/'

In [None]:
task_ERPs = xr.load_dataarray(parent_preprocess_dir + 'task_ERPs_AllRegions.nc')
erp_prior_dataset = xr.load_dataarray(parent_preprocess_dir + 'erp_sim_array.nc')

erp_con_stats = xr.load_dataarray(parent_preprocess_dir + 'erp_sbi_con_stats.nc')

task_elecs = task_ERPs.electrode.to_numpy()
task_elecs_sessions = np.array([e_name[:-2] for e_name in task_elecs])
task_sessions = np.unique(task_elecs_sessions)

agg_event = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggEvents_All.nc')
agg_cond = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggConds_All.nc')
agg_beh = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggBehaviors_All.nc')
agg_event = agg_event.sel(session=task_sessions)
agg_cond = agg_cond.sel(session=task_sessions)
agg_beh = agg_beh.sel(session=task_sessions)

In [None]:
task_ERPs_grouped = task_ERPs.copy()
task_ERPs_grouped.coords['electrode'] = task_elecs_sessions
task_ERPs_grouped = task_ERPs_grouped.groupby('electrode').mean('electrode')

In [None]:
sessERP_Peaks = np.max(task_ERPs, axis=0)
sessERP_Peaks = sessERP_Peaks/sessERP_Peaks.max(axis=1)

pDomSessInds = np.where(sessERP_Peaks[:, 0]>=sessERP_Peaks[:, 1])[0]
aDomSessInds = np.where(sessERP_Peaks[:, 0]<sessERP_Peaks[:, 1])[0]

mon_fb = agg_beh.sel(event='feedback')
p_mask = agg_cond == 3
a_mask = agg_cond == 1
p_acc = mon_fb.where(p_mask).sum('trial')/p_mask.sum('trial')
a_acc = mon_fb.where(a_mask).sum('trial')/a_mask.sum('trial')

max_erps_mean = task_ERPs.max('time')
max_erps_all = task_ERPs_grouped.max('time')
max_erps_df = task_ERPs.max('time').stack(aggSession=('electrode','condition')).to_pandas().reset_index().rename(columns={0:'value'})

n_ratio = max_erps_all.sel(condition='Presence')/(max_erps_all.sum('condition')) * 100 
b_ratio = p_acc/(p_acc+a_acc)*100

n_ratio = n_ratio.to_numpy()
b_ratio = b_ratio.to_numpy()

erp_beh_cor = stats.pearsonr(n_ratio, b_ratio)
erp_beh_cor_nl = stats.spearmanr(n_ratio, b_ratio)

x_mcmc = np.sort(n_ratio)
x_sorted_inds = np.argsort(n_ratio)
y_mcmc = b_ratio[x_sorted_inds]

mcmc_linear = run_mcmc_from_system(linreg_system, x=x_mcmc, y=y_mcmc)
samples_linear = az.from_numpyro(mcmc_linear)
xdot_quantiles_linear = np.quantile(samples_linear.posterior.xdot.squeeze(),[0.05,0.95],axis=0)

In [None]:
erp_trajectories = xr.load_dataarray(parent_preprocess_dir + 'erp_latent_traj.nc')
erp_trajectories.coords['electrode'] = task_ERPs.electrode
pn_embeddings = xr.load_dataarray(parent_preprocess_dir + 'sbi_pn_traj_cebra.nc')
embedArray = xr.load_dataarray(parent_preprocess_dir + '/ERP_Embeddings.nc')
erp_ppc = az.from_netcdf(parent_preprocess_dir + 'ERP_CEBRA_Embeddings_PPC.nc')

pn_traj_stacked = erp_trajectories.sel(subpopulation='PN', time=np.arange(0,1001)).stack(aggCondition=('electrode', 'condition'))
pn_conditions = (pn_traj_stacked.condition.to_numpy() == 'Presence').astype(int)

methodInd = 0
methodName = embedArray['embedding'].to_numpy()[methodInd]
scaler = preprocessing.MinMaxScaler()

X = embedArray[methodInd, :, :].to_numpy().copy()
X = scaler.fit_transform(X)
Y = embedArray['condition'].to_numpy().copy()

X_train, X_test, Y_train, Y_test = skl.model_selection.train_test_split(X, Y, test_size=.5)

y_ppc = erp_ppc.posterior_predictive["out"]
y_ppc_mean = y_ppc.mean(('chain', 'draw')).to_numpy()
y_pred = y_ppc_mean > 0.5
ppc_scale = erp_ppc.observed_data['out'].shape[0]
y_ppc_grid = y_ppc_mean.reshape(ppc_scale, ppc_scale, ppc_scale)

## Synthetic

In [None]:
erp_thetas = xr.load_dataarray(parent_preprocess_dir + 'erp_sbi_thetas.nc')
erp_thetas_maxs = xr.load_dataarray(parent_preprocess_dir + 'erp_sbi_thetas_maxs.nc')

In [None]:
cond_thetas_max = erp_thetas_maxs.median('electrode')
max_pn_df = erp_trajectories.sel(subpopulation='PN').max('time').stack(aggSession=('electrode','condition')).to_pandas().reset_index().rename(columns={0:'value'})

erp_theta_df = erp_thetas.stack(aggSample=('theta', 'sample',  'electrode', 'condition'))
erp_theta_df = erp_theta_df.to_dataframe(name = 'value')
dfInds = erp_theta_df.index
erp_theta_df.index = np.arange(len(dfInds))

## Multi-Panel

In [None]:
def plot_sc_trajectories(fig, ax):

    colors = ['dodgerblue', 'crimson']
    
    for s_ind in range(33):
        for c_ind in range(2):
            xsc, xin, xpy = erp_trajectories.isel(session=s_ind, condition=c_ind)
            if (xpy.max() >= 20) | (xin.max() >= 10)|(xin.min() <= -5)| (xsc.max() >= 0.5):
                continue
            ax.plot(xsc, colors[c_ind]);

    ymin, ymax = data.min().round(1), data.max().round(1)
    modify_axis_spines(ax, which = ['x', 'y'], yticks = [0, 0.4], xticks = [0, xsc.shape[0]]) 

    ax.set_xticklabels([0, xsc.shape[0]//10])
    ax.set_xlabel('Time')
    ax.set_ylabel('SC output')
    ax.patch.set_alpha(0.0)

    # fig.tight_layout()
 

In [None]:
def modify_axis_spines_3d(ax, which=None, base=1.0, xticks=[], yticks=[], zticks=[]):

    tick_locator = plticker.MultipleLocator(base=base)
    
    if 'x' in which:
        if len(xticks) == 0:
            xticks = ax.get_xticks() 
            ax.xaxis.set_major_locator(tick_locator)
        ax.set_xticks([xticks[0], xticks[-1]])
        # xspine.set_bounds(ax.get_xticks()[0], ax.get_xticks()[-1])
        ax.set_xlim(ax.get_xticks()[0], ax.get_xticks()[-1])
    else:
        ax.spines.bottom.set(visible=False)
    
    if 'y' in which:
        if len(yticks) == 0:
            yticks = ax.get_yticks()
        ax.set_yticks([yticks[0], yticks[-1]])
        ax.set_ylim(ax.get_yticks()[0], ax.get_yticks()[-1])
        if len(yticks) == 0:
            ax.yaxis.set_major_locator(tick_locator)
    else:
        ax.spines.left.set(visible=False)

    if 'z' in which:
        if len(zticks) == 0:
            zticks = ax.get_zticks()
        ax.set_zticks([zticks[0], zticks[-1]])
        ax.set_zlim(ax.get_zticks()[0], ax.get_zticks()[-1])
        if len(zticks) == 0:
            ax.zaxis.set_major_locator(tick_locator)

In [None]:
def plot_erp_trajectories(fig, ax, labelpad=-5, bg_color=(0.5, 0.5, 0.5, 0.2), traj_lw=0.5, alpha=0.5, axis_lw=1, azim=168, elev=19.5, titlepad=10,
                         xlim=[5,30], ylim=[-0.05, 0.16], zlim=[-0.5, 3]):
    
    colors = ['dodgerblue', 'crimson']
    
    for e_ind in range(erp_trajectories.electrode.shape[0]):
        cond_maxs = erp_trajectories[2, :, e_ind, :].max('time')
        if cond_maxs.diff('condition') < 0:
            for c_ind in range(erp_trajectories.condition.shape[0]):
                xsc, xin, xpy = erp_trajectories[:, :, e_ind, c_ind]
                ax.plot(xsc, xin, xpy, colors[c_ind], zorder=2, lw=traj_lw, alpha=alpha);
         
    ax.set_xlabel('SC', labelpad=labelpad)
    ax.set_ylabel('IN', labelpad=labelpad)
    ax.set_zlabel('PN', labelpad=labelpad)
    ax.set_title('Latent space', pad=titlepad, fontsize=plt.rcParams['font.size'])

    for axis in [ax.w_xaxis, ax.w_yaxis, ax.w_zaxis]:
        axis.line.set_linewidth(axis_lw)
        axis.set_ticks([])
        axis.set_pane_color(bg_color) 
    ax.view_init(azim=azim, elev=elev)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_zlim(zlim)

    ax.patch.set_alpha(0.0)
    

In [None]:
def plot_emp_ERPs(fig, ax, data, color='k', lw=1, alpha=0.5, labelpad=0, titlepad=10):

    max_lead = data.attrs['MaxLead']*1000*-1
    max_lag = data.attrs['MaxLag']*1000
    
    time_vec = np.linspace(max_lead, max_lag, data.shape[0])

    ax.plot(time_vec, data.to_numpy().reshape(time_vec.shape[0], -1), c=color, lw=lw, alpha=alpha)

    ymin, ymax = -3.5, data.max().round(1)
    
    modify_axis_spines(ax, which = ['x', 'y'], yticks = [ymin, 0, ymax], xticks = [max_lead, max_lag]) 
        
    ax.set_xticklabels([str(round(max_lead)), 'FB'])
    ax.set_ylabel(r'Amplitude (mV)', labelpad=labelpad)
    ax.set_xlabel('Time (ms)')
    ax.set_title('Session ERPs', pad=titlepad, fontsize=plt.rcParams['font.size'])
    ax.patch.set_alpha(0.0)
    

In [None]:
def plot_sim_ERPs(fig, ax, data, sim_step=50, color='k', lw=1, alpha=0.5, titlepad=10):

    time_vec = np.arange(data.time.shape[0])
    ax.plot(time_vec, data.T[:, ::sim_step], c=color, lw=lw, alpha=alpha)

    ymin, ymax = data.min().round(0), data.max().round(0)
    
    modify_axis_spines(ax, which = ['x', 'y'], yticks=[ymin, 0, ymax], xticks=[0, time_vec[-1]]) 
        
    ax.set_ylabel(r'Amplitude (mV)')
    ax.set_xlabel('Timestep')
    ax.set_title('Sample simulations', pad=titlepad, fontsize=plt.rcParams['font.size'])
    ax.patch.set_alpha(0.0)


In [None]:
def plot_session_ratios(fig, ax, titlepad=20, alpha=1, labeldistance=1.2, pctdistance=0.75, width=0.5, fontoffset=2, startangle=313):
    
    sess_cond_ratio = {'Presence dominant': len(pDomSessInds), 'Other': len(aDomSessInds)}
    ratio_labels = list(sess_cond_ratio.keys())
    
    counterclock=True
    explode = [0,1]

    title_fontsize = ax.title.get_font_properties().get_size_in_points()
    ax.title.set_font_properties(dict(size=title_fontsize+fontoffset))
    
    wedges, texts, value_texts = ax.pie(list(sess_cond_ratio.values()), wedgeprops=dict(width=width, alpha=alpha), startangle=startangle, labels = ratio_labels, textprops={'size': title_fontsize, 'ha': 'center'}, 
                                        colors = [presence_color, 'crimson'], autopct='%2.f%%', labeldistance=labeldistance, pctdistance=pctdistance, counterclock=counterclock)
    ax.patch.set_alpha(0.0)
    ax.set_aspect('auto', adjustable='box')
    fig.savefig(fig_save_loc + save_string + 'Session_Dom_Ratio.svg', transparent = True)

In [None]:
def plot_erp_beh_corr(fig, ax, color='slategray', marker='o', edgecolor='k', alpha=0.65, markersize=50, lw=2, rho_x=20, rho_y=57, rho_offset=1):

    ax.scatter(n_ratio, b_ratio, c=color, marker=marker, edgecolors=edgecolor, alpha=alpha, s=markersize, zorder=3, linewidth=lw)
    ax.fill_between(np.sort(n_ratio), xdot_quantiles_linear[0,:], xdot_quantiles_linear[1,:], color='grey', alpha = 0.5, linewidth=0, zorder=1)
    ax.plot(np.sort(n_ratio), xdot_quantiles_linear.mean(0), c='k', linewidth=lw, alpha=0.8, zorder=2)
    
    ax.text(s = r'$\rho_p$ = '+ str(np.round(erp_beh_cor.statistic,2)), y=rho_y, x=rho_x,  fontweight = 'normal')
    
    ax.set_xlabel('Peak ERP ratio')
    ax.set_ylabel('Behavioral performance ratio')
    
    modify_axis_spines(ax, which = ['x', 'y'], yticks=[40, 60], xticks = [25, 85]) 
    ax.patch.set_alpha(0.0)

fig.savefig(fig_save_loc + save_string + 'ERP_Beh_Correlation.svg', transparent=True, bbox_inches='tight')

In [None]:
def plot_cond_maxs(fig, ax, data=max_erps_df, linewidth=1, point_size=5, labelpad=10, tickpad=10, saturation=0.8, alpha=1, width_viol=0.4, width_box=0.4, bw=0.4,
                  pointplot=False, orient='v', linecolor='darkslategray', move=0, offset=0, legend=False, hide_ticks=True, label_text='Amplitude', title_text='Empirical',
                  titlepad=10, stat_loc='inside'):

    palette = {'Presence': presence_color, 'Absence':absence_color}
    
    significanceComparisons=[('Presence','Absence')]
    configuration = {'test':'Mann-Whitney',  'text_format':'star', 'loc':stat_loc, 'line_width':linewidth}

    if orient == 'h':
        x, y = 'value', 'condition'
    else:
        x, y = 'condition', 'value'
    
    fig_args = {'x': x, 'y': y, 'data': data, 'dodge': True, 'palette':palette, 'linecolor': linecolor,
                'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation, 'alpha': alpha,
                'width_viol':width_viol, 'width_box':width_box, 'bw':bw, 'box_capprops':{'linewidth': 0}}
    

    rainclouds = plot_RainClouds(ax=ax, orient=orient, **fig_args, cut=0, pointplot=pointplot, box_fliersize=0,
                                 box_whiskerprops=dict(linewidth=linewidth), clouds=False)
    
    if legend:
        handles, labels = rainclouds.get_legend_handles_labels()
        ax.legend(handles, labels=['Presence', 'Absence'], frameon=False)  # Adjust labels as needed
    
    annotator = Annotator(ax=ax, pairs=significanceComparisons, **fig_args, plot='boxplot', verbose=False)
    annotator.configure(**configuration).apply_test().annotate()

    data_max = data.max(numeric_only=True).to_numpy().round()[0]
    data_min = data.min(numeric_only=True).to_numpy().round()[0]
    
    if orient == 'h':
        ax.set_ylabel('')
        ax.set_xlabel(label_text, labelpad=labelpad)
        ax.tick_params(axis='y', length=0, pad=tickpad)
        ax.tick_params(axis='x',)
        modify_axis_spines(ax, which = ['x'], xticks=np.arange(data_min, data_max, 1)) 
        if hide_ticks:
            ax.set_yticks([])
    else:
        ax.set_xlabel('')
        ax.set_ylabel(label_text, labelpad=labelpad)
        ax.tick_params(axis='x', length=0, pad=tickpad)
        ax.tick_params(axis='y')
        modify_axis_spines(ax, which = ['y'], yticks=[data_min, data_max])
        if hide_ticks:
            ax.set_xticks([])

    ax.set_title(title_text, fontsize=plt.rcParams['font.size'], pad=titlepad)
    ax.patch.set_alpha(0.0)


In [None]:
def plot_erp_cebra(fig, ax, points, labels, parent_figure=False, axes_offset=0.01, markersize=50, manifold_alpha=0.5, cbar_title_rotation=0,
                   titlepad=10, cbar_width=5,  cbar_pad=2, cbar_titlepad=15, cax_rect=[0.86, 0.20, 0.02, 0.6], orientation='vertical'): 

    # Fit ellipsoid
    eansa = ls_ellipsoid(points[:, 0], points[:, 1], points[:, 2])
    center, axes, inve = polyToParams3D(eansa, False)
    
    axes -= axes_offset
    
    # Generate ellipsoid points
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x = center[0] + axes[0] * np.outer(np.cos(u), np.sin(v))
    y = center[1] + axes[1] * np.outer(np.sin(u), np.sin(v))
    z = center[2] + axes[2] * np.outer(np.ones_like(u), np.cos(v))
    
    # Define custom norm to control colormap orientation
    cmap_manifold = sns.diverging_palette(12, 250, s=250, l=20, as_cmap=True).reversed()
    
    # x_data, y_data, z_data = y_ppc_grid.mean(0), y_ppc_grid.mean(1), y_ppc_grid.mean(2)
    x_data, y_data, z_data = x, y, y_ppc_grid.mean(2)
    
    ellipsoid_color_data = np.sin(np.pi * x_data) * np.cos(np.pi * y_data) * np.sin(np.pi * z_data)
    
    norm = plt.Normalize(ellipsoid_color_data.min(), ellipsoid_color_data.max())
    ellipsoid_color_data = norm(ellipsoid_color_data)
    
    surface = ax.plot_surface(x, y, z, alpha=manifold_alpha, cmap=cmap_manifold.reversed(), facecolors=cmap_manifold(ellipsoid_color_data), linewidth=0)
    
    # Plot the data points
    ax.scatter(points[labels==0, 0], points[labels==0, 1], points[labels==0, 2], color='crimson', label = 'Absence', s=markersize)
    ax.scatter(points[labels==1, 0], points[labels==1, 1], points[labels==1, 2], color='dodgerblue', label = 'Presence', s=markersize)

    if parent_figure:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('bottom', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)
    
    cbar = fig.colorbar(surface, extendrect=False, cax=cax, orientation=orientation)
    
    if orientation == 'vertical':
        cbar.ax.set_ylabel('Presence Probability', labelpad=cbar_titlepad);
    else:
        cbar.ax.set_xlabel('Presence Probability', labelpad=cbar_titlepad);

    cbar.ax.set_xticks([0, 0.5 ,1])
    cbar.ax.tick_params(size=0)
    cbar.outline.set_visible(False)
    
    ax.set_xlim([0.15, 0.85])
    ax.set_ylim([0.15, 0.85])
    ax.set_zlim([0.15, 0.85])
   
    ax.set_title('ERP embedding manifold\n(empirical)', pad=titlepad, fontsize=plt.rcParams['font.size'])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.axis('off')
    ax.patch.set_alpha(0)
    # ax.view_init(elev=-25, azim=-0, roll=90,)
    

In [None]:
def plot_erp_fits(fig, ax, electrodes=[0,1], chains=0, percentiles=[0.05, 0.95], ds_sim=5, erp_colors=['dodgerblue', 'crimson'], xhat_colors=['skyblue', 'salmon'],
                  erp_alpha=1, xhat_alpha=1, erp_lw=1, xhat_lw=0, erp_ls='', erp_marker='.', erp_ms=1, handletextpad=1, markerscale=5, legend_spacing=2, lax_rect=[0.9, 0.25, 0.5, 1.7], conds_labels=['Pr', 'Ab'],
                  lbar_pad=0.2, lbar_width=1, handlelength=2, ncol=2, parent_figure=True):
    
    emp_time = task_ERPs.time
    xhat_time = np.arange(0, emp_time.shape[0], ds_sim)
    
    for e_ind in electrodes:    
    
        emp_erps = task_ERPs.isel(electrode=e_ind)
        xhat_erps = hmc_xhats.isel(chain=chains, electrode=e_ind).squeeze()
        
        for c_ind in range(2):

            if e_ind == 0:
                xhat_label = 'HMC-' + conds_labels[c_ind]
                emp_label = 'Emp-' + conds_labels[c_ind]
            else:
                xhat_label, emp_label = '', ''
                
            xhats_lo=np.quantile(xhat_erps.isel(condition=c_ind), percentiles[0], axis=0)
            xhats_hi=np.quantile(xhat_erps.isel(condition=c_ind), percentiles[1], axis=0)
    
            ax.fill_between(x=xhat_time, y1=xhats_lo, y2=xhats_hi, color=xhat_colors[c_ind], alpha=xhat_alpha, lw=xhat_lw, label=xhat_label)
            ax.plot(emp_time[::ds_sim], emp_erps.isel(condition=c_ind)[::ds_sim], color=erp_colors[c_ind], alpha=erp_alpha, lw=erp_lw, label=emp_label,
                   ls=erp_ls, marker=erp_marker, markersize=erp_ms)

    if parent_figure:
        divider = make_axes_locatable(ax)
        lax = divider.append_axes('top', size=str(lbar_width) + '%', pad=lbar_pad)
    else:
        lax = fig.add_axes(rect=lax_rect)

    handles, labels = ax.get_legend_handles_labels()
    
    lax.legend(handles, labels, loc="center", ncol=ncol, frameon=False, markerscale=markerscale, 
               columnspacing=legend_spacing, prop={'weight':'normal'}, handletextpad=handletextpad, handlelength=handlelength)
    lax.set_frame_on(False)
    lax.axis(False)
    
    ymin, ymax = emp_erps.min().round(), emp_erps.max().round()
    modify_axis_spines(ax, which = ['x', 'y'], yticks=[ymin, 0, ymax], xticks=[0, emp_time[-1]]) 
        
    ax.set_ylabel(r'Amplitude (mV)')
    ax.set_xlabel('Timestep')
    ax.patch.set_alpha(0.0)


In [None]:
def pval_to_string(pvals, format='*'):

    pstrings = []
    p_dict = {'****': 0.00001, '***': 0.001, '**': 0.01, '*': 0.05}  # Ordered thresholds

    if format == '*':
        for pval in pvals:
            for star in p_dict:
                if pval <= p_dict[star]:
                    pstrings.append(star)
                    break  # Stop after finding a match
            else:
                pstrings.append('ns')  # Append 'ns' for non-significant

    else:
        for pval in pvals:
            for star in p_dict:
                if pval <= p_dict[star]:
                    pstrings.append(r'p \leq ' + str(pval))
                    break  # Stop after finding a match
            else:
                pstrings.append(r'$p = $' + str(pval))  # Append 'ns' for non-significant
    
    return pstrings

In [None]:
def plot_erp_posterior_boxs(fig, ax, param_df, m_ind=0, method_prefix='KS', saturation=1, lw=1, labelpad=10, lgd_loc=(0.8, .75)):

    palette = {'Presence': 'dodgerblue', 'Absence': 'crimson'}
    
    test_rvals = erp_con_stats[m_ind, 0, :].to_numpy()
    test_pvals = erp_con_stats[m_ind, 1, :].to_numpy()
    sig_annotations = [method_prefix + ': ' + str(test_rvals[p_ind].round(2)) for p_ind, p_name in enumerate(erp_thetas.theta.to_numpy())]
    
    labels = [r'$g_1$', r'$g_2$', r'$g_3$', r'$g_4$']
    sig_pairs = [(('g_1','Presence'), ('g_1','Absence')),
                 (('g_2','Presence'), ('g_2','Absence')),
                 (('g_3','Presence'), ('g_3','Absence')),
                 (('g_4','Presence'), ('g_4','Absence'))]
    
    boxs = sns.boxplot(x="theta", y="value", data=param_df, palette=palette, showfliers=False,
                   hue='condition', saturation=saturation, dodge=True, ax=ax, linewidth=lw, capprops={'linewidth': 0}) 
    
        
    annotator = Annotator(x = 'theta', y = 'value', hue='condition', pairs=sig_pairs, data=param_df, ax=boxs, verbose=False)
    annotator.set_custom_annotations(sig_annotations)
    annotator.configure(line_width=lw)
    annotator.annotate()

    modify_axis_spines(ax, which=['y'], yticks=[0, 0.5, 1])

    ax.set_xticklabels(labels)
    ax.tick_params(axis='both', which='major')
    ax.tick_params(axis = 'x', length=0)
    ax.legend(ncol=1, loc=lgd_loc, frameon=False)
    ax.set_ylabel('Pooled posterior value', labelpad=labelpad)
    ax.set_xlabel('')
    ax.patch.set_alpha(0.0)


In [None]:
def plot_pn_cebra(fig, ax, points, labels, parent_figure=False, axes_offset=0.01, markersize=50, manifold_alpha=0.2, manifold_color='darkgray', cbar_title_rotation=0,
                  titlepad=10, cbar_width=5,  cbar_pad=2, cbar_titlepad=15, cax_rect=[0.86, 0.20, 0.02, 0.6], orientation='vertical', range_offset=0.1,
                  elev=0, roll=180, azim=90): 
    
    label_colors = labels.astype(object)
    label_colors[label_colors==0] = 'r'
    label_colors[label_colors==1] = 'b'
    
    # Fit ellipsoid
    eansa = ls_ellipsoid(points[:, 0], points[:, 1], points[:, 2])
    center, axes, inve = polyToParams3D(eansa, False)
    
    axes -= axes_offset
    
    # Generate ellipsoid points
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x = center[0] + axes[0] * np.outer(np.cos(u), np.sin(v))
    y = center[1] + axes[1] * np.outer(np.sin(u), np.sin(v))
    z = center[2] + axes[2] * np.outer(np.ones_like(u), np.cos(v))
    
    # Define custom norm to control colormap orientation
    cmap_manifold = sns.diverging_palette(12, 250, s=250, l=20, as_cmap=True).reversed()
    
    x_data, y_data, z_data = x, y, y_ppc_grid.mean(2)
    
    ellipsoid_color_data = np.sin(np.pi * x_data) * np.cos(np.pi * y_data) * np.sin(np.pi * z_data)
    
    surface = ax.plot_surface(x, y, z, alpha=manifold_alpha, color=manifold_color, linewidth=0)
    
    # Plot the data points
    ax.scatter(points[labels==0, 0], points[labels==0, 1], points[labels==0, 2], color='crimson', label = 'Absence', s=markersize)
    ax.scatter(points[labels==1, 0], points[labels==1, 1], points[labels==1, 2], color='dodgerblue', label = 'Presence', s=markersize)
       
    ax.set_title('PN embedding manifold\n(predicted)', pad=titlepad, fontsize=plt.rcParams['font.size'])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    data_min, data_max = points.min(), points.max()

    data_min, data_max = -0.6, 0.6
    
    ax.set_xlim([data_min-range_offset, data_max+range_offset])
    ax.set_ylim([data_min-range_offset, data_max+range_offset])
    ax.set_zlim([data_min-range_offset, data_max+range_offset])
    ax.patch.set_alpha(0)
    ax.axis('off')
    ax.view_init(elev=elev, roll=roll, azim=azim)


In [None]:
def plot_syn_currents(fig, axes, data_maxs=cond_thetas_max.to_numpy(), x_loc=0.0, y_loc=-0.8, x_offset=0, cond_labels=['Presence', 'Absence'], cond_colors=['royalblue', 'firebrick'], auc_alpha=0.9,
                     lgd_bbox=(1., 0.9), text_x=58, text_y=0.03, suptitle_fontsize=None, labelpad=0, linewidth=0.7, titleweight='normal'):

    num_cond = 2

    I_e, I_i, EIRatio, time_vec = compute_syn_current(data_maxs)
    I_e /= I_e.max()
    I_i /= I_i.max()
    EIRatio /= EIRatio.max()

    if suptitle_fontsize == None:
        suptitle_fontsize = plt.rcParams['font.size']

    for ax_ind, ax in enumerate(axes):
        
        ax.set_title(cond_labels[ax_ind])
        ax.fill_between(time_vec, I_e[:,ax_ind], color=cond_colors[0], alpha=auc_alpha, edgecolor='k', label='Excitatory', lw=linewidth)
        ax.fill_between(time_vec, -I_i[:,ax_ind], color=cond_colors[1], alpha=auc_alpha, edgecolor='k', label='Inhibitory', lw=linewidth)
    
        ax.text(text_x, text_y, s = 'E/I Ratio: ' + str(round(EIRatio[ax_ind], 2)))
        
        if ax_ind==1:    
            lgd = ax.legend(frameon=False, bbox_to_anchor=lgd_bbox)
        
        ax.tick_params(axis='both', which='major')
        ax.set_title(cond_labels[ax_ind], weight=titleweight, fontsize=suptitle_fontsize)
        modify_axis_spines(ax, which=['x', 'y'], xticks=[0, time_vec[-1].round(-2)], yticks=[-1, 0, 1])
        ax.set_xlabel('Timestep', labelpad=labelpad)
        ax.patch.set_alpha(0.0)

    axes[0].set_ylabel('Normalized PSP', x=x_loc)

In [None]:
def plot_epsps(fig, ax, x_loc=0.0, y_loc=-0.8, x_offset=0, cond_labels=['Presence', 'Absence'], cond_colors=['royalblue', 'firebrick'], auc_alpha=0.5,
                     lgd_bbox=(1., 0.9), text_x=58, text_y=0.03, suptitle_fontsize=None, labelpad=0, linewidth=0.7, titleweight='normal'):

    num_cond = 2
    
    I_e, I_i, EIRatio, time_vec = compute_syn_current(cond_thetas_max)
    I_e /= I_e.max()
    I_i /= I_i.max()
    
    if suptitle_fontsize == None:
        suptitle_fontsize = plt.rcParams['font.size']

    for c_ind in range(2):
        
        ax.fill_between(time_vec, I_e[:,c_ind], color=cond_colors[c_ind], alpha=auc_alpha, edgecolor='k', label=cond_labels[c_ind], lw=linewidth)
                
    ax.tick_params(axis='both', which='major')
    ax.set_title('Condition EPSP', weight=titleweight, fontsize=suptitle_fontsize)
    modify_axis_spines(ax, which=['x', 'y'], xticks=[0, time_vec[-1].round(-2)], yticks=[0, 0.6])
    ax.set_xlabel('Timestep', labelpad=labelpad)
    ax.patch.set_alpha(0.0)

    ax.set_ylabel('Normalized PSP', x=x_loc)

### Run

In [None]:
height_ratios = np.array([1, 1, 1])

right_space = 0.95
left_space = 0.09
wspace = 0.6
hspace = 0.25

fig = plt.figure(figsize=(9,7.5), frameon=False)
(fig_r1, fig_r2, fig_r3) = fig.subfigures(3, 1, height_ratios=height_ratios, hspace=hspace, frameon=False)


####################### 1st Row sub-figure #########################################


# Defining a gridspec where the subplots are places
fig_r1_wr = [1, 1, 1, 1]
grid_r1 = gridspec.GridSpec(1, 4, width_ratios=fig_r1_wr, figure=fig_r1, wspace=0.6, left=left_space, right=right_space)

ax_all_erps = fig_r1.add_subplot(grid_r1[0])
ax_peak_stats = fig_r1.add_subplot(grid_r1[1])
ax_beh_corr = fig_r1.add_subplot(grid_r1[2])
ax_emp_manifold = fig_r1.add_subplot(grid_r1[3], projection='3d')

plot_emp_ERPs(fig_r1, ax_all_erps, task_ERPs, alpha=0.2, color='slategray')
plot_cond_maxs(fig_r1, ax_peak_stats, point_size=3, labelpad=10, tickpad=22)
plot_erp_beh_corr(fig_r1, ax_beh_corr, color='lightsteelblue', markersize=10, rho_offset=2,rho_y=43.5, rho_x=54, lw=1)
plot_erp_cebra(fig_r1, ax_emp_manifold, X.copy(), Y.copy(), orientation='horizontal', markersize=10,
               cax_rect=[0.8, 0.1 ,0.15, 0.05], cbar_title_rotation=0, cbar_titlepad=7, titlepad=8)


####################### 2nd Row sub-figure #########################################
fig_r2_wr = [1, 1, 1, 1]
grid_r2 = gridspec.GridSpec(1, 4, width_ratios=fig_r2_wr, figure=fig_r2, wspace=0.6, left=left_space, right=right_space)

ax_sample_sim = fig_r2.add_subplot(grid_r2[0])
ax_pn_maxs = fig_r2.add_subplot(grid_r2[1])
ax_latent_traj = fig_r2.add_subplot(grid_r2[2], projection='3d')
ax_pn_manifold = fig_r2.add_subplot(grid_r2[3], projection='3d')

plot_cond_maxs(fig_r2, ax_pn_maxs, data=max_pn_df, title_text='Predicted', point_size=3, labelpad=10, tickpad=22)
plot_sim_ERPs(fig_r2, ax_sample_sim, erp_prior_dataset, alpha=0.2, color='slategray')
plot_erp_trajectories(fig_r2, ax_latent_traj, labelpad=-12, alpha=0.5, traj_lw=2, titlepad=23)

plot_pn_cebra(fig_r1, ax_pn_manifold, pn_embeddings.to_numpy(), pn_conditions, manifold_alpha=0.2, manifold_color='gray', markersize=10)

####################### 4th Row sub-figure #########################################
fig_r3_wr = [1, 1, 1, 1]

grid_r3 = gridspec.GridSpec(1, 4, width_ratios=fig_r3_wr, figure=fig_r3, wspace=0.6, left=left_space, right=right_space, top=1.5, bottom=0.01)

ax_post_boxs = fig_r3.add_subplot(grid_r2[:2])
ax_EI_1 = fig_r3.add_subplot(grid_r2[2])
ax_EI_2 = fig_r3.add_subplot(grid_r2[3])
ax_EI_curr= [ax_EI_1, ax_EI_2]

plot_erp_posterior_boxs(fig_r3, ax_post_boxs, erp_theta_df, lgd_loc=(0.7, 0.75))
plot_syn_currents(fig_r3, ax_EI_curr, lgd_bbox=(1.2, 0.35), text_x=40, y_loc=-0.8, x_offset=10, labelpad=-8)


# ####################### Parent Figure Configuration #########################################
c4_x = np.linspace(0.02, 0.75, 4)
c2_x = [0.02, 0.5]
cols_x = [c4_x, c4_x, c2_x]
rows_y = [0.98, 0.65, 0.30]

fig_coords = [[cols_x[r_ind][c_ind], rows_y[r_ind]] for r_ind in range(3) for c_ind in range(len(cols_x[r_ind]))]
fig_letters = all_letters[:len(fig_coords)]
text_kwargs = dict(fontsize=14, fontweight='bold', ha='right', va='center')

for f_ind, f_coords in enumerate(fig_coords):
    fig.text(x=f_coords[0], y=f_coords[1], s=fig_letters[f_ind], **text_kwargs)

fig.show()

In [None]:
fig.savefig(fig_save_loc + 'Article/' + 'Fig2.png', dpi=800, transparent=True)
fig.savefig(fig_save_loc + 'Article/' + 'Fig2.svg', transparent=True)

# EEG

## Empirical

In [None]:
sp_dir = ''

poly5_dirs = glob.glob(sp_dir + '**/*.Poly5', recursive=True)
raw_clean_dirs = glob.glob(sp_dir + '**/*_raw_clean.fif', recursive=True)
source_localized_dirs = glob.glob(sp_dir + '**/*_stc.fif', recursive=True)
inv_operator_dirs = glob.glob(sp_dir + '**/*_inv.fif', recursive=True)
ica_dirs = glob.glob(sp_dir + '**/*_ica.fif', recursive=True)

schaefer_sc_dir = 's'
SC_dirs = glob.glob(schaefer_sc_dir+'**/**/weights_Schaefer2018_400Parcels_7Networks_15M.txt', recursive=True)
SC = np.load(parent_preprocess_dir + 'Schaefer2018_SC.npy')
num_node = len(SC)

subject_folders = glob.glob(sp_dir + 'pongFac23*')
subjects = np.array([subj.split('_')[-1] for subj in subject_folders])

eeg_sweep_dir = ''
eeg_sbi_input = ''
eeg_sbi_output = ''

eeg_preprocess_dir = parent_preprocess_dir + 'SP_Facilitation/'

# ICA templates for artifact removal
template_array = xr.load_dataarray(eeg_preprocess_dir + 'ica_component_templates_extended.nc').load()
blink_template = template_array.sel(template = 'blink').to_numpy()
saccade_template = template_array.sel(template = 'saccade').to_numpy()
source_epochs = xr.open_dataarray(parent_preprocess_dir + 'source_epochs_7Networks_bm_full.nc')
sample_source_epochs = xr.load_dataarray(parent_preprocess_dir + 'sample_source_epochs.nc').load()

### Behavior

In [None]:
save_string = 'EEG/'

event_label = 'threshTime'

mapped_dict = {-1: 'n', 1: 'p', 0.1: 'a-p', 1.0: 'p-a', 1: 'Male', 2: 'Female'}
 
pong_results = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_results.nc').load()
pong_movement_raw = xr.load_dataarray(parent_preprocess_dir + 'agg_pong_movement_' + event_label + '_lock.nc').load()

f_order = 2
low_cut = 12
lowpass = signal.butter(f_order, low_cut, fs = 120, btype = 'lp', output = 'sos') 
pong_movement_data = signal.sosfiltfilt(lowpass, pong_movement_raw.sel(source='movement'), axis = 0)
pong_movement = pong_movement_raw.copy()
pong_movement[:, :, 0, :] = pong_movement_data

conditions = pong_results.sel(variable = 'cond')
intercepts = pong_results.sel(variable = 'result')

pcond = conditions == 1
acond = conditions == 0

negfb = intercepts == -1
posfb = intercepts == 1

subject_gens = pong_results.sel(variable = 'gender', trial = 0).to_numpy()
gen_list = [mapped_dict[gen] for gen in subject_gens]

num_subjects = len(subjects)
male_subjects = subject_gens == 1
female_subjects = subject_gens == 2

subject_groups = [female_subjects, male_subjects]
subject_group_names = ['Female', 'Male']
num_subject_groups = len(subject_groups)
subject_group_dict = {subject_group_names[ind]: subject_groups[ind] for ind in range(num_subject_groups)}
subj_per_group = female_subjects.sum()

In [None]:
p_beh = np.zeros(len(subjects))
a_beh = np.zeros(len(subjects))

for sInd, subj in enumerate(subjects):
    
    if sInd == 0:
        labels = ['Presence', 'Absence']
    else:
        labels = ['', '']
    
    sub_bs = pong_results.sel(subject = subj, variable = 'ms')
    sub_bap = pong_results.sel(subject = subj, variable = 'BAP_new')
    sub_bdp = pong_results.sel(subject = subj, variable = 'BDP_new')
    sub_bap[sub_bap == 0] = 1
        
    sub_movement = np.abs(pong_movement.sel(subject = subj, source = 'movement')/sub_bap)
    
    sub_speed = np.gradient(sub_movement, axis = 0)
    sub_speed = xr.DataArray(sub_speed, coords = sub_movement.coords, dims = sub_movement.dims)

    stable_trials = ~(sub_movement[0,:] >= sub_movement[-1,:])    

    nfb = negfb.sel(subject = subj)
    pfb = posfb.sel(subject = subj)
    
    pres = pcond.sel(subject = subj)
    abse = acond.sel(subject = subj)

    p_tr = pres & stable_trials
    a_tr = abse & stable_trials
        
    p_sum = (p_tr & pfb).sum('trial')
    a_sum = (p_tr & pfb).sum('trial')

    vel_metric = sub_speed
    
    p_met = vel_metric.sel(trial = p_tr).mean('time').mean('trial')
    a_met = vel_metric.sel(trial = a_tr).mean('time').mean('trial')
   
    p_beh[sInd] = p_met
    a_beh[sInd] = a_met

In [None]:
beh_norm_concat = preprocessing.MinMaxScaler().fit_transform(np.concatenate((p_beh, a_beh)).reshape(-1,1)).ravel()
p_beh = beh_norm_concat[:num_subjects]
a_beh = beh_norm_concat[num_subjects:]

cdiff_method = 'percentage'

if cdiff_method == 'percentage':
    beh_ratio = p_beh/(p_beh+a_beh) * 100
else:
    beh_ratio = p_beh-a_beh

beh_ratio_groups = np.zeros((2,14))
beh_ratio_groups[0,:] = beh_ratio[female_subjects]
beh_ratio_groups[1,:male_subjects.sum()] = beh_ratio[male_subjects]
beh_ratio_array = xr.DataArray(beh_ratio_groups, dims = ('group', 'subject'), coords = {'group': ['Female', 'Male']})

pong_beh_df = pd.DataFrame({'Presence': p_beh, 'Absence': a_beh, 'gender' : gen_list}).melt(id_vars = 'gender', value_vars = ['Presence', 'Absence'], var_name = 'condition')

### Loading preprocessed data

In [None]:
parc = 'Schaefer2018_400Parcels_7Networks_order'
yeo_networks = np.array(['DorsAttn', 'SalVentAttn', 'SomMot', 'Vis', 'Cont' ,'Default', 'Limbic'])
yeo_networks_shortened = ['DAN', 'VAN', 'SMN', 'VIS', 'FPN', 'DMN', 'LIM']
yeo_networks_shortened_dict = {n_name: yeo_networks_shortened[n_ind] for n_ind, n_name in enumerate(yeo_networks)}

fs_dir = ''
fs_label_dir = ''

schaefer_labels = mne.read_labels_from_annot('fsaverage', parc=parc, regexp='7Network', subjects_dir=fs_label_dir)
schaefer_label_names = source_epochs.label.to_numpy()

network_label_dict = {n_name: np.array([label.name for label in mne.read_labels_from_annot('fsaverage', parc=parc, regexp=n_name, subjects_dir=fs_label_dir)], dtype=object)
                      for n_name in yeo_networks}

network_dimensions = {n_name: len(n_labels) for n_name, n_labels in network_label_dict.items()}
network_names = list(network_label_dict.keys())
network_inds = {n_name: np.where(np.isin(schaefer_label_names, n_labels))[0] for n_name, n_labels in network_label_dict.items()}

In [None]:
network_color_dict = {'DAN': '#027710',
                      'VAN': '#C639FC',
                      'SMN': '#4882B5',
                      'VIS': '#7A1287',
                      'FPN': '#E89424',
                      'DMN': '#CE3F50',
                      'LIM': '#879D5D'}

In [None]:
label_colors = np.array([label.color for label in schaefer_labels])
network_label_colors = network_color_dict

In [None]:
network_label_dict_sorted = {}
label_names_sorted = np.array([])
for n_ind, n_name in enumerate(yeo_networks):
    
    net_lh = np.array([label.name for label in mne.read_labels_from_annot('fsaverage', parc=parc, regexp=n_name, hemi='lh', subjects_dir=fs_label_dir)], dtype=object)
    net_rh = np.array([label.name for label in mne.read_labels_from_annot('fsaverage', parc=parc, regexp=n_name, hemi='rh', subjects_dir=fs_label_dir)], dtype=object)

    net_labels = np.concatenate((net_lh, net_rh))

    network_label_dict_sorted[n_name] = net_labels

    if n_ind == 0:
        label_names_sorted = net_labels
    else:
        label_names_sorted = np.concatenate((label_names_sorted, net_labels))

network_inds_sorted = {yeo_networks_shortened_dict[n_name]: np.where(np.isin(label_names_sorted, n_labels))[0] for n_name, n_labels in network_label_dict_sorted.items()}

In [None]:
sbi_eeg_dir = ''

num_sim = 20000
num_sim_networks = 7
prior_type = 'Informative'

sim_name = 'JR_SDE_SBI_C1_' + str(num_sim_networks) + 'Networks_' + prior_type + '_'

sbi_eeg_input = sbi_eeg_dir + 'input/' + sim_name + str(num_sim)
sbi_eeg_output = sbi_eeg_dir + 'output/' + sim_name + str(num_sim)

In [None]:
fc_sum_sim = xr.load_dataarray(sbi_eeg_input + '/fc_sum_sim.nc')
fc_eig_sim = xr.load_dataarray(sbi_eeg_input + '/fc_eig_sim.nc')
theta_eeg = torch.load(sbi_eeg_input+'/theta.pt')

feature_order =('network', 'summary')
sel_feature_names = ['sum']

fc_summary_sim = xr.concat((fc_sum_sim, fc_eig_sim), dim='summary')


x_feature_array = fc_summary_sim.sel(summary=sel_feature_names)
x_feature_array.coords['network'] = yeo_networks_shortened
x_feature_array = x_feature_array.stack({'feature': feature_order}).squeeze()

num_features = len(x_feature_array.feature)

In [None]:
sim_array_sbi = xr.open_dataarray(sbi_eeg_input + '/simulations_sample.nc')
ctx_labels = [lname for lname in sim_array_sbi.label.to_numpy() if lname.startswith('7Networks')]

fc_sum_emp = xr.load_dataarray(parent_preprocess_dir + 'fc_sum_emp_7Networks_norm_full.nc')
fc_sum_emp_netnorm = xr.load_dataarray(parent_preprocess_dir + 'fc_sum_emp_netnorm_7Networks_norm_full.nc')
                               
fc_sum_groups = fc_sum_emp_netnorm.sel(source='epochs').squeeze().rename(subject='group')

fc_sum_groups.coords['group'] = gen_list
fc_sum_groups.coords['network'] = yeo_networks_shortened
fc_sum_groups = fc_sum_groups.groupby('group').mean()

fc_sum_groups -= fc_sum_groups.min()
fc_sum_groups /= fc_sum_groups.max()

fc_sum_df = fc_sum_groups.reset_coords(names=['summary', 'source'], drop=True).to_dataframe(name='value').reset_index()
fc_sum_df_reindexed = fc_sum_df.set_index('group')
fc_sum_polar = {group: fc_sum_groups.sel(group=group).to_numpy() for group in fc_sum_groups.group.to_numpy()}

sample_fc = compute_FC(source_epochs.sel(subject=subjects[sInd], condition='Presence', label=label_names_sorted).mean('trial').T)
net_comm = list(network_inds_sorted.values())

emp_posterior = xr.load_dataarray(sbi_eeg_output + '/JR_7Networks_empirical_posterior.nc')
posterior_means = xr.load_dataarray(sbi_eeg_output + '/JR_7Networks_empirical_posterior_means.nc')
posterior_means_norm = xr.load_dataarray(sbi_eeg_output + '/JR_7Networks_empirical_posterior_means_norm.nc')
group_posteriors = xr.load_dataarray(parent_preprocess_dir + '/group_posteriors.nc')
group_posteriors.coords['network'] = yeo_networks_shortened

attn_networks = ['DAN', 'VAN', 'FPN']
group_posteriors_attn = group_posteriors.sel(network=attn_networks).stack({'aggSample': ('group', 'condition', 'sample', 'network')})
group_posteriors_attn = group_posteriors_attn.to_pandas().T.reset_index().rename(columns={0:'value'})

posterior_means_group = posterior_means.reindex_like(posterior_means_norm).rename(subject='group')
posterior_means_group.coords['group'] = posterior_means_norm.coords['group']
posterior_means_group = posterior_means_group.groupby('group').mean()
posterior_difference = posterior_means_group.sel(condition='Presence') - posterior_means_group.sel(condition='Absence')
eeg_posterior_stats = xr.load_dataarray(parent_preprocess_dir + 'eeg_posterior_stats.nc')

fc_stats_sbi= xr.load_dataarray(parent_preprocess_dir + '/fc_stats_sbi.nc')
fc_stats_nl_sbi= xr.load_dataarray(parent_preprocess_dir + '/fc_stats_nl_sbi.nc')
fc_linreg_sbi= xr.load_dataarray(parent_preprocess_dir + '/fc_linreg_sbi.nc')
neu_ratio_sbi= xr.load_dataarray(parent_preprocess_dir + '/neu_ratio_sbi.nc')

In [None]:
fc_emp_groups = fc_sum_emp_netnorm.sel(source='epochs').squeeze().rename(subject='group')

fc_emp_groups.coords['group'] = gen_list
fc_emp_groups.coords['network'] = yeo_networks_shortened
fc_emp_diff = (fc_emp_groups.sel(condition='Presence') - fc_emp_groups.sel(condition='Absence')).groupby('group').mean()

fcg_mins = fc_emp_diff.min()
fcg_maxs = fc_emp_diff.max()

fc_emp_diff = (-1 + (fc_emp_diff - fcg_mins)*2/(fcg_maxs - fcg_mins))/2

## Multi-Panel

In [None]:
def return_brain_axes_from_grid(grid, num_row=2, num_col=2, row_start=3, col_slice=21, col_offset=0):
    axes_all = np.zeros((num_row, num_col), dtype=object)
    for h_ind in range(axes_all.shape[0]):
        for v_ind in range(axes_all.shape[1]):
            grid_slice = grid[h_ind+row_start, col_slice*(v_ind)+col_offset:col_slice*(v_ind+1)+col_offset]
            print(grid_slice)
            axes_all[h_ind, v_ind] = fig.add_subplot(grid_slice)

    return axes_all

In [None]:
def plot_fcdiff_heatmap(fig, axes, hemi='lh', view='lateral', group_pad=-5, view_pad=10, labelpad=10, titlepad=5, cmap_data=[], image_data=[], draw_title=True,
                       wspace=-0.1, hspace=-0.1, left=0, right=0, cax_rect=[0.7, 0.2, 0.03, 0.6], cbar_width=0.3, cbar_pad=0.1, parent_figure=True, title_x=0.3, title_y=1.05,): 

    _, sm = compute_pdiff_colormap(data=cmap_data)
    
    for g_ind, _ in enumerate(image_data.group):
        
        ax = axes[g_ind]

        data = image_data[g_ind, ...].copy().squeeze()
        group = data.group.to_numpy()
        data = data.to_numpy()
        
        # Create an empty alpha channel with default value 1 for all pixels
        alpha_channel = np.ones((data.shape[0],data.shape[1]), dtype=np.uint8)*255
        data_rgba = np.dstack((data, alpha_channel))

        view_mask = brain_masks[hemi+'-'+view]

        data_rgba[view_mask[:, 0], view_mask[:, 1]] = [0, 0, 0, 0]  # Set the alpha channel (transparency) to 0 for white pixel

        ax.imshow(data_rgba)

        ax.set_xlabel(group, labelpad=group_pad)
        ax.xaxis.set_label_position('top')

        # Hide axis lines, ticks, and tick labels
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        ax.set_xticks([])  # Hide x-axis tick labels
        ax.set_yticks([])  # Hide y-axis tick labels
        # ax.
    
    if parent_figure:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)
        fig.subplots_adjust(wspace=wspace, hspace=hspace)
    
    cbar = fig.colorbar(sm, extendrect=False, cax=cax)
    cbar.set_label('Condition difference', labelpad=labelpad)  # Set font size for the colorbar label
    cbar.ax.tick_params(length=0)  # Remove colorbar ticks
    cbar.outline.set_visible(False)  # Remove colorbar frame
    cbar.ax.set_yticks([-0.5, 0.5])

    if draw_title:
        fig.text(x=title_x, y=title_y, s='INI', fontsize=plt.rcParams['font.size'])

In [None]:
def plot_con_beh_correlation(fig, axes, num_ticks=2, rho_fontsize=10, markersize=50, markerscale=3, linewidth=1, labelpad=10, bbox_to_anchor=(.6,1.),
                             wspace=1, hspace=0.5, left=0.1, bottom=0.2, lax_rect=[0.5, 0.9, 0.1, 0.1], xlabel_pos=0, ylabel_pos=0, legend_spacing=1,
                             ncol=1, rho_pos_groups=np.array([[0.7, 0.5],[0.6, 0.4]]), rho_offset=np.array([0, -0.1]), cbar_width=0.3, cbar_pad=0.1, parent_figure=True,
                            g_colors=['lightsteelblue', 'bisque'], g_edgecolor='k', g_styles=['D', 'o'], stat_weight='normal', net_weight='normal', handletextpad=1):

    xticks = np.linspace(20,100, num_ticks)
    
    for gInd, group_name in enumerate(subject_group_names):
    
        subject_group = subject_group_dict[group_name]
        
        b_ratio = beh_ratio_array.sel(group = group_name)
        b_ratio = b_ratio[:subject_group.sum()].to_numpy()


        plotted_networks = fc_stats_sbi.network.to_numpy()[:-1]

        rho_pos = rho_pos_groups[gInd, :]
        
        for net_ind, network in enumerate(plotted_networks):
    
            if gInd == 1:
                x_pearson = 70
                y_pearson = 50
                ymin_offset = 5
            else:
                axes[gInd, net_ind].set_xlabel(yeo_networks_shortened[net_ind], weight=net_weight, c='darkslategray', labelpad=labelpad)
                axes[gInd, net_ind].xaxis.set_label_position('top')
                x_pearson = 70
                y_pearson = 40
                ymin_offset = -5
            
            if net_ind == 0:
                label = group_name
            else:
                label = ''
            
            n_ratio = neu_ratio_sbi.sel(group=group_name, network=network).to_numpy()[:subject_group.sum()]
            sel_linreg = fc_linreg_sbi.sel(group=group_name, network=network)[:,:subject_group.sum()]
    
            axes[gInd, net_ind].scatter(n_ratio, b_ratio, c=g_colors[gInd], marker=g_styles[gInd], edgecolors=g_edgecolor,
                                  alpha=0.65, label=label, s=markersize, zorder=3, linewidth=linewidth)
            
            axes[gInd, net_ind].fill_between(np.sort(n_ratio), sel_linreg.sel(bound='l_bound'), sel_linreg.sel(bound='u_bound'), color='grey', alpha=0.5,
                                             linewidth=0, zorder=1)
            
            axes[gInd, net_ind].plot(np.sort(n_ratio), sel_linreg.mean('bound'), c='k', linewidth=linewidth, alpha=0.8, zorder=2)
    
    
            axes[gInd, net_ind].annotate(r'$\rho_p$ = '+ str(fc_stats_sbi.sel(stat='rvalue', group=group_name, network=network).to_numpy().round(2)), fontweight=stat_weight,
                                      xy=rho_pos, va='top', xycoords='axes fraction', fontsize=rho_fontsize, zorder=4)
            
            axes[gInd, net_ind].annotate(r'$\rho_s$ = '+ str(fc_stats_nl_sbi.sel(stat='rvalue', group = group_name, network=network).to_numpy().round(2)), fontweight=stat_weight,
                                      xy=rho_pos+rho_offset, va='top', xycoords='axes fraction', fontsize=rho_fontsize, zorder=4)
            
            yticks = np.linspace(b_ratio.min().round(-1)+ymin_offset, b_ratio.max().round(-1), num_ticks)
            modify_axis_spines(axes[gInd, net_ind], which=['x', 'y'], yticks=yticks, xticks=xticks)

    for ax in axes.ravel():
        ax.patch.set_alpha(0.0)
    

    if parent_figure:
        divider = make_axes_locatable(ax)
        lax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
        axes[1,3].set_xlabel(r'$g_2$' +' ratio', labelpad=labelpad, fontsize=plt.rcParams['font.size'])
        axes[1,3].set_ylabel('Behavioral ratio', labelpad=labelpad, fontsize=plt.rcParams['font.size'])
    else:
        lax = fig.add_axes(rect=lax_rect)
        fig.supxlabel(r'$g_2$' +' ratio', y=xlabel_pos, fontsize=plt.rcParams['font.size'])
        fig.supylabel('Behavioral ratio', x=ylabel_pos, fontsize=plt.rcParams['font.size'])
    
    handles, labels = axes[0, 0].get_legend_handles_labels()  # You can choose any subplot
    handle_m, label_m = axes[1, 0].get_legend_handles_labels() 
    
    handles.extend(handle_m)
    labels.extend(label_m)
    
    lax.legend(handles, labels, loc="upper center", ncol=ncol, frameon=False, markerscale=markerscale, labelspacing=legend_spacing, prop={'weight':'normal'},
              handletextpad=handletextpad)
    lax.set_frame_on(False)
    lax.axis(False)

In [None]:
def plot_simulation_timeseries(fig, ax, s_ind=-10, num_plotted_labels=400, lw=1, alpha=0.1, color='slategray', titlepad=5):
    
    label_inds = np.arange(num_plotted_labels)
    time_slice = np.arange(2400, 2995, 5)
    plotted_labels = np.random.choice(label_inds, size=num_plotted_labels)
    
    plot_data = sim_array_sbi.sel(label=ctx_labels, simulation=s_ind, time=time_slice)[:, plotted_labels].load()
    
    time_vec = plot_data.time.to_numpy()
    intrp_length = time_vec.shape[0]*5
    time_vec_intrp = np.linspace(time_vec[0], time_vec[-1], intrp_length)
    plot_data_intrp = np.zeros((intrp_length, num_plotted_labels))
    
    for l_ind in range(num_plotted_labels):
        interpFunc = interpolate.interp1d(time_vec, plot_data[:, l_ind].to_numpy(), kind='cubic')
        data = interpFunc(time_vec_intrp)
        plot_data_intrp[:, l_ind] = data
        
    ax.plot(time_vec_intrp, plot_data_intrp[:,:], c=color, lw=lw, alpha=alpha);
    
    modify_axis_spines(ax, which=['x', 'y'], yticks=np.arange(7, 9.5, 2),
                       xticks=[time_slice[0].round(-2), time_slice[-1].round(-2)])

    ax.set_xlabel('Time (ms)', labelpad=-5)
    ax.set_ylabel('Voltage (mV)', labelpad=5)
    ax.set_title('Sample simulation', fontsize=plt.rcParams['font.size'], pad=titlepad)
    ax.set_xticklabels([0, 500])
    ax.patch.set_alpha(0.0)
    
    # ax.set_ylim([plot_data_intrp.min().round(1)-0.5, plot_data_intrp.max().round()+0.25])
    
    # ax.margins(x=0.15)
    # # ax.margins(y=-0.4)
    
    # ax.axis('off')
    
    # scalebar_y = ScaleBar(dx=1, rotation='vertical', length_fraction=0.3, fixed_value=1, frameon=False, scale_formatter=lambda value,unit: f"{value} mV",
    #                           location = 'center left', scale_loc='left', font_properties={'size':tick_size})    
    # scalebar_x = ScaleBar(dx=1, length_fraction=0.3, fixed_value=50, frameon=False, scale_formatter=lambda value,unit: f"{value} ms", location = 'lower center', font_properties={'size':tick_size})
    
    # ax.add_artist(scalebar_y)
    # ax.add_artist(scalebar_x)
    
    # fig.tight_layout()    

In [None]:
def pval_to_string(pvals, format='*'):

    pstrings = []
    p_dict = {'****': 0.00001, '***': 0.001, '**': 0.01, '*': 0.05}  # Ordered thresholds

    if format == '*':
        for pval in pvals:
            for star in p_dict:
                if pval <= p_dict[star]:
                    pstrings.append(star)
                    break  # Stop after finding a match
            else:
                pstrings.append('ns')  # Append 'ns' for non-significant

    else:
        for pval in pvals:
            for star in p_dict:
                if pval <= p_dict[star]:
                    pstrings.append(r'p \leq ' + str(pval))
                    break  # Stop after finding a match
            else:
                pstrings.append(r'$p = $' + str(pval))  # Append 'ns' for non-significant
    
    return pstrings

In [None]:
def wasserstein_distance(sample1, sample2, n_permutations=100):

    combined_data = np.concatenate([sample1, sample2])
    observed_distance = stats.wasserstein_distance(sample1, sample2)
    
    # Number of permutations
    p_value = 0
    for _ in range(n_permutations):
      # Shuffle data points within combined sample
      shuffled_data = np.random.permutation(combined_data)
      shuffled_sample1 = shuffled_data[:len(sample1)]
      shuffled_sample2 = shuffled_data[len(sample1):]
      
      # Calculate distance for shuffled samples
      shuffled_distance = stats.wasserstein_distance(shuffled_sample1, shuffled_sample2)
      
      # Update p-value if shuffled distance is greater than observed
      if shuffled_distance >= observed_distance:
        p_value += 1
    
    p_value /= n_permutations

    return (observed_distance, p_value)

In [None]:
def plot_speed_boxes(fig, ax, linewidth=1, alpha=1, saturation=1, order=['Female', 'Male'], point_size=3, linecolor='k',
                     width_viol=0.5, bw=0.7, width_box=0.3, stat_loc='inside', dodge=False):

    palette = {'Presence': presence_color, 'Absence':absence_color}
    
    sig_pairs = [(('Male','Presence'), ('Male', 'Absence')), (('Female', 'Presence'), ('Female', 'Absence'))]

    fig_args = {'x': 'gender', 'y': 'value', 'hue':'condition', 'data': pong_beh_df, 'palette':palette, 'linecolor': linecolor,
                    'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation, 'alpha': alpha,
                    'width_viol':width_viol, 'width_box':width_box, 'bw':bw, 'edgecolor':linecolor}

    rainclouds = plot_RainClouds(ax=ax, order=order, cut=0, pointplot=False, box_fliersize=0,
                              box_medianprops=dict(linewidth=linewidth, color=linecolor,),
                              box_whiskerprops=dict(linewidth=linewidth), **fig_args)

    
    annotator = Annotator(**fig_args, ax=ax, pairs=sig_pairs, verbose=False, order=order, plot='stripplot')
    annotator.configure(test='Mann-Whitney', text_format='star', loc=stat_loc, line_width=linewidth)
    annotator.apply_and_annotate()

    rainclouds.legend_.remove()

    
    modify_axis_spines(ax, which = ['y'], yticks = np.linspace(0.0, 1.0, 3)) 
    
    ax.set_xlabel('')
    ax.set_ylabel('Avg speed$_p$', labelpad=10)
    ax.tick_params(axis = 'x', length=0)
    ax.tick_params(axis = 'y', )
    ax.patch.set_alpha(0.0)

    fig.savefig(fig_save_loc + save_string + 'Speed_Accuracy_Ratio_Groups_Vrt.svg', transparent = True, bbox_inches = 'tight')


In [None]:
def plot_group_posteriors(fig, axes, method_prefix='WS', m_ind=1, linewidth=1, labelsize=10, labelpad=15, stat_loc='outside', linecolor='k',
                          width=0.8, hspace=0.7):
    
    stat_list = [((net, 'Presence'), (net, 'Absence')) for net in yeo_networks_shortened[:-1]]
    line_width = 1

    palette = dict(Presence='dodgerblue', Absence='Crimson')

    for g_ind, group in enumerate(np.unique(group_posteriors.group)):
    
        ax = axes[g_ind]
        
        g_posterior = group_posteriors.sel(group=group)
        g_posterior.coords['group'] = np.arange(len(g_posterior.group))

        test_rvals = eeg_posterior_stats.isel(method=m_ind).sel(stat='value', group=group).to_numpy()
        test_pvals = eeg_posterior_stats.isel(method=m_ind).sel(stat='pvalue', group=group).to_numpy()        
        sig_annotations = [method_prefix + ': ' + str(test_rvals[n_ind].round(2)) for n_ind, n_name in enumerate(yeo_networks_shortened[:-1])]
        
        group_posteriors_stack = g_posterior.isel(network=np.arange(6)).stack({'aggSample': ('group', 'sample', 'network')})
        group_posteriors_stack = group_posteriors_stack.to_pandas().T.reset_index().melt(id_vars='network', value_vars=['Presence', 'Absence'], var_name='condition')
            
        flierprops = dict(marker='o', markerfacecolor='None', markersize=0,  markeredgecolor='black')
        
        violins = sns.boxplot(data=group_posteriors_stack, x='network', y='value', hue='condition', flierprops=flierprops,
                       palette=palette, saturation=0.8, dodge=True, ax=ax, linewidth=linewidth, showcaps=False, whis=1, width=width)

        annotator = Annotator(pairs=stat_list, data=group_posteriors_stack, x='network', y='value', hue='condition', verbose=False, plot='boxplot', ax=violins);
        annotator.set_custom_annotations(sig_annotations)
        annotator.configure(text_format='star', loc=stat_loc, line_width=linewidth);
        annotator.annotate();
        
        violins.legend_.remove()
        
        ax.tick_params(axis='y', which='major', labelsize=labelsize)
        ax.tick_params(axis='x', which='major', labelsize=labelsize, length=0)
        modify_axis_spines(ax, which=['y'], yticks=np.arange(100,115, 5))
        ax.set_title(group, size=labelsize, pad=labelpad*2)
        ax.set_xlabel('')
        ax.set_ylabel(r'$g_2$',size=labelsize, labelpad=labelpad,)

        fig.subplots_adjust(hspace=hspace)

In [None]:
def plot_group_posteriors_attn(fig, axes, group_posteriors, num_samples=5000, swarm_sample=True, bw=0.2, width_viol=0.5, width_box=0.3, linewidth=1, labelpad=15, saturation=1, alpha=0.8,
                               method_prefix='WS', m_ind=1, stat_loc='outside', linecolor='k', plot_clouds=False, order=['Female', 'Male'],
                               point_size=1, jitter=True, xlim_multiplier=0.6):
    
    stat_list = [((net, 'Presence'), (net, 'Absence')) for net in attn_networks]
    line_width = 1

    palette = dict(Presence='dodgerblue', Absence='Crimson')
    
    for g_ind, group in enumerate(subject_group_names):
        
        ax = axes[g_ind]

        test_rvals = eeg_posterior_stats.isel(method=m_ind).sel(stat='value', network=attn_networks, group=group).to_numpy()
        test_pvals = eeg_posterior_stats.isel(method=m_ind).sel(stat='pvalue', network=attn_networks, group=group).to_numpy()
        
        sig_annotations = [method_prefix + ': ' + str(test_rvals[n_ind].round(2)) for n_ind, n_name in enumerate(attn_networks)]
        
        g_posterior = group_posteriors.loc[group_posteriors['group']==group]
            
        flierprops = dict(marker='o', markerfacecolor='None', markersize=0,  markeredgecolor='black')

        fig_args = {'x':'network', 'y':'value', 'hue':'condition', 'data': g_posterior, 'hue_order':['Presence', 'Absence'], 'palette':palette, 'order':order,
                    'linecolor': linecolor, 'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation,
                    'alpha': alpha, 'width_viol':width_viol, 'width_box':width_box, 'bw':bw, 'edgecolor':linecolor, 'ax':ax, 'order':attn_networks}
        
        rainclouds = plot_RainClouds(cut=0, pointplot=False, box_fliersize=0, box_medianprops=dict(linewidth=linewidth, color=linecolor),
                                     swarm_sample=swarm_sample, num_samples=num_samples, box_whiskerprops=dict(linewidth=linewidth), box_capprops={'linewidth': 0},
                                     **fig_args, clouds=plot_clouds, jitter=jitter)
            
        annotator = Annotator(pairs=stat_list, data=g_posterior, x='network', y='value', hue='condition', verbose=False, plot='boxplot', ax=rainclouds);
        annotator.set_custom_annotations(sig_annotations)
        annotator.configure(text_format='star', loc=stat_loc, line_width=linewidth);
        annotator.annotate();
        
        rainclouds.legend_.remove()
        
        ax.tick_params(axis='y', which='major')
        ax.tick_params(axis='x', which='major', length=0)
        modify_axis_spines(ax, which=['y'], yticks=[101, 111])
        ax.set_xlim([-0.5,(len(attn_networks)+1)*xlim_multiplier])
        ax.set_title(group, size=plt.rcParams['font.size'], pad=labelpad*2)
        ax.set_xlabel('')
        
        if g_ind == 0:
            ax.set_ylabel(r'$g_2$', labelpad=labelpad, fontsize=plt.rcParams['font.size'])
        else:
            ax.set_ylabel('')
        

In [None]:
def plot_simulation_INI(fig, axes, markersize=3, labelpad=20, color='slategray', net_color='darkslategray', weight='normal', wspace=-0.1, hspace=-0.1, left=-0.1):

    xticks = np.linspace(100,110, 3)
    
    for net_ind, ax in enumerate(axes.ravel()):
    
        network = yeo_networks_shortened[net_ind]
    
        net_sum = fc_sum_sim[net_ind, :]
        
        ax.scatter(theta_eeg[:, net_ind], net_sum, s=markersize, color=color)
        
        ax.set_xlabel(network, weight=weight, c=net_color, labelpad=labelpad)
        ax.xaxis.set_label_position('top')
    
        yticks = np.linspace(0, 2500, 3)
        modify_axis_spines(ax, which=['x', 'y'], yticks=yticks, xticks=xticks)

    ax.patch.set_alpha(0)
    fig.supxlabel(r'$g_2$', y=0.01)
    fig.supylabel('INI', x=-0.005)
    fig.subplots_adjust(wspace=wspace, hspace=hspace, left=left)
    # fig.savefig(fig_save_loc + save_string + 'Int-Seg_Sim_shortened_EqualScale.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def plot_simulation_INI_single(fig, ax, fc_data=fc_sum_sim, step=1, ncol=1, handletextpad=0.1, markersize=3, markerscale=2, legend_spacing=1,
                                  scale_y=True, yscale_x=0.015, yscale_y=0.9, alpha=0.3, labelpad=20, weight='normal', net_scale=1,
                                  wspace=-0.1, hspace=-0.1, left=-0.1, lax_rect=[0.88, 0.2, 0.1, 0.7], cbar_width=0.3, lgd_ax_loc='right',
                                  cbar_pad=0.1, parent_figure=True, normalize=True, yticks=[0, 1.3], plot_LIM=False):

    # Define function to format tick labels as integers
    def int_formatter(x, pos):
        return int(x)  # Convert to integer and return

    xticks = np.linspace(101,111, 3)
    net_offsets=np.linspace(0, net_scale, 7)
    
    for net_ind, network in enumerate(yeo_networks_shortened):

        if ~plot_LIM and (network=='LIM'):
            continue
        
        net_sum = fc_data[net_ind, :].copy()
        if normalize:
            # net_dim = list(network_dimensions.values())[net_ind]
            # net_scaler = (net_dim**2 - net_dim)/2
            # net_sum /= net_scaler
            net_sum += net_offsets[net_ind]
            
        ax.scatter(theta_eeg[::step, net_ind], net_sum[::step], s=markersize, color=network_color_dict[network], alpha=alpha, label=network)

    # if scale_y:
    #     ax.set_yticklabels([0, 1])
    
    modify_axis_spines(ax, which=['x', 'y'], yticks=yticks, xticks=xticks)
    # ax.text(yscale_x, yscale_y, '1e+3', ha='left', va='top', transform=ax.transAxes)  # Adjust offset values

    ax.set_xlabel(r'$g_2$')
    ax.set_ylabel('INI')
    ax.patch.set_alpha(0)

    if parent_figure:
        divider = make_axes_locatable(ax)
        lax = divider.append_axes(lgd_ax_loc, size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        lax = fig.add_axes(rect=lax_rect)
    
    handles, labels = ax.get_legend_handles_labels()  # You can choose any subplot
    if ncol > 1:
        lax.legend(handles, labels, loc="center", ncol=ncol, frameon=False, markerscale=markerscale, 
                   columnspacing=legend_spacing, prop={'weight':'normal'}, handletextpad=handletextpad, handlelength=0.1)
    else:
        lax.legend(handles, labels, loc="center", ncol=ncol, frameon=False, markerscale=markerscale, 
                   labelspacing=legend_spacing, prop={'weight':'normal'}, handletextpad=handletextpad, handlelength=0.1)
    lax.set_frame_on(False)
    lax.axis(False)
    
    # fig.legend(frameon=False, markerscale=markerscale, bbox_to_anchor=bbox_to_anchor)
    # fig.subplots_adjust(wspace=wspace, hspace=hspace, left=left)
    fig.savefig(fig_save_loc + save_string + 'Int-Seg_Sim_shortened_EqualScale.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def plot_schaefer2018_parcellation(fig, axes, image_array, handletextpad=0.5, view_pad=10, hemi_pad=-5, ncol=1, lax_rect=[0.7, 0.2, 0.03, 0.6], legend_spacing=1,
                                   markerscale=1, wspace=-1, hspace=-0.5, left=-0.1, labelpad=1, cbar_width=0.3, cbar_pad=0.1, parent_figure=True):        
        
    hemi_names = {'rh': 'Right', 'lh': 'Left'}
    
    for v_ind, view in enumerate(views):

        for h_ind, hemi in enumerate(hemis):

            ax = axes[v_ind, h_ind]
            ax.imshow(image_array[h_ind, v_ind, ...])
            
            ax.spines['top'].set_visible(False)# Hide axis lines, ticks, and tick labels
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            ax.set_xticks([])  # Hide x-axis tick labels
            ax.set_yticks([])  # Hide y-axis tick labels

            if v_ind == 0:
                ax.set_xlabel(hemi_names[hemi], labelpad=view_pad)
                ax.xaxis.set_label_position('top')    
            if h_ind == 0:
                ax.set_ylabel(view.capitalize(), labelpad=hemi_pad)
    
    # Create a list to store legend handles
    legend_handles = []
    legend_labels = []
    # Iterate over the dictionary
    for label, color in network_label_colors.items():
        # Create a rectangle patch for the legend handle
        legend_handle = plt.Rectangle((0,0), 1, 1, color=color, label=label)
        legend_handles.append(legend_handle)
        legend_labels.append(label)
    
    if parent_figure:
        divider = make_axes_locatable(ax)
        lax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        lax = fig.add_axes(rect=lax_rect)
        fig.subplots_adjust(wspace=wspace, hspace=hspace)
    
    lax.legend(legend_handles, legend_labels, frameon=False, markerscale=markerscale, ncol=ncol, handletextpad=handletextpad, labelspacing=legend_spacing)
    lax.set_frame_on(False)
    lax.axis(False)
    
    # Create the legend
    fig.savefig(fig_save_loc + save_string + 'Schaefer2018_400Parcels_7Networks.svg', format='svg', bbox_inches='tight', transparent=True, dpi=800)

In [None]:
def plot_pdiff_heatmap(fig, axes, hemi='lh', views=['lateral', 'medial'], group_pad=-5, view_pad=10, labelpad=10, title_x=0.5, title_y=0.9, draw_title=True,
                       wspace=-0.1, hspace=-0.1, left=0, right=0, cax_rect=[0.7, 0.2, 0.03, 0.6], cbar_width=0.3, cbar_pad=0.1, parent_figure=True): 

    _, sm = compute_pdiff_colormap(data=posterior_difference)


    for g_ind, _ in enumerate(pdiff_image_array.group):
        
        for v_ind, _ in enumerate(views):
    
            ax = axes[v_ind, g_ind]
            # ax = axes.ravel()[g_ind*v_ind + v_ind]
    
            data = pdiff_image_array[g_ind, v_ind, ...].copy()
            group = data.group.to_numpy()
            view = str(data.view.to_numpy())
            data = data.to_numpy()
            
            # Create an empty alpha channel with default value 1 for all pixels
            alpha_channel = np.ones((data.shape[0], data.shape[1]), dtype=np.uint8)*255
            data_rgba = np.dstack((data, alpha_channel))
    
            view_mask = brain_masks[hemi+'-'+view]
    
            data_rgba[view_mask[:, 0], view_mask[:, 1]] = [0, 0, 0, 0]  # Set the alpha channel (transparency) to 0 for white pixel
    
            ax.imshow(data_rgba)
    
            if v_ind == 0:
                ax.set_xlabel(group, labelpad=group_pad)
                ax.xaxis.set_label_position('top')
    
            if (v_ind == 0) & (len(views) > 1):
                ax.set_ylabel(view.capitalize(), labelpad=view_pad)
    
            # Hide axis lines, ticks, and tick labels
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            ax.set_xticks([])  # Hide x-axis tick labels
            ax.set_yticks([])  # Hide y-axis tick labels
    if parent_figure:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)
        fig.subplots_adjust(wspace=wspace, hspace=hspace)

    if draw_title:
        fig.text(x=title_x, y=title_y, s='Posterior $g_2$', fontsize=plt.rcParams['font.size'])
    
    cbar = fig.colorbar(sm, extendrect=False, cax=cax)
    cbar.set_label('Condition difference', labelpad=labelpad)  # Set font size for the colorbar label
    cbar.ax.tick_params(length=0)  # Remove colorbar ticks
    cbar.outline.set_visible(False)  # Remove colorbar frame
    cbar.ax.set_yticks([-0.5, 0.5])

In [None]:
def generate_schaefer2018_image_array(hemis=['lh', 'rh'], views=['lateral', 'medial'], width=400, height=400, ctxcolor='k', bgcolor='w', verbose=False):

    num_hemi = len(hemis)
    num_view = len(views)
    
    image_array = xr.DataArray(np.zeros((num_hemi, num_view, width, height, 3), dtype=np.int16), dims=('hemi', 'view', 'width', 'height', 'channel'),
                               coords=dict(hemi=hemis, view=views, channel=['r', 'g', 'b']))
    
    for h_ind, hemi in enumerate(hemis):

        if verbose:
            print('Drawing ' + hemi + ' brain')
        
        labels = mne.read_labels_from_annot('fsaverage', parc=parc, regexp='7Network', hemi=hemi, subjects_dir=fs_label_dir)
        
        brain = mne.viz.Brain(subject="fsaverage", surf="inflated", hemi=hemi, subjects_dir=fs_label_dir,
            cortex=ctxcolor,
            background=bgcolor,
            size=(width, height))
        
        for label in labels:
            # print(label.name)
            label_network_name = label.name.split('H_')[1].split('_')[0]
            if label_network_name not in yeo_networks:
                label_color = 'k'
            else:
                label_color = network_label_colors[yeo_networks_shortened_dict[label_network_name]]
            brain.add_label(label, color=label_color)
        
        brain.add_annotation(parc, borders=True, color='k')    
    
        for v_ind, view in enumerate(views):
    
            brain.show_view(view=view)
            image_array[h_ind, v_ind, ...] = np.array(brain.screenshot())
        brain.close()
    return image_array

In [None]:
def plot_EEG_timeseries(fig, ax, s_ind=0, num_plotted_labels=400, lw=1, alpha=0.1, color='slategray', titlepad=5):

    label_inds = np.arange(400)
    plotted_labels = np.random.choice(label_inds, size=num_plotted_labels)
    plotted_trials = np.random.choice(np.arange(80), size=1)
    
    plot_data = source_epochs.sel(subject=subjects[s_ind], condition='Presence', label=sample_source_epochs.label[plotted_labels], trial=plotted_trials[0]).copy()
    
    time_vec = plot_data.time.to_numpy()
    intrp_length = time_vec.shape[0]*5
    time_vec_intrp = np.linspace(time_vec[0], time_vec[-1], intrp_length)
    plot_data_intrp = np.zeros((intrp_length, num_plotted_labels))
    
    for l_ind in range(num_plotted_labels):
        interpFunc = interpolate.interp1d(time_vec, plot_data[:, l_ind].to_numpy(), kind='cubic')
        data = interpFunc(time_vec_intrp)
        plot_data_intrp[:, l_ind] = data
    
    ax.plot(time_vec_intrp, plot_data_intrp[:,:], c=color, lw=lw, alpha=alpha);
    
    modify_axis_spines(ax, which=['x', 'y'], yticks=np.arange(0, 3e-13, 1e-13), xticks=np.arange(0, time_vec_intrp[-1]+1, time_vec_intrp[-1]))
    ax.set_xlabel('Time (ms)', labelpad=-5)
    ax.set_xticklabels(['BM', str(500)])
    ax.set_ylabel('MNE', labelpad=5)
    ax.patch.set_alpha(0.0)
    ax.set_title('Sample source activity', fontsize=plt.rcParams['font.size'], pad=titlepad)
    # ax.margins(x=0.15)
    # ax.margins(y=1.5e-1)
    
    # ax.axis('off')
    
    # scalebar_y = ScaleBar(dx=1, rotation='vertical', fixed_value=1e-13, frameon=False, scale_formatter=lambda value,unit: f"{value} MNE",
    #                           location = 'center left', scale_loc='left', font_properties={'size':tick_size}, width_fraction=0.008)
    # scalebar_x = ScaleBar(dx=1, fixed_value=50, frameon=False, scale_formatter=lambda value,unit: f"{value} ms", location = 'lower center', font_properties={'size':tick_size}, width_fraction=0.01)
    
    # ax.add_artist(scalebar_y)
    # ax.add_artist(scalebar_x)
    

In [None]:
def plot_beh_timeseries(fig, ax, lw=1, titlepad=5, labelpad=-10, s_ind=16, color='k', alpha=0.6, scale_pixels=True):
        
    sel_subj_data = pong_movement.sel(subject = subjects[s_ind])
    sel_movement = sel_subj_data.sel(source='movement')
    sel_time = sel_subj_data.sel(source='timestamp', trial=0)
    sel_time -= sel_time[0]
    time_max = sel_time[-1].round(1).to_numpy()
    time_step = time_max/2
    
    ax.plot(sel_time, sel_movement, c=color, alpha=alpha, lw=lw);
    modify_axis_spines(ax, which=['x', 'y'], yticks=np.arange(-800, 801, 800), xticks=np.arange(0, time_max+0.01, time_step))
    if scale_pixels:
        ax.set_yticklabels([-8, 0, 8])
    
    ax.patch.set_alpha(0.0)
    ax.set_xlabel('Time (ms)')
    ax.set_xticklabels([str(round(-time_max*1000)), str(round(-time_step*1000)), 'FB'])
    ax.set_ylabel('Distance (px)', labelpad=labelpad)
    ax.set_title('Sample paddle movement', fontsize=plt.rcParams['font.size'], pad=titlepad)
    
    fig.save_fig(fig_save_loc + save_string + 'Movement_Samples.svg', transparent = True, bbox_inches = 'tight')

In [None]:
def half_violinplot_pt(fig, ax, fig_args, vis_args, violin=True, strip=True, width_viol=0.5, point_size=3, jitter=1, width_box=0.3):

    if violin:
        violins = pt.half_violinplot(**fig_args, cut=0, scale="area", width=width_viol, inner=None, edgecolor=vis_args['edgecolor'], dodge=False)
        violins.legend_.remove()
    if strip:
        strips = sns.stripplot(**fig_args, size=point_size, jitter=jitter, zorder=10, edgecolor=vis_args['edgecolor'], dodge=True)
        strips.legend_.remove()
    
    boxes = sns.boxplot(**fig_args, width=width_box, zorder=1, boxprops={'facecolor': 'white', "zorder": 1}, color=vis_args['edgecolor'],
                     showfliers=False, showcaps=False, dodge=True, whiskerprops={'linewidth': vis_args['linewidth'], "zorder": 10})
    # boxes.legend_.remove()
    

In [None]:
def plot_speed_boxes(fig, ax, linewidth=1, alpha=1, saturation=1, order=['Female', 'Male'], point_size=3, linecolor='k',
                     width_viol=0.5, bw=0.7, width_box=0.3, stat_loc='inside', dodge=False):

    palette = {'Presence': presence_color, 'Absence':absence_color}
    
    sig_pairs = [(('Male','Presence'), ('Male', 'Absence')), (('Female', 'Presence'), ('Female', 'Absence'))]

    fig_args = {'x': 'gender', 'y': 'value', 'hue':'condition', 'data': pong_beh_df, 'palette':palette, 'linecolor': linecolor,
                    'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation, 'alpha': alpha,
                    'width_viol':width_viol, 'width_box':width_box, 'bw':bw, 'edgecolor':linecolor, 'box_capprops':{"linewidth": 0}}
    
    rainclouds = plot_RainClouds(ax=ax, order=order, cut=0, pointplot=False, box_fliersize=0,
                              box_medianprops=dict(linewidth=linewidth, color=linecolor,),
                              box_whiskerprops=dict(linewidth=linewidth), **fig_args)

    
    annotator = Annotator(**fig_args, ax=ax, pairs=sig_pairs, verbose=False, order=order, plot='stripplot')
    annotator.configure(test='Mann-Whitney', text_format='star', loc=stat_loc, line_width=linewidth)
    annotator.apply_and_annotate()

    rainclouds.legend_.remove()

    
    modify_axis_spines(ax, which = ['y'], yticks = np.linspace(0.0, 1.0, 3)) 
    
    # boxes.legend_.remove()
    ax.set_xlabel('')
    ax.set_ylabel('Avg speed$_p$', labelpad=10)
    ax.tick_params(axis = 'x', length=0)
    ax.tick_params(axis = 'y', )
    ax.patch.set_alpha(0.0)

    fig.savefig(fig_save_loc + save_string + 'Speed_Accuracy_Ratio_Groups_Vrt.svg', transparent = True, bbox_inches = 'tight')


In [None]:
def plot_IS_polar_groups(fig, axes, title='INI', title_x=1, title_y=0.965, tickpad=1, subfigure=True, fontsize=None, rotation=None):
   
    absence_color, presence_color = 'crimson', 'dodgerblue'
    palette = {'Presence': presence_color, 'Absence': absence_color}    
    
    net_polar_labels = yeo_networks_shortened
    
    N = len(net_polar_labels)
    theta = radar_factory(N, frame='circle')
        
    # Plot the four cases from the example data on separate axes
    for g_ind, (group, group_data) in enumerate(fc_sum_polar.items()):
        
        ax = axes[g_ind]
    
        # ax.set_rgrids([])
        ax.text(x=0.5, y=0.15, s=group, weight='bold', transform=ax.transAxes, horizontalalignment='center', verticalalignment='center')
        ax.set_varlabels(net_polar_labels)
        ax.grid(False)  # Set grid to False
        ax.spines['polar'].set_visible(False)
        ax.set_facecolor('#E5ECF6')
        
        ax.tick_params(pad=tickpad)
        for d, color in zip(group_data, palette.values()):
    
            for tick, label in zip(ax.get_xticklabels(), net_polar_labels):
                # tick.set_color(network_color_dict[label])
                tick.set_fontproperties({'weight': 'normal'})
                if fontsize != None:
                    tick.set_fontproperties({'size': fontsize})
                    
            ax.set_ylim([0,1])
            ax.set_yticklabels(['' for _ in range(len(ax.get_yticks()))])
            # ax.plot(theta, d, color=color)
            ax.fill(theta, d, facecolor=color, alpha=0.5, label='_nolegend_')
            
    
    if rotation != None:
        # ax.set_theta_zero_location(rotation)  # Rotate offset degrees Westward
        ax.set_theta_offset(rotation)
    if subfigure:
        fig.text(title_x, title_y, 'INI', horizontalalignment='center', color='black', weight='bold')
    else:
        axes[0].text(title_x, title_y, 'INI', horizontalalignment='center', color='black', weight='bold')


In [None]:
def generate_schaefer2018_image_array(hemis=['rh', 'lh'], views=['lateral', 'medial'], width=400, height=400, ctxcolor='dimgray', bgcolor='w', verbose=False):

    num_hemi = len(hemis)
    num_view = len(views)
    
    image_array = xr.DataArray(np.zeros((num_hemi, num_view, width, height, 3), dtype=np.int16), dims=('hemi', 'view', 'width', 'height', 'channel'),
                               coords=dict(hemi=hemis, view=views, channel=['r', 'g', 'b']))
    
    for h_ind, hemi in enumerate(hemis):

        if verbose:
            print('Drawing ' + hemi + ' brain')
        
        labels = mne.read_labels_from_annot('fsaverage', parc=parc, regexp='7Network', hemi=hemi, subjects_dir=fs_label_dir)
        
        brain = mne.viz.Brain(subject="fsaverage", surf="inflated", hemi=hemi, subjects_dir=fs_label_dir,
            cortex=ctxcolor,
            background=bgcolor,
            size=(width, height))
        
        for label in labels:
            # print(label.name)
            label_network_name = label.name.split('H_')[1].split('_')[0]
            if label_network_name not in yeo_networks:
                label_color = 'k'
            else:
                label_color = network_label_colors[yeo_networks_shortened_dict[label_network_name]]
            brain.add_label(label, color=label_color)
        
        brain.add_annotation(parc, borders=True, color='k')    
    
        for v_ind, view in enumerate(views):
    
            brain.show_view(view=view)
            image_array[h_ind, v_ind, ...] = np.array(brain.screenshot())
        brain.close()
    return image_array

In [None]:
def compute_pdiff_colormap(data=[], cmap=plt.cm.RdBu, round_limits=False):
    
    data = data.copy().to_numpy().ravel()
    # Preparing colormap    
    if round_limits:
        norm = Normalize(vmin=np.min(data).round(1), vmax=np.max(data).round(1))
    else:
        norm = Normalize(vmin=np.min(data), vmax=np.max(data))
    
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  
    color_array = sm.to_rgba(data.reshape(posterior_difference.shape))

    return color_array, sm

In [None]:
def generate_conHeatMap_image_array(data=[], hemi='lh', views=['lateral', 'medial'], width=500, height=500,
                                    ctxcolor='k', bgcolor='w', cmap=plt.cm.RdBu, verbose=False):

    num_view = len(views)
    
    color_array, sm = compute_pdiff_colormap(data=data, cmap=cmap)
    
    image_array = xr.DataArray(np.zeros((2, num_view, width, height, 3), dtype=np.int16), dims=('group', 'view', 'width', 'height', 'channel'),
                               coords=dict(view=views, group=posterior_difference.group, channel=['r', 'g', 'b']))
    
    brain = mne.viz.Brain(subject="fsaverage", surf="inflated", hemi=hemi, subjects_dir=fs_label_dir,
                          cortex=ctxcolor, background=bgcolor, size=(width, height))
    
    for g_ind, group in enumerate(data.group):
        
        for net_ind, network in enumerate(data.network.to_numpy()):
            
            labels = mne.read_labels_from_annot('fsaverage', parc=parc, hemi=hemi, regexp=yeo_networks[net_ind], subjects_dir=fs_label_dir)
            
            for l_ind, label in enumerate(labels):
                brain.add_label(label, color=color_array[g_ind, net_ind, :])
            
        brain.add_annotation(parc, borders=True, color='k')
    
        for v_ind, view in enumerate(views):
    
            brain.show_view(view=view)
            image_array[g_ind, v_ind, ...] = np.array(brain.screenshot())
    
    brain.close()
    
    return image_array

In [None]:
def generate_brain_mask_array(hemis=['rh', 'lh'], views=['lateral', 'medial'], width=400, height=400, ctxcolor='k', bgcolor='w', verbose=False):

    view_masks = np.zeros((len(hemis), len(views), width, height, 4), dtype=np.uint8) 
    mask_dict = {}

    for h_ind, hemi in enumerate(hemis):

        brain = mne.viz.Brain(subject="fsaverage", surf="inflated", hemi=hemi, subjects_dir=fs_label_dir,
            cortex='k',
            background='w',
            size=(width, height))
    
        for v_ind, view in enumerate(views):
        
            brain.show_view(view)
            
            brain_image = brain.screenshot()
            
            alpha_channel = np.ones((brain_image.shape[0], brain_image.shape[1]), dtype=np.uint8)*255
            brain_image_rgba = np.dstack((brain_image, alpha_channel))
        
            # Create a boolean mask for white pixels
            white_pixel_mask = np.all(brain_image_rgba == [255, 255, 255, 255], axis=-1)
            
            # Find the indices of white pixels using np.where
            white_pixel_indices = np.where(white_pixel_mask)
            
            # white_pixel_indices is a tuple of arrays containing row and column indices
            # You can transpose it to get an array of (y, x) indices
            white_pixel_indices = np.transpose(white_pixel_indices)
            
            mask_dict[hemi+'-'+view] = white_pixel_indices
            
            view_masks[h_ind, v_ind, ...] = brain_image_rgba
        
        brain.close()

    return mask_dict

In [None]:
def plot_fc_matrix(comm, ordered, ax, fig, cmap='viridis', ylabel=False, lw=1, cbar_pad=0.1, cbar_width=3, cax_rect=[0.9, 0.2, 0.03, 0.6], cbar_axis=True):

    s = 0
    X = Y = 0
    N = ordered.shape[0]
    n_comm = len(comm)

    for k in range(n_comm):
        if k > 0:
            s += len(comm[k-1])
            X = s
            Y = s
            origin_offset=-lw/2
            size_offset=+lw/8
        else:
            origin_offset=0
            size_offset=-lw/2

        if k == len(comm)-1:
            # origin_offset=0
            size_offset=+lw/4
            
        patch = ax.add_patch(mpl.patches.Rectangle((X+origin_offset, Y+origin_offset),
                               len(comm[k])+size_offset, len(comm[k])+size_offset,
                               fill=None, lw=lw, alpha=1,))
        
        patch.set_clip_path(patch)

    vmin=ordered.min().round(1)
    
    im = ax.imshow(ordered, cmap=cmap, vmin=vmin)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
    ax.set_title('FC', fontsize=plt.rcParams['font.size'])
        
    ticks = [vmin, 1]

    if cbar_axis:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)

    cbar = fig.colorbar(im, extendrect=False, cax=cax)
    cbar.set_ticks(ticks)
    cbar.ax.tick_params(size=0)
    cbar.outline.set_visible(False)

    # fig.tight_layout()

In [None]:
def draw_network_dimensions(network_dict, fig, ax, orientation='horizontal', ax_offset=0.01, x_value=0, text_offset=0.1, text_yoffset=0, ind_offsets=[0,0],
                            width_ratio=0.05, linewidth=3, fontsize=10, fontweight='normal', ax_width=None, LIM_offset=-4):

    if orientation.startswith('v'):
        labelax_x = ax.get_position().x0 - ax_offset
        labelax_y = ax.get_position().y0
        
        if ax_width is None:
            ax_width = ax.get_position().height
        label_ax = fig.add_axes([labelax_x, labelax_y, width_ratio, ax_width])  # Adjust position as needed            
        label_ax.invert_yaxis()
        label_ax.invert_xaxis()
        text_offset = -text_offset
    else:
        labelax_x = ax.get_position().x0 + ax_offset
        labelax_y = ax.get_position().y1
        if ax_width is None:
            ax_width = ax.get_position().width
        
        label_ax = fig.add_axes([labelax_x, labelax_y, ax_width, width_ratio], sharex=ax)  # Adjust position as needed 
    
    label_ax.set_xlim([0, sample_fc.shape[0]])
    label_ax.set_ylim([0, sample_fc.shape[0]])

    for n_ind, (net_name, net_inds) in enumerate(network_dict.items()):

        # Create a new axes for the labels on top of the existing one

        if orientation.startswith('v'):

            rect_x = [x_value, x_value]
            rect_y = [net_inds[0]+ind_offsets[0], net_inds[-1]+ind_offsets[1]]

            text_x = x_value + text_offset
            text_y = int(np.median(net_inds)) + text_yoffset
        
        else:

            if net_name == 'LIM':
                ind_offsets[1] += LIM_offset
                text_yoffset = -2.5
            
            rect_y = [x_value, x_value]
            rect_x = [net_inds[0]+ind_offsets[0], net_inds[-1]+ind_offsets[1]]

            text_x = int(np.median(net_inds)) + text_yoffset 
            text_y = x_value + text_offset
        
        label_ax.plot(rect_x, rect_y, color=network_color_dict[net_name], linewidth=linewidth)
        label_ax.text(text_x, text_y+text_yoffset, net_name, fontweight=fontweight,
                ha='center', va='center', rotation=orientation, fontsize=fontsize, color=network_color_dict[net_name])
    
    # label_ax.set_aspect('auto', adjustable='datalim')

    label_ax.axis('off')

In [None]:
def draw_network_dimensions_bottom(network_dict, fig, ax, orientation='horizontal', ax_xoffset=0.01, ax_yoffset=-0.1, x_value=0, text_offset=0.1, text_yoffset=0, ind_offsets=[0,0],
                            width_ratio=0.05, linewidth=3, fontsize=10, fontweight='normal', ax_width=None, LIM_offset=-4):

    labelax_x = ax.get_position().x0 + ax_xoffset
    labelax_y = ax.get_position().y0 + ax_yoffset

    if ax_width is None:
        ax_width = ax.get_position().width
    
    label_ax = fig.add_axes([labelax_x, labelax_y, ax_width, width_ratio], sharex=ax)  # Adjust position as needed 
    
    label_ax.set_xlim([0, sample_fc.shape[0]])
    label_ax.set_ylim([0, sample_fc.shape[0]])

    for n_ind, (net_name, net_inds) in enumerate(network_dict.items()):


        if net_name == 'LIM':
            ind_offsets[1] += LIM_offset
            text_yoffset = -2.5
        
        rect_y = [x_value, x_value]
        rect_x = [net_inds[0]+ind_offsets[0], net_inds[-1]+ind_offsets[1]]

        text_x = int(np.mean(net_inds)) + text_yoffset 
        text_y = x_value + text_offset
        
        label_ax.plot(rect_x, rect_y, color=network_color_dict[net_name], linewidth=linewidth)
        label_ax.text(text_x, text_y+text_yoffset, net_name, fontweight=fontweight,
                ha='center', va='center', rotation=orientation, fontsize=fontsize, color=network_color_dict[net_name])
    
    # label_ax.set_aspect('auto', adjustable='datalim')

    label_ax.axis('off')
    label_ax.patch.set_alpha(0)

In [None]:
def plot_SC_average(fig, ax, cmap='inferno', labelpad=12, cbar_pad=0.1, cbar_width=3, cax_rect=[0.9, 0.2, 0.03, 0.6], cbar_axis=True, draw_yaxis=True,
                   titlepad=5):

    sc_imag = ax.imshow(SC,cmap=cmap, vmax=12)

    if cbar_axis:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)
    
    cbar = fig.colorbar(sc_imag, extendrect=False, cax=cax)
    # cbar.set_label('Connection weight', fontsize=fontsize, labelpad=labelpad+10)
    cbar.set_ticks(np.linspace(int(SC.min().round(-1)), int(SC.max().round()), 2))
    cbar.ax.tick_params(axis='both', length=0)
    cbar.outline.set_edgecolor(None)
    cbar.outline.set_visible(False)
    
    ax.set_xticks([0, 400])
    ax.tick_params(axis='x', which='major', length=0)
    ax.set_xlabel('Nodes', labelpad=labelpad)
    ax.set_title('SC', fontsize=plt.rcParams['font.size'], pad=titlepad)
    if draw_yaxis:    
        ax.set_yticks([0, 400])
        ax.tick_params(axis='y', which='major', length=0)
        ax.set_ylabel('Nodes', labelpad=labelpad)
    else:
        ax.set_yticks([])
    ax.set_frame_on(False)
    ax.patch.set_alpha(0)

    f.savefig(fig_save_loc + 'EEG/' + 'Schaefer2018_SC.svg', transparent = True, bbox_inches='tight')

In [None]:
def plot_eeg_setup(fig, ax, image_path):

    img = Image.open(image_path)
    image_array = np.array(img)    
    ax.imshow(image_array)
    ax.axis('off')

In [None]:
def plot_xpos_corrs(fig, ax, cmap='magma', titlepad=5, ax_title='Dot product', cbar_pad=0.1, cbar_width=3, cax_rect=[0.9, 0.2, 0.03, 0.6], cbar_axis=True):


    if cbar_axis:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size=str(cbar_width) + '%', pad=cbar_pad)
    else:
        cax = fig.add_axes(rect=cax_rect)
    
    pong_mvm_stacked = pong_movement.stack(aggTrial=('trial','subject'))
    ballx_all = pong_mvm_stacked.sel(source='ball_x')
    pongx_all = pong_mvm_stacked.sel(source='movement')
    
    x_dots = np.dot(ballx_all, pongx_all.T)
    xdot_shape = x_dots.shape
    x_dots = preprocessing.MinMaxScaler().fit_transform(x_dots.reshape(-1,1)).reshape(xdot_shape)
    corr_img = ax.imshow(x_dots, cmap=cmap)
    
    ax.set_title(ax_title, fontsize=plt.rcParams['font.size'], pad=titlepad)
    ax.set_xlabel('Paddle pos')
    ax.set_ylabel('Ball pos')
    ax.set_xticks([0, xdot_shape[0]-1])
    ax.set_yticks([0, xdot_shape[0]-1])
    ax.set_frame_on(False)
    ax.patch.set_alpha(0)
    ax.tick_params(length=0)
    
    cbar = fig.colorbar(corr_img, extendrect=False, cax=cax)
    cbar.ax.set_yticks([x_dots.min().round(1), x_dots.max().round(1)])
    cbar.ax.tick_params(size=0)
    cbar.outline.set_visible(False)


### SubFigureMain

In [None]:
brain_width, brain_height = 800, 800

yeo_parcellation_image = generate_schaefer2018_image_array(width=brain_width, height=brain_height)
brain_masks = generate_brain_mask_array(width=brain_width,height=brain_height)
pdiff_image_array = generate_conHeatMap_image_array(width=brain_width,height=brain_height, data=posterior_difference)
fcdiff_image_array = generate_conHeatMap_image_array(data=fc_emp_diff, views=['lateral'], width=brain_width, height=brain_height)

yeo_parcellation_image.to_netcdf(parent_preprocess_dir + 'yeo_parcellation.nc')
pdiff_image_array.to_netcdf(parent_preprocess_dir + 'pdiff_image_array.nc.')
fcdiff_image_array.to_netcdf(parent_preprocess_dir + 'fcdiff_image_array.nc')

In [None]:
yeo_parcellation_image = xr.load_dataarray(parent_preprocess_dir + 'yeo_parcellation.nc')
pdiff_image_array = xr.load_dataarray(parent_preprocess_dir + 'pdiff_image_array.nc.')
fcdiff_image_array = xr.load_dataarray(parent_preprocess_dir + 'fcdiff_image_array.nc')

In [None]:
_ = radar_factory(7, frame='circle')

height_ratios = np.array([3, 1, 2])
right_space = 0.95
left_space = 0.09
wspace = 0.5
hspace = 0.3


fig = plt.figure(figsize=(9,12), frameon=False)
(fig_r1, fig_r2, fig_r3) = fig.subfigures(3, 1, height_ratios=height_ratios, hspace=hspace, wspace=wspace)


####################### 1st Row sub-figure #########################################
fig_r2_wr = [1, 2, 2]

fig_r1.subplots_adjust(top=0.8)

((fig_mvm_cor, fig_eeg_mvm, fig_eeg_beh),
 (fig_FC_img, fig_eeg_ts, fig_IS_heatmap),
 (fig_SC_img, fig_sim_sample, fig_sim_IS)) = fig_r1.subfigures(3,3, frameon=False, width_ratios=fig_r2_wr, hspace=hspace+0.2, wspace=0.1)

######### 1st sub-row #############
ax_eeg_mvm = fig_eeg_mvm.subplots(1,1)
ax_eeg_beh = fig_eeg_beh.subplots(1,1)

plot_beh_timeseries(fig_eeg_mvm, ax_eeg_mvm, lw=0.3,labelpad=5)
plot_speed_boxes(fig_eeg_beh, ax_eeg_beh, bw=0.7, alpha=0.8, stat_loc='outside', dodge=True, point_size=2)

######### 1st sub-column #############
ax_mvm_cor = fig_mvm_cor.subplots(1,1)
ax_FC_img = fig_FC_img.subplots(1,1)
ax_SC_img = fig_SC_img.subplots(1,1)

sc_cax_rect = [0.75, 0.1, 0.03, 0.78]

plot_xpos_corrs(fig_mvm_cor, ax_mvm_cor, cax_rect=sc_cax_rect)

net_comm = list(network_inds_sorted.values())
sample_fc = compute_FC(source_epochs.sel(subject=subjects[sInd], condition='Presence', label=label_names_sorted).mean('trial').T)
network_inds_sorted_inverse = {n_name:np.sort(399 - network_inds_sorted[n_name]) for n_name in yeo_networks_shortened[::-1]}
cmap_FC = cc.m_kbc
# cmap_FC = cc.m_kb
cmap_FC = 'plasma'

plot_fc_matrix(net_comm, sample_fc, ax=ax_FC_img, fig=fig_FC_img, cmap=cmap_FC, lw=2, cax_rect=sc_cax_rect)
draw_network_dimensions_bottom(network_inds_sorted, fig=fig_FC_img, ax=ax_FC_img, orientation='horizontal', x_value=10,
                               ax_width=0.495, text_offset=-200, ind_offsets=[0,0], linewidth=3,
                               fontsize=2, ax_xoffset=0.105, ax_yoffset=-0.05, LIM_offset=0)

cmap_SC = 'inferno'
plot_SC_average(fig_SC_img, ax_SC_img, cmap=cmap_SC, labelpad=-3, cax_rect=sc_cax_rect, draw_yaxis=False)

######### 2nd sub-column #############
ax_eeg_ts = fig_eeg_ts.subplots(1,1)
ax_sim_sample = fig_sim_sample.subplots(1,1)
plot_EEG_timeseries(fig_eeg_ts, ax_eeg_ts, alpha=0.2, num_plotted_labels=50, color='k')
plot_simulation_timeseries(fig_sim_sample, ax_sim_sample, alpha=0.2, num_plotted_labels=50, color='k')

######### 3rd sub-column #############

brain_wspace = -0.5
brain_hspace = -0.1
brain_top = 1
brain_bottom = 0
brain_left = -0.15
brain_right = 0.95

cax_brain = [0.8, 0.05, 0.03, 0.9] 


ax_IS_heatmap = fig_IS_heatmap.subplots(1,2)
ax_sim_IS = fig_sim_IS.subplots(1,1)

plot_fcdiff_heatmap(fig_IS_heatmap, ax_IS_heatmap, cmap_data=fc_emp_diff, image_data=fcdiff_image_array.squeeze(), wspace=brain_wspace, hspace=brain_hspace, labelpad=5,
                    cbar_pad=0.5, group_pad=-3, parent_figure=False, cax_rect=cax_brain, title_x=0.38, title_y=0.05)

plot_simulation_INI_single(fig_sim_IS, ax_sim_IS, parent_figure=False, step=20, markersize=1, markerscale=4, alpha=0.5,
                              lax_rect=[0.8, 0.2, 0.1, 0.5], ncol=1, handletextpad=0.5, legend_spacing=0.2)

fig_sim_IS.subplots_adjust(right=0.75)

fig_IS_heatmap.subplots_adjust(bottom=brain_bottom, top=brain_top, right=brain_right, left=brain_left)

# fig_r1.subplots_adjust(hspace=0.3)

####################### 2nd Row sub-figure #########################################
fig_r2_wr = [1, 0.5]
(fig_con_boxes, fig_grp_hmap) = fig_r2.subfigures(1,2, frameon=False, width_ratios=fig_r2_wr)

ax_con_boxes = fig_con_boxes.subplots(1, 2).ravel()
ax_grp_hmap = fig_grp_hmap.subplots(1, 2).reshape(1, 2)


brain_wspace = 0.0
brain_hspace = 0.0
brain_top = 1
brain_bottom = -0.1
brain_left = -0.1

# cax_brain = [0.7, 0.0, 0.03, 0.9] 
cax_brain = [0.78, 0.17, 0.03, 0.68] 

plot_group_posteriors_attn(fig_con_boxes, ax_con_boxes, group_posteriors_attn, labelpad=15, bw=0.4, alpha=0.6, plot_clouds=False,
                          width_box=0.7, num_samples=1000, point_size=2, jitter=0.25)

for ax in ax_con_boxes:
    if ax.legend_ is not None:
        ax.legend_.remove()
        ax.legend('')

plot_pdiff_heatmap(fig_grp_hmap, ax_grp_hmap, views=['lateral'], wspace=brain_wspace, hspace=brain_hspace, labelpad=5, cbar_pad=0.5, group_pad=-3,
                   parent_figure=False, cax_rect=cax_brain, title_x=0.25, title_y=1.1)

fig_con_boxes.subplots_adjust(left=0.1, right=0.85, wspace=0.2)
fig_grp_hmap.subplots_adjust(right=0.75, left=-0.1, wspace=0.0)

####################### 3rd Row sub-figure #########################################

beh_grid = gridspec.GridSpec(2, 6, wspace=wspace, hspace=hspace, left=left_space, bottom=0.15, right=0.91)

ax_con_beh_corr = np.zeros((beh_grid.nrows, beh_grid.ncols), dtype=object)

for r_ind in range(beh_grid.nrows):
    for c_ind in range(beh_grid.ncols):
        ax_con_beh_corr[r_ind, c_ind] = fig_r3.add_subplot(beh_grid[r_ind, c_ind])

for r_ind in range(2):
    for ax_ind, ax in enumerate(ax_con_beh_corr[r_ind,:]):
        if ax_ind != 0:
            ax.sharey(ax_con_beh_corr[r_ind,0])

plot_con_beh_correlation(fig_r3, ax_con_beh_corr, parent_figure=False, rho_fontsize=7, linewidth=1, markersize=10, markerscale=2.5, stat_weight='semibold', net_weight='semibold',
                         g_colors=['lightsteelblue', 'rosybrown'], lax_rect=[0.91, 0.9, 0.1, 0.1], ncol=1, wspace=1., hspace=0.6, legend_spacing=1,
                         ylabel_pos=0.02, rho_pos_groups=np.array([[0.53, 0.22],[0.53, 0.9]]), handletextpad=0)
            
for c_ind in range(beh_grid.ncols):
    ax_con_beh_corr[0, c_ind].set_xticklabels([])
            
######################## Parent figure config #########################################

annotated_axes = [ax_mvm_cor, ax_eeg_mvm, ax_eeg_beh, ax_FC_img, ax_eeg_ts, ax_IS_heatmap[0], ax_SC_img,
                  ax_sim_sample, ax_sim_IS, ax_con_boxes[0], ax_grp_hmap.ravel()[0], ax_con_beh_corr.ravel()[0]]
ax_letters = all_letters[2:len(annotated_axes)+2]
# ax_letters = [ltr+' )' for ltr in ax_letters]

for ax_ind, ax in enumerate(annotated_axes):
    x_offset, y_offset = -0.25, 1.1
    ax_annot = ax_letters[ax_ind]
    if ax_ind == len(annotated_axes)-1:
        x_offset = -0.85
    elif ax_ind == len(annotated_axes)-3:
        x_offset = -0.22
    letter_annotation(ax, x_offset, y_offset, ax_letters[ax_ind], fontsize=14)


for sub_fig in [fig_r1, fig_r2, fig_r3]:
    sub_fig.set_frameon(False)

fig.show()

In [None]:
fig.savefig(fig_save_loc + 'Article/Fig3.png', dpi=800, transparent=True)
fig.savefig(fig_save_loc + 'Article/' + 'Fig3.svg', transparent=True, dpi=800)

### GridSpecMain