In [1]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
from scipy.io import loadmat
from sklearn.metrics import mean_squared_error
import pysindy as ps
import pandas as pd
from matplotlib import cm
import time
import pickle
import seaborn as sns
from statistics import mean 
from scipy.stats import ks_2samp
import matplotlib.gridspec as gridspec
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [2]:
def confidenceInterval(alist, ciColumn, trials):
    """
    Calculate confidence interval, mean, and standard deviation for a specified column in a list of datasets.
    
    Args:
    alist (list): List of datasets.
    ciColumn (int): Index of the column for confidence interval calculation.
    trials (int): Not used in the function body and could be removed or repurposed.
    
    Returns:
    tuple: lower and upper bounds of the confidence interval, mean, and standard deviation.
    """
    mean_1 = mean_of_list(alist, ciColumn)  # Calculate mean for each point
    std_1 = std_of_list(alist, ciColumn)  # Calculate standard deviation
    confidence_low, confidence_high = bootstrap_confidence_interval_column(alist, ciColumn)  # Get confidence interval
    return confidence_low, confidence_high, mean_1, std_1

In [3]:
def bootstrap_confidence_interval_column(data, ciColumn, num_bootstrap_samples=1000, confidence_level=0.90):
    """
    Generate bootstrap confidence intervals for a specified column across multiple datasets.
    
    Args:
    data (list): List of datasets where each dataset is a list or a numpy array.
    ciColumn (int): The index of the column for which to calculate the confidence interval.
    num_bootstrap_samples (int): Number of bootstrap samples to generate.
    confidence_level (float): Confidence level for the interval.
    
    Returns:
    tuple: Lists of lower bounds, upper bounds of the confidence intervals.
    """
    lower_bounds = []
    upper_bounds = []
    for dataset in data:
        column_data = np.array(dataset)[:, ciColumn]
        bootstrap_means = [np.mean(np.random.choice(column_data, size=len(column_data), replace=True)) for _ in range(num_bootstrap_samples)]
        alpha = (1 - confidence_level) / 2
        lower_bounds.append(np.percentile(bootstrap_means, 100 * alpha))
        upper_bounds.append(np.percentile(bootstrap_means, 100 * (1 - alpha)))
    return lower_bounds, upper_bounds

In [4]:
def replace_column(matrix, new_column, snr_index, column_index):
    """
    Replace a column in a matrix with a new column vector.
    
    Args:
        matrix (list of lists): The original matrix.
        new_column (list): The new column vector to replace the existing column.
        snr_index (int): Index for the sub-list in matrix where the replacement should occur.
        column_index (int): Index of the column to be replaced.
    
    Returns:
        list of lists: The matrix with the updated column.
    """
    if column_index < 0 or column_index >= len(matrix[0]):
        raise ValueError("Invalid column_index")
    if len(new_column) != len(matrix[snr_index][column_index]):
        raise ValueError("New column length must match the matrix height")
    for i in range(len(matrix[snr_index][column_index])):
        matrix[snr_index][column_index][i] = new_column[i]
    return matrix

In [5]:
def flatten(data, column, column_sin):
    """
    Flatten a specific column from a list of datasets. This function is intended to extract and flatten data
    from a structured dataset where data points are organized in columns.
    
    Args:
        data (list): List of datasets where each dataset is potentially a list of lists.
        column (int): Index of the column to extract and flatten.
        column_sin (int): Unused in this function but might be intended for future use to specify a second column.
    
    Returns:
        list: A flat list of values extracted from the specified column across all datasets.
    """
    flattened_data = []
    for i in range(len(data)):
        element = data[i][column]  # Extract the column for flattening
        flattened_data.append(element)
    flat_list = [item for sublist in flattened_data for item in sublist]  # Flatten the list of lists
    return flat_list

In [6]:
def normalise_decision_time(data,column_dt,column_dt_sindy,lcaddm=False):
    """
    Normalize decision times in the data according to the maximum and minimum values found across all data.
    
    Args:
    data (list): Data containing decision times.
    column_dt (int): Index of the decision time column in each sub-list.
    column_dt_sindy (int): Index of the decision time column in the SINDy model.
    lcaddm (bool): Flag to indicate special processing condition.
    
    Returns:
    list: Data with normalized decision times.
    """
    model_dt_flat=flatten(data,column_dt,column_dt_sindy)
    max_dt=max(model_dt_flat)
    min_dt=min(model_dt_flat)
    
    for i in range(len(data)):
        normalised=[]
        normalised_sindy=[]
        for c in range(len(data[i][column_dt])):
            if i==4 and c==9999 and lcaddm==True:
                norm_model_times = (mean(data[i][column_dt])- min_dt) / (max_dt - min_dt)
                norm_model_times_sindy = (mean(data[i][column_dt_sindy]) - min_dt) / (max_dt - min_dt)
                normalised.append(norm_model_times)
                normalised_sindy.append(norm_model_times_sindy)
            else:
                norm_model_times = (data[i][column_dt][c] - min_dt) / (max_dt - min_dt)
                norm_model_times_sindy = (data[i][column_dt_sindy][c] - min_dt) / (max_dt - min_dt)
                normalised.append(norm_model_times)
                normalised_sindy.append(norm_model_times_sindy)
        #print(len(normalised))
        data=replace_column(data, normalised, i, column_dt)
        data=replace_column(data, normalised_sindy, i, column_dt_sindy)
    
    return data

In [7]:
def cliffs_delta(x, y):
    """
    Compute Cliff's Delta, a measure of effect size, comparing two samples.
    
    Args:
    x (array): First sample.
    y (array): Second sample.
    
    Returns:
    float: Cliff's Delta value.
    """
    all_values = np.concatenate((x, y))
    ranks = np.argsort(all_values)
    rx = np.mean(ranks[:len(x)])
    ry = np.mean(ranks[len(x):])
    delta = (rx - ry) / len(all_values)
    return delta

In [8]:
def ks_statistical_analysis(data, column, column_sin, snr, no_column=False, data_sindy=None):
    """
    Perform the Kolmogorov-Smirnov test and compute Cliff's Delta to assess the differences
    between distributions in the provided datasets.
    
    Args:
        data (list): A list of datasets.
        column (int): Index of the primary column for comparison.
        column_sin (int): Index of the secondary column for comparison.
        snr (list): Signal-to-noise ratios or identifiers for the datasets.
        no_column (bool): Indicates if the data lacks structured columns.
        data_sindy (list, optional): Secondary dataset list for comparison.
    
    Returns:
        tuple: Contains D-statistics, p-values, significant points based on p-value,
               and effect sizes (Cliff's Delta).
    """
    p = np.empty(len(data))
    d_stat = np.empty(len(data))  # KS statistic
    effect_size = np.empty(len(data))  # cliff's d

    if no_column:
        for i in range(len(data)):
            group1 = data[i]
            group2 = data_sindy[i]
            d_stat[i], p[i] = ks_2samp(group1, group2)
            effect_size[i] = cliffs_delta(group1, group2)
    else:
        for i in range(len(data)):
            group1 = data[i][column]
            group2 = data[i][column_sin]
            d_stat[i], p[i] = ks_2samp(group1, group2)
            effect_size[i] = cliffs_delta(group1, group2)

    significant_points = [snr[i] if p[i] < .05 else None for i in range(len(p))]
    return d_stat, p, significant_points, effect_size

In [9]:
def mean_of_list(alist, column, no_column=False):
    """
    Calculate the mean of a specified column in a list of datasets, or of each dataset if no_column is True.
    
    Args:
        alist (list): List of datasets, where each dataset can be a list of lists or a list of numbers.
        column (int): Index of the column to compute the mean for.
        no_column (bool): If True, computes the mean of the entire dataset (assumed to be flat).
    
    Returns:
        list: List of means for each dataset.
    """
    if no_column:
        return [mean(dataset) for dataset in alist]
    else:
        return [mean(dataset[column]) for dataset in alist]

In [10]:
def std_of_list(alist, column, no_column=False):
    """
    Calculate the standard deviation of a specified column in a list of datasets, or of each dataset if no_column is True.
    
    Args:
        alist (list): List of datasets.
        column (int): Index of the column to compute the standard deviation for.
        no_column (bool): If True, computes the standard deviation of the entire dataset (assumed to be flat).
    
    Returns:
        list: List of standard deviations for each dataset.
    """
    if no_column:
        return [std(dataset) for dataset in alist]
    else:
        return [std(dataset[column]) for dataset in alist]

In [11]:
def average_activity_multidimensional(data, trial_threshold=100):
    """
    Calculate the average activity across multiple populations in multidimensional data, truncated to the last valid
    time point based on a minimum trial threshold.

    Args:
        data (list of list of lists): Data containing multiple populations over multiple samples, structured as a list
                                      of samples, where each sample contains lists of population activity data.
        trial_threshold (int): The minimum number of valid trials required for a time point to be considered in the average.

    Returns:
        list: A list of averages for each population, truncated to the minimum last valid time point across populations.
    """
    # Determine the number of populations and the maximum time length
    num_populations = len(data[0])
    max_length = max(max(len(population) for population in sample) for sample in data)
    
    # Initialize an array to store padded data
    padded_data = []
    for sample in data:
        padded_sample = [np.pad(population, (0, max_length - len(population)), 'constant', constant_values=np.nan) for population in sample]
        padded_data.append(padded_sample)
    
    # Convert the list of padded data to a numpy array for easy mean calculation
    padded_data_np = np.array(padded_data)
    average_activity = np.nanmean(padded_data_np, axis=0)
    
    # Calculate valid counts and determine the last valid index based on the trial threshold
    valid_counts = np.sum(~np.isnan(padded_data_np), axis=0)
    last_valid_indices = np.max(np.where(valid_counts >= trial_threshold, np.arange(valid_counts.shape[1]), 0), axis=1)
    min_last_valid_index = np.min(last_valid_indices)
    
    # Truncate the average activity to the last valid index
    truncated_averages = average_activity[:, :min_last_valid_index + 1]
    return truncated_averages.tolist()

In [12]:
def average_activity(data, trial_threshold=100):
    """
    Calculate the average activity across data points, truncated to the last valid time point based on a minimum
    trial threshold.

    Args:
        data (list of lists): Data containing time series or trials from multiple entities.
        trial_threshold (int): Minimum number of trials required to consider a time point valid.

    Returns:
        list: Average activity for the data, truncated to the last valid time point where the trial count exceeds the threshold.
    """
    # Find the maximum length across all data points for padding
    max_length = max(len(item) for item in data)
    
    # Pad each data point with NaNs to the maximum length
    padded_data = np.array([np.pad(item, (0, max_length - len(item)), 'constant', constant_values=np.nan) for item in data])
    
    # Calculate the mean, ignoring NaNs for an accurate calculation
    average_activity = np.nanmean(padded_data, axis=0)
    
    # Count the number of non-NaN values at each time point
    valid_counts = np.sum(~np.isnan(padded_data), axis=0)
    
    # Determine the last valid time point based on the trial threshold
    last_valid_index = np.max(np.where(valid_counts >= trial_threshold)[0])
    
    # Truncate the average activity to this last valid index
    truncated_average_activity = average_activity[:last_valid_index + 1]
    return truncated_average_activity