In [1]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
from scipy.io import loadmat
from sklearn.metrics import mean_squared_error
import pysindy as ps
import pandas as pd
from matplotlib import cm
import time
import pickle
import seaborn as sns
from statistics import mean 
from scipy.stats import ks_2samp
import matplotlib.gridspec as gridspec
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [2]:
def confidenceInterval(alist, ciColumn, trials):
    """
    Calculate confidence interval, mean, and standard deviation for a specified column in a list of datasets.
    
    Args:
    alist (list): List of datasets.
    ciColumn (int): Index of the column for confidence interval calculation.
    trials (int): Not used in the function body and could be removed or repurposed.
    
    Returns:
    tuple: lower and upper bounds of the confidence interval, mean, and standard deviation.
    """
    mean_1 = mean_of_list(alist, ciColumn)  # Calculate mean for each point
    std_1 = std_of_list(alist, ciColumn)  # Calculate standard deviation
    confidence_low, confidence_high = bootstrap_confidence_interval_column(alist, ciColumn)  # Get confidence interval
    return confidence_low, confidence_high, mean_1, std_1

In [3]:
def bootstrap_confidence_interval_column(data, ciColumn, num_bootstrap_samples=1000, confidence_level=0.90):
    """
    Generate bootstrap confidence intervals for a specified column across multiple datasets.
    
    Args:
    data (list): List of datasets where each dataset is a list or a numpy array.
    ciColumn (int): The index of the column for which to calculate the confidence interval.
    num_bootstrap_samples (int): Number of bootstrap samples to generate.
    confidence_level (float): Confidence level for the interval.
    
    Returns:
    tuple: Lists of lower bounds, upper bounds of the confidence intervals.
    """
    lower_bounds = []
    upper_bounds = []
    for dataset in data:
        column_data = np.array(dataset)[:, ciColumn]
        bootstrap_means = [np.mean(np.random.choice(column_data, size=len(column_data), replace=True)) for _ in range(num_bootstrap_samples)]
        alpha = (1 - confidence_level) / 2
        lower_bounds.append(np.percentile(bootstrap_means, 100 * alpha))
        upper_bounds.append(np.percentile(bootstrap_means, 100 * (1 - alpha)))
    return lower_bounds, upper_bounds

In [4]:
def replace_column(matrix, new_column, snr_index, column_index):
    """
    Replace a column in a matrix with a new column vector.
    
    Args:
        matrix (list of lists): The original matrix.
        new_column (list): The new column vector to replace the existing column.
        snr_index (int): Index for the sub-list in matrix where the replacement should occur.
        column_index (int): Index of the column to be replaced.
    
    Returns:
        list of lists: The matrix with the updated column.
    """
    if column_index < 0 or column_index >= len(matrix[0]):
        raise ValueError("Invalid column_index")
    if len(new_column) != len(matrix[snr_index][column_index]):
        raise ValueError("New column length must match the matrix height")
    for i in range(len(matrix[snr_index][column_index])):
        matrix[snr_index][column_index][i] = new_column[i]
    return matrix

In [5]:
def flatten(data, column, column_sin):
    """
    Flatten a specific column from a list of datasets. This function is intended to extract and flatten data
    from a structured dataset where data points are organized in columns.
    
    Args:
        data (list): List of datasets where each dataset is potentially a list of lists.
        column (int): Index of the column to extract and flatten.
        column_sin (int): Unused in this function but might be intended for future use to specify a second column.
    
    Returns:
        list: A flat list of values extracted from the specified column across all datasets.
    """
    flattened_data = []
    for i in range(len(data)):
        element = data[i][column]  # Extract the column for flattening
        flattened_data.append(element)
    flat_list = [item for sublist in flattened_data for item in sublist]  # Flatten the list of lists
    return flat_list

In [6]:
def normalise_decision_time(data,column_dt,column_dt_sindy,lcaddm=False):
    """
    Normalize decision times in the data according to the maximum and minimum values found across all data.
    
    Args:
    data (list): Data containing decision times.
    column_dt (int): Index of the decision time column in each sub-list.
    column_dt_sindy (int): Index of the decision time column in the SINDy model.
    lcaddm (bool): Flag to indicate special processing condition.
    
    Returns:
    list: Data with normalized decision times.
    """
    model_dt_flat=flatten(data,column_dt,column_dt_sindy)
    max_dt=max(model_dt_flat)
    min_dt=min(model_dt_flat)
    
    for i in range(len(data)):
        normalised=[]
        normalised_sindy=[]
        for c in range(len(data[i][column_dt])):
            if i==4 and c==9999 and lcaddm==True:
                norm_model_times = (mean(data[i][column_dt])- min_dt) / (max_dt - min_dt)
                norm_model_times_sindy = (mean(data[i][column_dt_sindy]) - min_dt) / (max_dt - min_dt)
                normalised.append(norm_model_times)
                normalised_sindy.append(norm_model_times_sindy)
            else:
                norm_model_times = (data[i][column_dt][c] - min_dt) / (max_dt - min_dt)
                norm_model_times_sindy = (data[i][column_dt_sindy][c] - min_dt) / (max_dt - min_dt)
                normalised.append(norm_model_times)
                normalised_sindy.append(norm_model_times_sindy)
        #print(len(normalised))
        data=replace_column(data, normalised, i, column_dt)
        data=replace_column(data, normalised_sindy, i, column_dt_sindy)
    
    return data

In [7]:
def cliffs_delta(x, y):
    """
    Compute Cliff's Delta, a measure of effect size, comparing two samples.
    
    Args:
    x (array): First sample.
    y (array): Second sample.
    
    Returns:
    float: Cliff's Delta value.
    """
    all_values = np.concatenate((x, y))
    ranks = np.argsort(all_values)
    rx = np.mean(ranks[:len(x)])
    ry = np.mean(ranks[len(x):])
    delta = (rx - ry) / len(all_values)
    return delta

In [8]:
def ks_statistical_analysis(data, column, column_sin, snr, no_column=False, data_sindy=None):
    """
    Perform the Kolmogorov-Smirnov test and compute Cliff's Delta to assess the differences
    between distributions in the provided datasets.
    
    Args:
        data (list): A list of datasets.
        column (int): Index of the primary column for comparison.
        column_sin (int): Index of the secondary column for comparison.
        snr (list): Signal-to-noise ratios or identifiers for the datasets.
        no_column (bool): Indicates if the data lacks structured columns.
        data_sindy (list, optional): Secondary dataset list for comparison.
    
    Returns:
        tuple: Contains D-statistics, p-values, significant points based on p-value,
               and effect sizes (Cliff's Delta).
    """
    p = np.empty(len(data))
    d_stat = np.empty(len(data))  # KS statistic
    effect_size = np.empty(len(data))  # cliff's d

    if no_column:
        for i in range(len(data)):
            group1 = data[i]
            group2 = data_sindy[i]
            d_stat[i], p[i] = ks_2samp(group1, group2)
            effect_size[i] = cliffs_delta(group1, group2)
    else:
        for i in range(len(data)):
            group1 = data[i][column]
            group2 = data[i][column_sin]
            d_stat[i], p[i] = ks_2samp(group1, group2)
            effect_size[i] = cliffs_delta(group1, group2)

    significant_points = [snr[i] if p[i] < .05 else None for i in range(len(p))]
    return d_stat, p, significant_points, effect_size

In [9]:
def mean_of_list(alist, column, no_column=False):
    """
    Calculate the mean of a specified column in a list of datasets, or of each dataset if no_column is True.
    
    Args:
        alist (list): List of datasets, where each dataset can be a list of lists or a list of numbers.
        column (int): Index of the column to compute the mean for.
        no_column (bool): If True, computes the mean of the entire dataset (assumed to be flat).
    
    Returns:
        list: List of means for each dataset.
    """
    if no_column:
        return [mean(dataset) for dataset in alist]
    else:
        return [mean(dataset[column]) for dataset in alist]

In [10]:
def std_of_list(alist, column, no_column=False):
    """
    Calculate the standard deviation of a specified column in a list of datasets, or of each dataset if no_column is True.
    
    Args:
        alist (list): List of datasets.
        column (int): Index of the column to compute the standard deviation for.
        no_column (bool): If True, computes the standard deviation of the entire dataset (assumed to be flat).
    
    Returns:
        list: List of standard deviations for each dataset.
    """
    if no_column:
        return [std(dataset) for dataset in alist]
    else:
        return [std(dataset[column]) for dataset in alist]

In [11]:
def average_activity_multidimensional(data, trial_threshold=100):
    """
    Calculate the average activity across multiple populations in multidimensional data, truncated to the last valid
    time point based on a minimum trial threshold.

    Args:
        data (list of list of lists): Data containing multiple populations over multiple samples, structured as a list
                                      of samples, where each sample contains lists of population activity data.
        trial_threshold (int): The minimum number of valid trials required for a time point to be considered in the average.

    Returns:
        list: A list of averages for each population, truncated to the minimum last valid time point across populations.
    """
    # Determine the number of populations and the maximum time length
    num_populations = len(data[0])
    max_length = max(max(len(population) for population in sample) for sample in data)
    
    # Initialize an array to store padded data
    padded_data = []
    for sample in data:
        padded_sample = [np.pad(population, (0, max_length - len(population)), 'constant', constant_values=np.nan) for population in sample]
        padded_data.append(padded_sample)
    
    # Convert the list of padded data to a numpy array for easy mean calculation
    padded_data_np = np.array(padded_data)
    average_activity = np.nanmean(padded_data_np, axis=0)
    
    # Calculate valid counts and determine the last valid index based on the trial threshold
    valid_counts = np.sum(~np.isnan(padded_data_np), axis=0)
    last_valid_indices = np.max(np.where(valid_counts >= trial_threshold, np.arange(valid_counts.shape[1]), 0), axis=1)
    min_last_valid_index = np.min(last_valid_indices)
    
    # Truncate the average activity to the last valid index
    truncated_averages = average_activity[:, :min_last_valid_index + 1]
    return truncated_averages.tolist()

In [12]:
def average_activity(data, trial_threshold=100):
    """
    Calculate the average activity across data points, truncated to the last valid time point based on a minimum
    trial threshold.

    Args:
        data (list of lists): Data containing time series or trials from multiple entities.
        trial_threshold (int): Minimum number of trials required to consider a time point valid.

    Returns:
        list: Average activity for the data, truncated to the last valid time point where the trial count exceeds the threshold.
    """
    # Find the maximum length across all data points for padding
    max_length = max(len(item) for item in data)
    
    # Pad each data point with NaNs to the maximum length
    padded_data = np.array([np.pad(item, (0, max_length - len(item)), 'constant', constant_values=np.nan) for item in data])
    
    # Calculate the mean, ignoring NaNs for an accurate calculation
    average_activity = np.nanmean(padded_data, axis=0)
    
    # Count the number of non-NaN values at each time point
    valid_counts = np.sum(~np.isnan(padded_data), axis=0)
    
    # Determine the last valid time point based on the trial threshold
    last_valid_index = np.max(np.where(valid_counts >= trial_threshold)[0])
    
    # Truncate the average activity to this last valid index
    truncated_average_activity = average_activity[:last_valid_index + 1]
    return truncated_average_activity

# Figure 2 

In [None]:
#fig2. setup
x_ddm=SnR_ddm
ddm_acc=confidenceInterval(st_ddm,1,10000)
ddm_acc=[np.array(ddm_acc[i])*100 for i in range(len(ddm_acc))]
ddm_accsin=confidenceInterval(st_ddm,2,10000)
ddm_accsin=[np.array(ddm_accsin[i])*100 for i in range(len(ddm_accsin))]
c_ddm_acc,p_ddm_acc,sig_points_acc_ddm,effect_st_ddm=ks_statistical_analysis(st_ddm,1,2,x_ddm)

ddm_time=confidenceInterval(st_ddm,3,10000)
ddm_time=[np.array(ddm_time[i]) for i in range(len(ddm_time))]
ddm_timesin=confidenceInterval(st_ddm,4,10000)
ddm_timesin=[np.array(ddm_timesin[i]) for i in range(len(ddm_timesin))]
c_ddm_time,p_ddm_time,sig_points_time_ddm,effect_st_ddm_t=ks_statistical_analysis(st_ddm,3,4,x_ddm)

x_lcaddm=SnR_lcaddm
lcaddm_acc=confidenceInterval(st_lcaddm,1,10000)
lcaddm_acc=[np.array(lcaddm_acc[i])*100 for i in range(len(lcaddm_acc))]
lcaddm_accsin=confidenceInterval(st_lcaddm,2,10000)
lcaddm_accsin=[np.array(lcaddm_accsin[i])*100 for i in range(len(lcaddm_accsin))]
c_lcaddm_acc,p_lcaddm_acc,sig_points_acc_lcaddm,effect_st_lcaddm=ks_statistical_analysis(st_lcaddm,1,2,x_lcaddm)

lcaddm_time=confidenceInterval(st_lcaddm,3,10000)
lcaddm_time=[np.array(lcaddm_time[i]) for i in range(len(lcaddm_time))]
lcaddm_timesin=confidenceInterval(st_lcaddm,4,10000)
lcaddm_timesin=[np.array(lcaddm_timesin[i]) for i in range(len(lcaddm_timesin))]
c_lcaddm_time,p_lcaddm_time,sig_points_time_lcaddm,effect_st_lcaddm_t=ks_statistical_analysis(st_lcaddm,3,4,x_lcaddm)

lcaddm_poly_acc=confidenceInterval(st_lcaddm_poly,1,10000)
lcaddm_poly_acc=[np.array(lcaddm_poly_acc[i])*100 for i in range(len(lcaddm_poly_acc))]
lcaddm_poly_accsin=confidenceInterval(st_lcaddm_poly,2,10000)
lcaddm_poly_accsin=[np.array(lcaddm_poly_accsin[i])*100 for i in range(len(lcaddm_poly_accsin))]
c_lcaddm_poly_acc,p_lcaddm_poly_acc,sig_points_acc_lcaddm_poly,effect_st_lcaddm_poly=ks_statistical_analysis(st_lcaddm_poly,1,2,x_lcaddm)

lcaddm_poly_time=confidenceInterval(st_lcaddm_poly,3,10000)
lcaddm_poly_time=[np.array(lcaddm_poly_time[i]) for i in range(len(lcaddm_poly_time))]
lcaddm_poly_timesin=confidenceInterval(st_lcaddm_poly,4,10000)
lcaddm_poly_timesin=[np.array(lcaddm_poly_timesin[i]) for i in range(len(lcaddm_poly_timesin))]
c_lcaddm_poly_time,p_lcaddm_poly_time,sig_points_time_lcaddm_poly,effect_st_lcaddm_poly_t=ks_statistical_analysis(st_lcaddm_poly,3,4,x_lcaddm)

x_lca=SnR_lca
lca_acc=confidenceInterval(st_lca,1,10000)
lca_acc=[np.array(lca_acc[i])*100 for i in range(len(lca_acc))]
lca_accsin=confidenceInterval(st_lca,2,10000)
lca_accsin=[np.array(lca_accsin[i])*100 for i in range(len(lca_accsin))]
c_lca_acc,p_lca_acc,sig_points_acc_lca,effect_st_lca=ks_statistical_analysis(st_lca,1,2,x_lca)

lca_time=confidenceInterval(st_lca,3,10000)
lca_time=[np.array(lca_time[i]) for i in range(len(lca_time))]
lca_timesin=confidenceInterval(st_lca,4,10000)
lca_timesin=[np.array(lca_timesin[i]) for i in range(len(lca_timesin))]
c_lca_time,p_lca_time,sig_points_time_lca,effect_st_lca_t=ks_statistical_analysis(st_lca,3,4,x_lca)

x_nlb=SnR_nlb
nlb_acc=confidenceInterval(st_nlb,1,10000)
nlb_acc=[np.array(nlb_acc[i])*100 for i in range(len(nlb_acc))]
nlb_accsin=confidenceInterval(st_nlb,2,10000)
nlb_accsin=[np.array(nlb_accsin[i])*100 for i in range(len(nlb_accsin))]
c_nlb_acc,p_nlb_acc,sig_points_acc_nlb,effect_st_nlb=ks_statistical_analysis(st_nlb,1,2,x_nlb)

nlb_time=confidenceInterval(st_nlb,3,10000)
nlb_time=[np.array(nlb_time[i]) for i in range(len(nlb_time))]
nlb_timesin=confidenceInterval(st_nlb,4,10000)
nlb_timesin=[np.array(nlb_timesin[i]) for i in range(len(nlb_timesin))]
c_nlb_time,p_nlb_time,sig_points_time_nlb,effect_st_nlb_t=ks_statistical_analysis(st_nlb,3,4,x_nlb)

In [None]:
###fig 2 
sns.set(font_scale=4)
sns.set_style('white') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=36)     # fontsize of the axes title
plt.rc('axes', labelsize=48)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=48)    # fontsize of the tick labels
plt.rc('ytick', labelsize=48)    # fontsize of the tick labels
#plt.rc('legend', fontsize=16)    # legend fontsize
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rc('font', size=48)          # controls default text sizes
# sns.set(font_scale=2)
# Assuming st_ddm and other necessary data are already defined
#ddm accuracy and rts


data_sets = [{'time':np.arange(0,10000,.1),'letter':'A','label':'DDM','model':st_ddm,'x': x_ddm,'acc': ddm_acc,'accsin': ddm_accsin,'time': ddm_time,'timesin': ddm_timesin,'sig_points_acc': sig_points_acc_ddm,'sig_points_time': sig_points_time_ddm,'effect_st_acc': effect_st_ddm,'effect_st_time': effect_st_ddm_t},
{'time':np.arange(0,10000,.01),'letter':'B','label':'LCA-DDM','model':st_lcaddm, 'x': x_lcaddm,'acc': lcaddm_acc,'accsin': lcaddm_accsin,'time': lcaddm_time,'timesin': lcaddm_timesin,'sig_points_acc': sig_points_acc_lcaddm,'sig_points_time': sig_points_time_lcaddm,'effect_st_acc': effect_st_lcaddm,'effect_st_time': effect_st_lcaddm_t,'acc_poly': lcaddm_poly_acc,'accsin_poly': lcaddm_poly_accsin,'time_poly': lcaddm_poly_time,'timesin_poly': lcaddm_poly_timesin,'sig_points_acc_poly': sig_points_acc_lcaddm_poly,'sig_points_time_poly': sig_points_time_lcaddm_poly,'effect_st_acc_poly': effect_st_lcaddm_poly,'effect_st_time_poly': effect_st_lcaddm_poly_t},
{'time':np.arange(0,10000,.01),'letter':'C','label':'LCA','model':st_lca, 'x': x_lca,'acc': lca_acc,'accsin': lca_accsin,'time': lca_time,'timesin': lca_timesin,'sig_points_acc': sig_points_acc_lca,'sig_points_time': sig_points_time_lca,'effect_st_acc': effect_st_lca,'effect_st_time': effect_st_lca_t},
{'time':np.arange(0,10000,.01),'letter':'D','label':'NLB','model':st_nlb, 'x': x_nlb,'acc': nlb_acc,'accsin': nlb_accsin,'time': nlb_time,'timesin': nlb_timesin,'sig_points_acc': sig_points_acc_nlb,'sig_points_time': sig_points_time_nlb,'effect_st_acc': effect_st_nlb,'effect_st_time': effect_st_nlb_t}
]

# Define your figure and gridspec
fig = plt.figure(figsize=(35, 30))  # Adjust the figure size as needed
gs = gridspec.GridSpec(4, 4, wspace=0.65)  # 4 rows, 4 columns, with column spans for the first plot in each row

# Assuming data_sets is defined as provided

for row, model_data in enumerate(data_sets ):  # Repeat data_sets twice since you have 4 rows and 2 models
    # Accessing data for the current model
    model_trials= model_data['model']
    x = model_data['x']
    acc = model_data['acc']
    accsin = model_data['accsin']
    time = model_data['time']
    timesin = model_data['timesin']
    sig_points_acc = model_data['sig_points_acc']
    effect_st_acc = model_data['effect_st_acc']
    sig_points_time = model_data['sig_points_time']
    effect_st_time = model_data['effect_st_time']
    label=model_data['label']
    letter=model_data['letter']
    if letter =="B":
        acc_poly = model_data['acc_poly']
        acc_polysin = model_data['accsin_poly']
        time_poly = model_data['time_poly']
        timesin_poly = model_data['timesin_poly']
        sig_points_acc_poly = model_data['sig_points_acc_poly']
        effect_st_acc_poly = model_data['effect_st_acc_poly']
        sig_points_time_poly = model_data['sig_points_time_poly']
        effect_st_time_poly = model_data['effect_st_time_poly']
    
    
    if len(model_trials[0][5][0])>2:
        
        if len(model_trials[0][5])>100:
            a_test=np.arange(0,10000,.1)
            # First plot: Trial Data Plot, taking double space
            ax0 = fig.add_subplot(gs[row, :2])  # Span first two columns
            ax0.plot(a_test[0:len(model_trials[40][5][125])],model_trials[40][5][125], color='blue', label="Model", linewidth=7)
            ax0.plot(a_test[0:len(model_trials[40][6][125])],model_trials[40][6][125], color='orange', label='SINDy', linewidth=7)
            ax0.axhline(y=1, linestyle='dashed', linewidth=7, color='black', label='Threshold')
            ax0.axhline(y=-1, linestyle='dashed', linewidth=7, color='black')
            ax0.set_xlim(0)
            ax0.set(ylabel="$X$")
            sns.despine()
            ax0.legend(loc='best', fontsize=24)
            ax0.spines['left'].set_linewidth(7)
            ax0.spines['bottom'].set_linewidth(7)
            ax0.set_title(letter, fontsize=48,fontweight='bold' )
            ax0.title.set_position([-.1, 1.05]) 
        else:
            a_test=np.arange(0,10000,.01)

            # First plot: Trial Data Plot, taking double space
            ax0 = fig.add_subplot(gs[row, :2])  # Span first two columns
            ax0.plot(a_test[0:len(model_trials[40][5][2])],model_trials[40][5][2], color='blue', label=label, linewidth=7)
            ax0.plot(a_test[0:len(model_trials[40][6][2])],model_trials[40][6][2], color='orange', label='SINDy', linewidth=7)
            ax0.axhline(y=.75, linestyle='dashed', linewidth=7, color='black', label='Threshold')
            ax0.axhline(y=-.75, linestyle='dashed', linewidth=7, color='black')
            #ax0.set_ylabel('Decision variable')
            ax0.set(ylabel="$X$")
            ax0.set_xlim(0)
            sns.despine()
#             ax0.legend(loc='best', fontsize=24)
            ax0.spines['left'].set_linewidth(7)
            ax0.spines['bottom'].set_linewidth(7)
            ax0.set_title(letter, fontsize=48,fontweight='bold' )
            ax0.title.set_position([-.1, 1.05]) 
    else:
        # First plot: Trial Data Plot, taking double space
        a_test=np.arange(0,10000,.01)
        ax0 = fig.add_subplot(gs[row, :2])  # Span first two columns
        ax0.plot(a_test[0:len(model_trials[40][5][0][0])],(model_trials[40][5][0][0]),color='blue', label=label,linewidth=7 )
        ax0.plot(a_test[0:len(model_trials[40][5][0][1])],(model_trials[40][5][0][1]),color='blue', alpha=.5,linewidth=7 )
        ax0.plot(a_test[0:len(model_trials[40][6][0][0])],(model_trials[40][6][0][0]),color='orange', label='SINDy',linewidth=7 )
        ax0.plot(a_test[0:len(model_trials[40][6][0][1])],(model_trials[40][6][0][1]),color='orange', alpha=.5,linewidth=7 )
        ax0.axhline(y=1, linestyle='dashed', linewidth=7, color='black',label='Threshold')
        ax0.axvline(x=-1, linestyle='dashed', linewidth=7, color='grey',label='Error')
        ax0.set(ylabel="$y_1$, $y_2$")
        ax0.set_xlim(0)
        sns.despine()
#         ax0.legend(loc='best', fontsize=24)
        ax0.spines['left'].set_linewidth(7)
        ax0.spines['bottom'].set_linewidth(7)
        ax0.set_title(letter, fontsize=48,fontweight='bold' )
        ax0.title.set_position([-.1, 1.05]) 
    
    # Second plot: Accuracy Plot
    ax1 = fig.add_subplot(gs[row, 2])
    # Plotting accuracy data
    ax1.plot(sig_points_acc, np.full(len(sig_points_acc), 40), '*', color='black')
    ax1.plot(x, acc[2], '', color='blue', label=label,linewidth=4 )
    ax1.plot(x, acc[0], color='blue',linewidth=3 )
    ax1.plot(x, acc[1], color='blue',linewidth=3 )
    ax1.fill_between(x, acc[0], acc[1], color='blue', alpha=0.25)
    ax1.plot(x, accsin[2], '', color='orange', label="SINDy Model",linewidth=4)
    ax1.plot(x, accsin[0], color='orange',linewidth=3)
    ax1.plot(x, accsin[1], color='orange',linewidth=3)
    ax1.fill_between(x, accsin[0], accsin[1], color='orange', alpha=0.25)
    if letter =='B':
        ax1.plot(x, acc_polysin[2], '', color='gold', label="SINDy Model")
        ax1.plot(x, acc_polysin[0], color='gold')
        ax1.plot(x, acc_polysin[1], color='gold')
        ax1.fill_between(x, acc_polysin[0], acc_polysin[1], color='gold', alpha=0.25)
        
    # Add cliff's d annotations for accuracy
#     for i, (x_val, d_value) in enumerate(zip(x, effect_st_acc)):
#         if sig_points_acc[i] is not None:
#             ax1.text(x_val, 45, f'd={d_value:.2f}', fontsize=4, ha='center', va='bottom', color='red')

    # Third plot: Response Time Plot
    ax2 = fig.add_subplot(gs[row, 3])
    # Plotting response time data
    ax2.plot(sig_points_time, np.full(len(sig_points_time), 0), '*', color='black')
    ax2.plot(x, time[2], '', color='blue', label="Model Response Time",linewidth=4)
    ax2.plot(x, time[0], color='blue',linewidth=3)
    ax2.plot(x, time[1], color='blue',linewidth=3)
    ax2.fill_between(x, time[0], time[1], color='blue', alpha=0.25)
    ax2.plot(x, timesin[2], '', color='orange', label="SINDy Response Time",linewidth=4)
    ax2.plot(x, timesin[0], color='orange',linewidth=3)
    ax2.plot(x, timesin[1], color='orange',linewidth=3)
    ax2.fill_between(x, timesin[0], timesin[1], color='orange', alpha=0.25)
    if letter =='B':
        ax2.plot(x, timesin_poly[2], '', color='gold', label="SINDy Response Time")
        ax2.plot(x, timesin_poly[0], color='gold')
        ax2.plot(x, timesin_poly[1], color='gold')
        ax2.fill_between(x, timesin_poly[0], timesin_poly[1], color='gold', alpha=0.25)
    # Add cliff's d annotations for response time
#     for i, (x_val, d_value) in enumerate(zip(x, effect_st_time)):
#         if sig_points_time[i] is not None:
#             ax2.text(x_val, 1, f'   d={d_value:.2f}  ', fontsize=4, ha='center', va='bottom', color='red')

    # Apply styles
    sns.despine(ax=ax1)
    sns.despine(ax=ax2)
    ax1.spines['left'].set_linewidth(7)
    ax1.spines['bottom'].set_linewidth(7)
    ax2.spines['left'].set_linewidth(7)
    ax2.spines['bottom'].set_linewidth(7)
    #ax1.legend(loc='best', fontsize=14)
    #ax2.legend(loc='best', fontsize=14)
    


    
ax0.set_xlabel('Time (a.u.)')
# ax1.supylabel('Choice accuracy (%)')
# ax2.supylabel('Normalised decision time')
fig.text(.85, .85, 'DDM', ha='center', va='center',zorder=50)
fig.text(.86, .65, 'LCA-DDM', ha='center', va='center',zorder=50)
fig.text(.85, .45, 'LCA', ha='center', va='center',zorder=50)
fig.text(.85, .25, 'NLB', ha='center', va='center',zorder=50)


fig.text(0.71, 0.525, 'Normalised decision time (a.u.)', ha='center', va='center', rotation='vertical',zorder=50)
fig.text(0.50, 0.525, 'Choice accuracy (%)', ha='center', va='center', rotation='vertical',zorder=50)
fig.text(.73, 0.09, 'Signal-to-noise ratio', ha='center', va='center',zorder=50)
fig.text(0.075, 0.525, 'Decision variable', ha='center', va='center',rotation='vertical',zorder=50)

# fig.supylabel('Decision variable')

plt.tight_layout()
#plt.savefig('fig_2_single_trial_02_05.pdf', dpi=600, bbox_inches='tight')

# Figure 3

In [None]:
###fig 3 data
# Assuming ave_ddm and other necessary data are already defined
#ddm accuracy and rts
x_ddm=SnR_ddm
ddm_acc=confidenceInterval(ave_ddm,1,10000)
ddm_acc=[np.array(ddm_acc[i])*100 for i in range(len(ddm_acc))]
ddm_accsin=confidenceInterval(ave_ddm,2,10000)
ddm_accsin=[np.array(ddm_accsin[i])*100 for i in range(len(ddm_accsin))]
c_ddm_acc,p_ddm_acc,sig_points_acc_ddm,effect_ave_ddm=ks_statistical_analysis(ave_ddm,1,2,x_ddm)

x_ddm=SnR_ddm
ddm_time=confidenceInterval(ave_ddm,3,10000)
ddm_time=[np.array(ddm_time[i]) for i in range(len(ddm_time))]
ddm_timesin=confidenceInterval(ave_ddm,4,10000)
ddm_timesin=[np.array(ddm_timesin[i]) for i in range(len(ddm_timesin))]
c_ddm_time,p_ddm_time,sig_points_time_ddm,effect_ave_ddm_t=ks_statistical_analysis(ave_ddm,3,4,x_ddm)

x_lcaddm=SnR_lcaddm
lcaddm_acc=confidenceInterval(ave_lcaddm,1,10000)
lcaddm_acc=[np.array(lcaddm_acc[i])*100 for i in range(len(lcaddm_acc))]
lcaddm_accsin=confidenceInterval(ave_lcaddm,2,10000)
lcaddm_accsin=[np.array(lcaddm_accsin[i])*100 for i in range(len(lcaddm_accsin))]
c_lcaddm_acc,p_lcaddm_acc,sig_points_acc_lcaddm,effect_ave_lcaddm=ks_statistical_analysis(ave_lcaddm,1,2,x_lcaddm)

x_lcaddm=SnR_lcaddm
lcaddm_time=confidenceInterval(ave_lcaddm,3,10000)
lcaddm_time=[np.array(lcaddm_time[i]) for i in range(len(lcaddm_time))]
lcaddm_timesin=confidenceInterval(ave_lcaddm,4,10000)
lcaddm_timesin=[np.array(lcaddm_timesin[i]) for i in range(len(lcaddm_timesin))]
c_lcaddm_time,p_lcaddm_time,sig_points_time_lcaddm,effect_ave_lcaddm_t=ks_statistical_analysis(ave_lcaddm,3,4,x_lcaddm)

x_lca=SnR_lca
lca_acc=confidenceInterval(ave_lca,1,10000)
lca_acc=[np.array(lca_acc[i])*100 for i in range(len(lca_acc))]
lca_accsin=confidenceInterval(ave_lca,2,10000)
lca_accsin=[np.array(lca_accsin[i])*100 for i in range(len(lca_accsin))]
c_lca_acc,p_lca_acc,sig_points_acc_lca,effect_ave_lca=ks_statistical_analysis(ave_lca,1,2,x_lca)

x_lca=SnR_lca
lca_time=confidenceInterval(ave_lca,3,10000)
lca_time=[np.array(lca_time[i]) for i in range(len(lca_time))]
lca_timesin=confidenceInterval(ave_lca,4,10000)
lca_timesin=[np.array(lca_timesin[i]) for i in range(len(lca_timesin))]
c_lca_time,p_lca_time,sig_points_time_lca,effect_ave_lca_t=ks_statistical_analysis(ave_lca,3,4,x_lca)

x_nlb=SnR_nlb
nlb_acc=confidenceInterval(ave_nlb,1,10000)
nlb_acc=[np.array(nlb_acc[i])*100 for i in range(len(nlb_acc))]
nlb_accsin=confidenceInterval(ave_nlb,2,10000)
nlb_accsin=[np.array(nlb_accsin[i])*100 for i in range(len(nlb_accsin))]
c_nlb_acc,p_nlb_acc,sig_points_acc_nlb,effect_ave_nlb=ks_statistical_analysis(ave_nlb,1,2,x_nlb)

x_nlb=SnR_nlb
nlb_time=confidenceInterval(ave_nlb,3,10000)
nlb_time=[np.array(nlb_time[i]) for i in range(len(nlb_time))]
nlb_timesin=confidenceInterval(ave_nlb,4,10000)
nlb_timesin=[np.array(nlb_timesin[i]) for i in range(len(nlb_timesin))]
c_nlb_time,p_nlb_time,sig_points_time_nlb,effect_ave_nlb_t=ks_statistical_analysis(ave_nlb,3,4,x_nlb)

In [None]:
#fig 3#

# /graph var fig3
sns.set_style('white') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=36)     # fontsize of the axes title
plt.rc('axes', labelsize=48)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=40)    # fontsize of the tick labels
plt.rc('ytick', labelsize=40)    # fontsize of the tick labels
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
#plt.rc('legend', fontsize=20)    # legend fontsize
plt.rc('font', size=48)          # controls default text sizes

data_sets = [
{'label':'DDM','x': x_ddm,'acc': ddm_acc,'accsin': ddm_accsin,'time': ddm_time,'timesin': ddm_timesin,'sig_points_acc': sig_points_acc_ddm,'sig_points_time': sig_points_time_ddm,'effect_ave_acc': effect_ave_ddm,'effect_ave_time': effect_ave_ddm_t},
{'label':'LCA-DDM', 'x': x_lcaddm,'acc': lcaddm_acc,'accsin': lcaddm_accsin,'time': lcaddm_time,'timesin': lcaddm_timesin,'sig_points_acc': sig_points_acc_lcaddm,'sig_points_time': sig_points_time_lcaddm,'effect_ave_acc': effect_ave_lcaddm,'effect_ave_time': effect_ave_lcaddm_t},
{'label':'LCA', 'x': x_lca,'acc': lca_acc,'accsin': lca_accsin,'time': lca_time,'timesin': lca_timesin,'sig_points_acc': sig_points_acc_lca,'sig_points_time': sig_points_time_lca,'effect_ave_acc': effect_ave_lca,'effect_ave_time': effect_ave_lca_t},
# {'label':'DDM polynomial 1','x': x_ddm,'acc': ddm_acc_poly,'accsin': ddm_accsin_poly,'time': ddm_time_poly,'timesin': ddm_timesin_poly,'sig_points_acc': sig_points_acc_ddm_poly,'sig_points_time': sig_points_time_ddm_poly,'effect_ave_acc': effect_ave_ddm_poly,'effect_ave_time': effect_ave_ddm_t_poly},
# {'label':'LCA-DDM poly 0 ', 'x': x_lcaddm,'acc': lcaddm_acc_poly,'accsin': lcaddm_accsin_poly,'time': lcaddm_time_poly,'timesin': lcaddm_timesin_poly,'sig_points_acc': sig_points_acc_lcaddm_poly,'sig_points_time': sig_points_time_lcaddm_poly,'effect_ave_acc': effect_ave_lcaddm_poly,'effect_ave_time': effect_ave_lcaddm_t},
# {'label':'LCA poly 0 ', 'x': x_lca,'acc': lca_acc_poly,'accsin': lca_accsin_poly,'time': lca_time_poly,'timesin': lca_timesin_poly,'sig_points_acc': sig_points_acc_lca_poly,'sig_points_time': sig_points_time_lca_poly,'effect_ave_acc': effect_ave_lca_poly,'effect_ave_time': effect_ave_lca_t_poly},
{'label':'NLB','x': x_nlb,'acc': nlb_acc,'accsin': nlb_accsin,'time': nlb_time,'timesin': nlb_timesin,'sig_points_acc': sig_points_acc_nlb,'sig_points_time': sig_points_time_nlb,'effect_ave_acc': effect_ave_nlb,'effect_ave_time': effect_ave_nlb_t}

    # {'label':'DDM poly 2','model':average_ddm_poly2,'x': x_ddm,'acc': ddm_acc_poly2,'accsin': ddm_accsin_poly2,'time': ddm_time_poly2,'timesin': ddm_timesin_poly2,'sig_points_acc': sig_points_acc_ddm_poly2,'sig_points_time': sig_points_time_ddm_poly2,'effect_ave_acc': effect_ave_ddm_poly2,'effect_ave_time': effect_ave_ddm_t_poly2},
# {'label':'LCA-DDM poly 2','model':average_lcaddm_poly2, 'x': x_lcaddm,'acc': lcaddm_acc_poly2,'accsin': lcaddm_accsin_poly2,'time': lcaddm_time_poly2,'timesin': lcaddm_timesin_poly2,'sig_points_acc': sig_points_acc_lcaddm_poly2,'sig_points_time': sig_points_time_lcaddm_poly2,'effect_ave_acc': effect_ave_lcaddm_poly2,'effect_ave_time': effect_ave_lcaddm_t_poly2},
# {'label':'LCA poly 2','model':average_lca_poly2, 'x': x_lca,'acc': lca_acc_poly2,'accsin': lca_accsin_poly2,'time': lca_time_poly2,'timesin': lca_timesin_poly2,'sig_points_acc': sig_points_acc_lca_poly2,'sig_points_time': sig_points_time_lca_poly2,'effect_ave_acc': effect_ave_lca_poly2,'effect_ave_time': effect_ave_lca_t_poly2}
] 

# Define your figure and gridspec
fig = plt.figure(figsize=(30, 25))  # Adjust the figure size as needed
gs = gridspec.GridSpec(4, 2, wspace=0.25)  # Adjusted for 3 model groups, 2 plots (Accuracy, Response Time) each

# Helper function to plot for each model group
def plot_for_model_group(ax_acc, ax_time, model_group_data, x, label_prefix,letter):
    i=0
    for model_data in model_group_data:
        
        # Accuracy Plot
        ax_acc.plot(model_data['sig_points_acc'], np.full(len(model_data['sig_points_acc']), 40), '*', color='black')
        ax_acc.plot(x, model_data['acc'][2], '',  linewidth=4, color='blue', label="Model")
        ax_acc.plot(x, model_data['acc'][0],  linewidth=3, color='blue')
        ax_acc.plot(x, model_data['acc'][1],  linewidth=3, color='blue')
        ax_acc.fill_between(x, model_data['acc'][0], model_data['acc'][1],  linewidth=5, color='blue', alpha=0.25)
        ax_acc.plot(x, model_data['accsin'][2], '',  linewidth=4, color='orange', label="SINDy")
        ax_acc.plot(x, model_data['accsin'][0],  linewidth=3, color='orange')
        ax_acc.plot(x, model_data['accsin'][1],  linewidth=3, color='orange')
        ax_acc.fill_between(x, model_data['accsin'][0], model_data['accsin'][1],  linewidth=5, color='orange', alpha=0.25)
        ax_acc.set_title(letter, fontsize=48,fontweight='bold' )
        ax_acc.title.set_position([-.15, 1.05]) 

        # Response Time Plot
        ax_time.plot(model_data['sig_points_time'], np.full(len(model_data['sig_points_acc']), 0), '*', color='black')
        ax_time.plot(x, model_data['time'][2], '',  linewidth=4, color='blue')#, label=f"{model_data['label']} Model Time")
        ax_time.plot(x, model_data['time'][0],  linewidth=3, color='blue')
        ax_time.plot(x, model_data['time'][1],  linewidth=3, color='blue')
        ax_time.fill_between(x, model_data['time'][0], model_data['time'][1],  linewidth=5, color='blue', alpha=0.25)
        ax_time.plot(x, model_data['timesin'][2], '',  linewidth=4, color='orange', label=f"{model_data['label']} SINDy Time")
        ax_time.plot(x, model_data['timesin'][0],  linewidth=3, color='orange')
        ax_time.plot(x, model_data['timesin'][1],  linewidth=3, color='orange')
        ax_time.fill_between(x, model_data['timesin'][0],model_data['timesin'][1],  linewidth=5, color='orange', alpha=0.25)
        i+=1
    # Apply legend with label prefix to distinguish model groups
    if label_prefix=="DDM":
        ax_acc.legend( loc='best', fontsize=24)
#     ax_time.legend(loc='best', fontsize=24)

# Group data by model type
ddm_models = [data for data in data_sets if 'DDM' in data['label'] and 'LCA' not in data['label']]
lca_ddm_models = [data for data in data_sets if 'LCA-DDM' in data['label']]
lca_models = [data for data in data_sets if 'LCA' in data['label'] and 'DDM' not in data['label']]
nlb_models =[data for data in data_sets if 'NLB' in data['label']]

# Plot data for each model group
model_groups = [('DDM', ddm_models), ('LCA-DDM', lca_ddm_models), ('LCA', lca_models),('NLB',nlb_models)]
letter=['A','B','C','D']
for i, (label_prefix, model_group) in enumerate(model_groups):
    ax_acc = fig.add_subplot(gs[i, 0])
    ax_time = fig.add_subplot(gs[i, 1])
    plot_for_model_group(ax_acc, ax_time, model_group, model_group[0]['x'], label_prefix,letter[i])

    # Styling for each subplot, maintaining properties from the original code
    for ax in [ax_acc, ax_time]:
        sns.despine(ax=ax)
        ax.spines['left'].set_linewidth(7)
        ax.spines['bottom'].set_linewidth(7)


# ax0.set_xlabel('Time (a.u.)')
# ax1.supylabel('Choice accuracy (%)')
# ax2.supylabel('Normalised decision time')
fig.text(.85, .85, 'DDM', ha='center', va='center',zorder=50)
fig.text(.85, .65, 'LCA-DDM', ha='center', va='center',zorder=50)
fig.text(.85, .45, 'LCA', ha='center', va='center',zorder=50)
fig.text(.85, .25, 'NLB', ha='center', va='center',zorder=50)
fig.text(0.505, 0.525, 'Normalised decision time (a.u.)', ha='center', va='center', rotation='vertical',zorder=50)
fig.text(0.045, 0.525, 'Choice accuracy (%)', ha='center', va='center', rotation='vertical',zorder=50)
fig.text(0.49,0.08,'Signal-to-noise ratio', ha='center', va='center',zorder=50)
# fig.supxlabel('Signal-tnoise ratio')        
plt.tight_layout()
#plt.savefig("fig_average_choice_behaviour_29_04.pdf", dpi=600,bbox_inches='tight')

# Figure 4

In [None]:
##Fig 4. average trial activity figure


sns.set_style('white') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=36)     # fontsize of the axes title
plt.rc('axes', labelsize=36)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=42)    # fontsize of the tick labels
plt.rc('ytick', labelsize=42)    # fontsize of the tick labels
plt.rc('legend', fontsize=24)    # legend fontsize
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rc('font', size=40) 

# Define your sample indices
ddm_samples=[0,9,17,26]
lcaddm_samples=[0,3,5,8]
lca_samples=[0,2,5,7]
nlb_samples=[0,6,12,18]

samples = [ddm_samples, lcaddm_samples, lca_samples, nlb_samples]
sample_titles=["Zero","Low","Medium","High"]


fig, axs = plt.subplots(nrows=4, ncols=len(samples), figsize=(30, 25))

# Titles for each model for clarity in the plots
model_titles = ['DDM','LCA-DDM', 'LCA','NLB']

model_axis = ['$X$','$y_1$, $y_2$', '$y_1$, $y_2$','$X_1$']
# for model_index, model_data in enumerate([ave_ddm_nt,ave_lcaddm_nt, ave_lca_nt,ave_nt_nlb]):
for model_index, model_data in enumerate([ave_ddm,ave_lcaddm, ave_lca,ave_nlb]):
    for sample_index, sample in enumerate(samples[model_index]):
        # Access the specific subplot for the current model and sample
        ax = axs[model_index, sample_index] if len(samples[model_index]) > 1 else axs[model_index]
        
        # Fetch the data for this model and sample

        if model_titles[model_index]=='LCA':
            time_t=np.arange(0,10000,.01)
            ave_act = average_activity_multidimensional(model_data[sample][5], trial_threshold=50)
            ave_sin = average_activity_multidimensional(model_data[sample][6], trial_threshold=50)
        elif model_titles[model_index]=='LCA-DDM':
            time_t=np.arange(0,10000,.01)
            ave_act = average_activity_multidimensional(model_data[sample][5], trial_threshold=24)
            ave_sin = average_activity_multidimensional(model_data[sample][6], trial_threshold=24)
        elif model_titles[model_index]=="DDM":
            time_t=np.arange(0,10000,.1)
            ave_act = average_activity(model_data[sample][5], trial_threshold=50)
            ave_sin = average_activity(model_data[sample][6], trial_threshold=50)
        elif model_titles[model_index]=="NLB":
            time_t=np.arange(0,10000,.01)
            ave_act = average_activity(model_data[sample][5], trial_threshold=10)
            ave_sin = average_activity(model_data[sample][6], trial_threshold=10)
        
        # Plot the data for this model and sample
        if model_titles[model_index]=='LCA-DDM' or model_titles[model_index]=='LCA':  # Check if data is not empty
#             ax.plot(ave_lca_act[0], ave_lca_act[1], label='Population 1')
              ax.plot(time_t[0:len(ave_act[0])],ave_act[0],color="blue",linewidth=3, label=model_titles[model_index])
              ax.plot(time_t[0:len(ave_act[1])],ave_act[1],color="blue",alpha=.5,linewidth=5)
              ax.plot(time_t[0:len(ave_sin[0])],ave_sin[0],color="orange",linewidth=3, label='SINDy')
              ax.plot(time_t[0:len(ave_sin[1])],ave_sin[1],color="orange",alpha=.5,linewidth=5)

              ax.set_ylim(-1,1)
              ax.axhline(y=1, linestyle='dashed', linewidth=5, color='black', label='Threshold')
#             ax.plot(ave_lca_act_sin[0], ave_lca_act_sin[1], linestyle="dashed", label='Population 2')
        else: 
            if model_titles[model_index]=='NLB':
              ax.plot(time_t[0:len(ave_act)],ave_act,color="blue",linewidth=5,label=model_titles[model_index])
              ax.plot(time_t[0:len(ave_sin)],ave_sin,color="orange",linewidth=5, label='SINDy')                    
              ax.axhline(y=.75, linestyle='dashed', linewidth=5, color='black', label='Threshold')
              ax.axhline(y=-.75, linestyle='dashed', linewidth=5, color='black', label='')
              ax.set_ylim(-.85,.75)
              #ax.set_xlim(0,2200)

            else:
#               ave_sin_threshold=np.where(ave_sin>1)
#               ave_sin=ave_sin[0:ave_sin_threshold-1]
              ax.plot(time_t[0:len(ave_act)],ave_act,color="blue",linewidth=5,label="Model")
              ax.plot(time_t[0:len(ave_sin)],ave_sin,color="orange",linewidth=5, label='SINDy')
              ax.axhline(y=1, linestyle='dashed', linewidth=5, color='black', label='Threshold')
              ax.axhline(y=-1, linestyle='dashed', linewidth=5, color='black', label='')
              ax.set_ylim(-1.1,1)
              #ax.set_xlim(0,1500)
        
        # Only add a legend to the first subplot for cleanliness
        if sample_index == 0 and model_index==0:
            ax.legend(loc='best', fontsize=30)

        # Set title for the first row of subplots
        if model_index == 0:
            ax.set_title(f'{sample_titles[sample_index]}')
            if sample_titles[sample_index]=="Medium" or sample_titles[sample_index]=="High" :
                ax.set_xlim(-.85,75)

        
        # Labeling the rows with the model names
        if sample_index == 0:
            ax.set_ylabel(model_axis[model_index])
            
        sns.despine(ax=ax)
        sns.despine(ax=ax)
        ax.spines['left'].set_linewidth(7)
        ax.spines['bottom'].set_linewidth(7)

fig.text(.025, .95, 'A', ha='center', va='center',zorder=50,fontsize=48,fontweight='bold' )
fig.text(.025, .71, 'B', ha='center', va='center',zorder=50,fontsize=48,fontweight='bold' )
fig.text(.025, .48, 'C', ha='center', va='center',zorder=50,fontsize=48,fontweight='bold' )
fig.text(.025, .255, 'D', ha='center', va='center',zorder=50,fontsize=48,fontweight='bold' )
fig.text(.85, .81, 'DDM', ha='center', va='center',zorder=50)
fig.text(.86, .575, 'LCA-DDM', ha='center', va='center',zorder=50)
fig.text(.85, .35, 'LCA', ha='center', va='center',zorder=50)
fig.text(.85, .12, 'NLB', ha='center', va='center',zorder=50)
# Adjust layout to prevent overlap and ensure clarity
fig.text(0.0,0.525,'Decision Varaible', ha='center', va='center',rotation='vertical',fontsize=48,zorder=50)
fig.supxlabel('Time (a.u.)',fontsize=48,)        
plt.tight_layout()
# plt.savefig("fig_average_activity_02_05.pdf", dpi=600,bbox_inches='tight')

# Supplementary

In [None]:
####supplementary figure data

ddm_acc_poly=confidenceInterval(ave_ddm_poly1,1,10000)
ddm_acc_poly=[np.array(ddm_acc_poly[i])*100 for i in range(len(ddm_acc_poly))]
ddm_accsin_poly=confidenceInterval(ave_ddm_poly1,2,10000)
ddm_accsin_poly=[np.array(ddm_accsin_poly[i])*100 for i in range(len(ddm_accsin_poly))]
c_ddm_acc_poly,p_ddm_acc_poly,sig_points_acc_ddm_poly,effect_ave_ddm_poly=ks_statistical_analysis(ave_ddm_poly1,1,2,x_ddm)
ddm_time_poly=confidenceInterval(ave_ddm_poly1,3,10000)
ddm_time_poly=[np.array(ddm_time_poly[i]) for i in range(len(ddm_time_poly))]
ddm_timesin_poly=confidenceInterval(ave_ddm_poly1,4,10000)
ddm_timesin_poly=[np.array(ddm_timesin_poly[i]) for i in range(len(ddm_timesin_poly))]
c_ddm_time_poly,p_ddm_time_poly,sig_points_time_ddm_poly,effect_ave_ddm_t_poly=ks_statistical_analysis(ave_ddm_poly1,3,4,x_ddm)

ddm_acc_poly2=confidenceInterval(ave_ddm_poly2,1,10000)
ddm_acc_poly2=[np.array(ddm_acc_poly2[i])*100 for i in range(len(ddm_acc_poly2))]
ddm_accsin_poly2=confidenceInterval(ave_ddm_poly2,2,10000)
ddm_accsin_poly2=[np.array(ddm_accsin_poly2[i])*100 for i in range(len(ddm_accsin_poly2))]
c_ddm_acc_poly2,p_ddm_acc_poly2,sig_points_acc_ddm_poly2,effect_ave_ddm_poly2=ks_statistical_analysis(ave_ddm_poly2,1,2,x_ddm)
ddm_time_poly2=confidenceInterval(ave_ddm_poly2,3,10000)
ddm_time_poly2=[np.array(ddm_time_poly2[i]) for i in range(len(ddm_time_poly2))]
ddm_timesin_poly2=confidenceInterval(ave_ddm_poly2,3,10000)
ddm_timesin_poly2=[np.array(ddm_timesin_poly2[i]) for i in range(len(ddm_timesin_poly2))]
c_ddm_time_poly2,p_ddm_time_poly2,sig_points_time_ddm_poly2,effect_ave_ddm_t_poly2=ks_statistical_analysis(ave_ddm_poly2,3,4,x_ddm)

lcaddm_acc_poly=confidenceInterval(ave_lcaddm_poly0,1,10000)
lcaddm_acc_poly=[np.array(lcaddm_acc_poly[i])*100 for i in range(len(lcaddm_acc_poly))]
lcaddm_accsin_poly=confidenceInterval(ave_lcaddm_poly0,2,10000)
lcaddm_accsin_poly=[np.array(lcaddm_accsin_poly[i])*100 for i in range(len(lcaddm_accsin_poly))]
c_lcaddm_acc_poly,p_lcaddm_acc_poly,sig_points_acc_lcaddm_poly,effect_ave_lcaddm_poly=ks_statistical_analysis(ave_lcaddm_poly0,1,2,x_lcaddm)
lcaddm_time_poly=confidenceInterval(ave_lcaddm_poly0,3,10000)
lcaddm_time_poly=[np.array(lcaddm_time_poly[i]) for i in range(len(lcaddm_time_poly))]
lcaddm_timesin_poly=confidenceInterval(ave_lcaddm_poly0,4,10000)
lcaddm_timesin_poly=[np.array(lcaddm_timesin_poly[i]) for i in range(len(lcaddm_timesin_poly))]
c_lcaddm_time_poly,p_lcaddm_time_poly,sig_points_time_lcaddm_poly,effect_ave_lcaddm_t_poly=ks_statistical_analysis(ave_lcaddm_poly0,3,4,x_lcaddm)

lcaddm_acc_poly2=confidenceInterval(ave_lcaddm_poly2,1,10000)
lcaddm_acc_poly2=[np.array(lcaddm_acc_poly2[i])*100 for i in range(len(lcaddm_acc_poly2))]
lcaddm_accsin_poly2=confidenceInterval(ave_lcaddm_poly2,2,10000)
lcaddm_accsin_poly2=[np.array(lcaddm_accsin_poly2[i])*100 for i in range(len(lcaddm_accsin_poly2))]
c_lcaddm_acc_poly2,p_lcaddm_acc_poly2,sig_points_acc_lcaddm_poly2,effect_ave_lcaddm_poly2=ks_statistical_analysis(ave_lcaddm_poly2,1,2,x_lcaddm)
lcaddm_time_poly2=confidenceInterval(ave_lcaddm_poly2,3,10000)
lcaddm_time_poly2=[np.array(lcaddm_time_poly2[i]) for i in range(len(lcaddm_time_poly2))]
lcaddm_timesin_poly2=confidenceInterval(ave_lcaddm_poly2,4,10000)
lcaddm_timesin_poly2=[np.array(lcaddm_timesin_poly2[i]) for i in range(len(lcaddm_timesin_poly2))]
c_lcaddm_time_poly2,p_lcaddm_time_poly2,sig_points_time_lcaddm_poly2,effect_ave_lcaddm_t_poly2=ks_statistical_analysis(ave_lcaddm_poly2,3,4,x_lcaddm)

lca_acc_poly=confidenceInterval(ave_lca_poly0,1,10000)
lca_acc_poly=[np.array(lca_acc_poly[i])*100 for i in range(len(lca_acc_poly))]
lca_accsin_poly=confidenceInterval(ave_lca_poly0,2,10000)
lca_accsin_poly=[np.array(lca_accsin_poly[i])*100 for i in range(len(lca_accsin_poly))]
c_lca_acc_poly,p_lca_acc_poly,sig_points_acc_lca_poly,effect_ave_lca_poly=ks_statistical_analysis(ave_lca_poly0,1,2,x_lca)
lca_time_poly=confidenceInterval(ave_lca_poly0,3,10000)
lca_time_poly=[np.array(lca_time_poly[i]) for i in range(len(lca_time_poly))]
lca_timesin_poly=confidenceInterval(ave_lca_poly0,4,10000)
lca_timesin_poly=[np.array(lca_timesin_poly[i]) for i in range(len(lca_timesin_poly))]
c_lca_time_poly,p_lca_time_poly,sig_points_time_lca_poly,effect_ave_lca_t_poly=ks_statistical_analysis(ave_lca_poly0,3,4,x_lca)

lca_acc_poly2=confidenceInterval(ave_lca_poly2,1,10000)
lca_acc_poly2=[np.array(lca_acc_poly2[i])*100 for i in range(len(lca_acc_poly2))]
lca_accsin_poly2=confidenceInterval(ave_lca_poly2,2,10000)
lca_accsin_poly2=[np.array(lca_accsin_poly2[i])*100 for i in range(len(lca_accsin_poly2))]
c_lca_acc_poly2,p_lca_acc_poly2,sig_points_acc_lca_poly2,effect_ave_lca_poly2=ks_statistical_analysis(ave_lca_poly2,1,2,x_lca)
lca_time_poly2=confidenceInterval(ave_lca_poly2,3,10000)
lca_time_poly2=[np.array(lca_time_poly2[i])for i in range(len(lca_time_poly2))]
lca_timesin_poly2=confidenceInterval(ave_lca_poly2,4,10000)
lca_timesin_poly2=[np.array(lca_timesin_poly2[i]) for i in range(len(lca_timesin_poly2))]
c_lca_time_poly2,p_lca_time_poly2,sig_points_time_lca_poly2,effect_ave_lca_t_poly2=ks_statistical_analysis(ave_lca_poly2,3,4,x_lca)

nlb_acc_poly=confidenceInterval(ave_nlb_poly4,1,10000)
nlb_acc_poly=[np.array(nlb_acc_poly[i])*100 for i in range(len(nlb_acc_poly))]
nlb_accsin_poly=confidenceInterval(ave_nlb_poly4,2,10000)
nlb_accsin_poly=[np.array(nlb_accsin_poly[i])*100 for i in range(len(nlb_accsin_poly))]
c_nlb_acc_poly,p_nlb_acc_poly,sig_points_acc_nlb_poly,effect_ave_nlb_poly=ks_statistical_analysis(ave_nlb_poly4,1,2,x_nlb)
nlb_time_poly=confidenceInterval(ave_nlb_poly4,3,10000)
nlb_time_poly=[np.array(nlb_time_poly[i]) for i in range(len(nlb_time_poly))]
nlb_timesin_poly=confidenceInterval(ave_nlb_poly4,4,10000)
nlb_timesin_poly=[np.array(nlb_timesin_poly[i]) for i in range(len(nlb_timesin_poly))]
c_nlb_time_poly,p_nlb_time_poly,sig_points_time_nlb_poly,effect_ave_nlb_t_poly=ks_statistical_analysis(ave_nlb_poly4,3,4,x_nlb)

nlb_acc_poly6=confidenceInterval(ave_nlb_poly6,1,10000)
nlb_acc_poly6=[np.array(nlb_acc_poly6[i])*100 for i in range(len(nlb_acc_poly6))]
nlb_accsin_poly6=confidenceInterval(ave_nlb_poly6,2,10000)
nlb_accsin_poly6=[np.array(nlb_accsin_poly6[i])*100 for i in range(len(nlb_accsin_poly6))]
c_nlb_acc_poly6,p_nlb_acc_poly6,sig_points_acc_nlb_poly6,effect_ave_nlb_poly6=ks_statistical_analysis(ave_nlb_poly6,1,2,x_nlb)
nlb_time_poly6=confidenceInterval(ave_nlb_poly6,3,10000)
nlb_time_poly6=[np.array(nlb_time_poly6[i]) for i in range(len(nlb_time_poly6))]
nlb_timesin_poly6=confidenceInterval(ave_nlb_poly6,4,10000)
nlb_timesin_poly6=[np.array(nlb_timesin_poly6[i]) for i in range(len(nlb_timesin_poly6))]
c_nlb_time_poly6,p_nlb_time_poly6,sig_points_time_nlb_poly6,effect_ave_nlb_t_poly6=ks_statistical_analysis(ave_nlb_poly6,3,4,x_nlb)

In [None]:
#fig supp#
sns.set(font_scale=4)
sns.set_style('white') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=36)     # fontsize of the axes title
plt.rc('axes', labelsize=48)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=48)    # fontsize of the tick labels
plt.rc('ytick', labelsize=48)    # fontsize of the tick labels
#plt.rc('legend', fontsize=16)    # legend fontsize
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rc('font', size=48)          # controls default text sizes
data_sets1 = [
{'label':'DDM Poly 1','x': x_ddm,'acc': ddm_acc_poly,'accsin': ddm_accsin_poly,'time': ddm_time_poly,'timesin': ddm_timesin_poly,'sig_points_acc': sig_points_acc_ddm_poly,'sig_points_time': sig_points_time_ddm_poly,'effect_ave_acc': effect_ave_ddm_poly,'effect_ave_time': effect_ave_ddm_t_poly},
{'label':'DDM Poly 2','x': x_ddm,'acc': ddm_acc_poly2,'accsin': ddm_accsin_poly2,'time': ddm_time_poly2,'timesin': ddm_timesin_poly2,'sig_points_acc': sig_points_acc_ddm_poly2,'sig_points_time': sig_points_time_ddm_poly2,'effect_ave_acc': effect_ave_ddm_poly2,'effect_ave_time': effect_ave_ddm_t_poly2},
{'label':'LCA-DDM Poly 0', 'x': x_lcaddm,'acc': lcaddm_acc_poly,'accsin': lcaddm_accsin_poly,'time': lcaddm_time_poly,'timesin': lcaddm_timesin_poly,'sig_points_acc': sig_points_acc_lcaddm_poly,'sig_points_time': sig_points_time_lcaddm_poly,'effect_ave_acc': effect_ave_lcaddm_poly,'effect_ave_time': effect_ave_lcaddm_t_poly},
{'label':'LCA-DDM Poly 2 ', 'x': x_lcaddm,'acc': lcaddm_acc_poly2,'accsin': lcaddm_accsin_poly2,'time': lcaddm_time_poly2,'timesin': lcaddm_timesin_poly2,'sig_points_acc': sig_points_acc_lcaddm_poly2,'sig_points_time': sig_points_time_lcaddm_poly2,'effect_ave_acc': effect_ave_lcaddm_poly2,'effect_ave_time': effect_ave_lcaddm_t_poly2},
{'label':'LCA Poly 0', 'x': x_lca,'acc': lca_acc_poly,'accsin': lca_accsin_poly,'time': lca_time_poly,'timesin': lca_timesin_poly,'sig_points_acc': sig_points_acc_lca_poly,'sig_points_time': sig_points_time_lca_poly,'effect_ave_acc': effect_ave_lca_poly,'effect_ave_time': effect_ave_lca_t_poly},
{'label':'LCA Poly 0', 'x': x_lca,'acc': lca_acc_poly2,'accsin': lca_accsin_poly2,'time': lca_time_poly2,'timesin': lca_timesin_poly2,'sig_points_acc': sig_points_acc_lca_poly2,'sig_points_time': sig_points_time_lca_poly2,'effect_ave_acc': effect_ave_lca_poly2,'effect_ave_time': effect_ave_lca_t_poly2},
{'label':'NLB Poly 4','x': x_nlb,'acc': nlb_acc_poly,'accsin': nlb_accsin_poly,'time': nlb_time_poly,'timesin': nlb_timesin_poly,'sig_points_acc': sig_points_acc_nlb_poly,'sig_points_time': sig_points_time_nlb_poly,'effect_ave_acc': effect_ave_nlb_poly,'effect_ave_time': effect_ave_nlb_t_poly},
{'label':'NLB Poly 6','x': x_nlb,'acc': nlb_acc_poly6,'accsin': nlb_accsin_poly6,'time': nlb_time_poly6,'timesin': nlb_timesin_poly6,'sig_points_acc': sig_points_acc_nlb_poly6,'sig_points_time': sig_points_time_nlb_poly6,'effect_ave_acc': effect_ave_nlb_poly6,'effect_ave_time': effect_ave_nlb_t_poly6}] 

# Define your figure and gridspec
fig = plt.figure(figsize=(35, 35))  # Adjust the figure size as needed
gs = gridspec.GridSpec(8, 2, wspace=0.15,hspace=.65)  # 6 rows, 2 columns

# Assuming data_sets1 is defined as provided
for row, model_data in enumerate(data_sets1):  # Iterate over each dataset in data_sets1
    # Accessing data for the current model
    x = model_data['x']
    acc = model_data['acc']
    accsin = model_data['accsin']
    time = model_data['time']
    timesin = model_data['timesin']
    sig_points_acc = model_data['sig_points_acc']
    effect_st_acc = model_data['effect_ave_acc']
    sig_points_time = model_data['sig_points_time']
    effect_st_time = model_data['effect_ave_time']
    label = model_data['label']

    # Second plot: Accuracy Plot
    ax1 = fig.add_subplot(gs[row, 0])
    # Plotting accuracy data
    #ax1.plot(sig_points_acc, np.full(len(sig_points_acc), 40), '*', color='black')
    ax1.plot(x, acc[2], '', color='blue', label="Model", linewidth=4)
    ax1.plot(x, acc[0], color='blue', linewidth=3)
    ax1.plot(x, acc[1], color='blue', linewidth=3)
    ax1.fill_between(x, acc[0], acc[1], color='blue', alpha=0.25)
    ax1.plot(x, accsin[2], '', color='orange', label="SINDy", linewidth=4)
    ax1.plot(x, accsin[0], color='orange', linewidth=3)
    ax1.plot(x, accsin[1], color='orange', linewidth=3)
    ax1.fill_between(x, accsin[0], accsin[1], color='orange', alpha=0.25)
    
    # Third plot: Response Time Plot
    ax2 = fig.add_subplot(gs[row, 1])
    # Plotting response time data
    #ax2.plot(sig_points_time, np.full(len(sig_points_time), 0), '*', color='black')
    ax2.plot(x, time[2], '', color='blue', label="Model Response Time", linewidth=4)
    ax2.plot(x, time[0], color='blue', linewidth=3)
    ax2.plot(x, time[1], color='blue', linewidth=3)
    ax2.fill_between(x, time[0], time[1], color='blue', alpha=0.25)
    ax2.plot(x, timesin[2], '', color='orange', label="SINDy Response Time", linewidth=4)
    ax2.plot(x, timesin[0], color='orange', linewidth=3)
    ax2.plot(x, timesin[1], color='orange', linewidth=3)
    ax2.fill_between(x, timesin[0], timesin[1], color='orange', alpha=0.25)

    # Apply styles
    sns.despine(ax=ax1)
    sns.despine(ax=ax2)
    ax1.spines['left'].set_linewidth(7)
    ax1.spines['bottom'].set_linewidth(7)
    ax2.spines['left'].set_linewidth(7)
    ax2.spines['bottom'].set_linewidth(7)
    
    if row == 0:
#         ax1.legend(loc='best', fontsize=14)
        ax1.legend(loc=[0.5,1], fontsize=26)

# Add labels
fig.text(0.5, 0.1, 'Signal-to-noise ratio', ha='center', va='center')
fig.text(0.084, 0.5, 'Choice accuracy (%)', ha='center', va='center', rotation='vertical')
fig.text(0.5, 0.5, 'Decision time (a.u.)', ha='center', va='center', rotation='vertical')

# Add labels for the models
fig.text(.875, .85, 'DDM Poly 1', ha='center', va='center', zorder=50)
fig.text(.875, .755, 'DDM Poly 2 ', ha='center', va='center', zorder=50)
fig.text(.875, .665, 'LCA-DDM Poly 0 ', ha='center', va='center', zorder=50)
fig.text(.875, .575, 'LCA-DDM Poly 2 ', ha='center', va='center', zorder=50)
fig.text(.875, .465, 'LCA Poly 0', ha='center', va='center', zorder=50)
fig.text(.875, .351, 'LCA Poly 2', ha='center', va='center', zorder=50)
fig.text(.875, .256, 'NLB Poly 4', ha='center', va='center', zorder=50)
fig.text(.875, .156, 'NLB Poly 6', ha='center', va='center', zorder=50)
# plt.savefig("fig_supp_choice_behaviour.tiff", dpi=600,bbox_inches='tight')