# Cell Assembly Detection with CADopti

**Author:** Ambarish S. Ghatpande  
**Date:** 2024-09-02  
**TOC:** macro  
**TOC Levels:** 2

This Jupyter notebook provides a comprehensive implementation of the `CADopti` function, which is designed to detect cell assemblies in spike train data using a multi-step analysis process. The notebook leverages advanced statistical methods and parallel processing to efficiently identify and refine potential assemblies across different temporal resolutions.

## Overview

The notebook is divided into several sections, each corresponding to a key step in the cell assembly detection pipeline:

1. **Initialization and Preprocessing**: This section initializes the parameters and preprocesses the input spike train data by removing NaNs and padding the spike times to ensure consistent lengths across all neurons.

2. **Matrix Binning**: Spike train data is binned at various temporal resolutions, and the number of pairwise comparisons required for assembly detection is calculated.

3. **First Order Assembly Detection**: The notebook identifies potential first-order assemblies by evaluating pairwise neuron interactions using parallel processing. A Holm-Bonferroni correction is applied to control for multiple comparisons.

4. **Higher Order Assembly Detection**: This section extends the assembly detection process to identify higher-order assemblies. The notebook iteratively adds new neurons to existing assemblies and prunes less significant assemblies to refine the results.

5. **Assembly Refinement and Final Output**: The identified assemblies are further refined by comparing them across different temporal resolutions, and the final assembly structures are outputted along with their associated statistical parameters.

## Prerequisites

This notebook assumes familiarity with Python programming, spike train analysis, and statistical methods for multiple comparisons. It also requires the following Python libraries:

- `numpy`
- `scipy`
- `multiprocessing`
- `itertools`

## How to Use

Simply run the cells in sequence to execute the `CADopti` function on your spike train data. The function's output will include the detected assemblies and associated statistical information.

For further details on the `CADopti` function and the methods used in this notebook, refer to the relevant sections in the code and the accompanying comments.

## Contact Information

For any questions or further assistance, please contact **A.S. Ghatpande** at `aghatpande@gmail.com`.


In [8]:
# First Cell: Importing the necessary libraries
import numpy as np
from scipy import stats
import multiprocessing as mp
from itertools import combinations

# Import helper functions
from FindAssemblies_recursive_prepruned import find_assemblies_recursive_prepruned
from TestPair_ref import test_pair_ref
from assemblies_across_bins import assemblies_across_bins

In [6]:
# Second Cell: Defining the process_pair function
def process_pair(args):
    w1, w2, binM, MaxLags, BinSizes, Dc, ref_lag = args
    assemblybin = [None] * len(BinSizes)
    p_by_bin = []
    for gg in range(len(BinSizes)):
        assemblybin[gg] = find_assemblies_recursive_prepruned(
            np.vstack((binM[gg][w1, :], binM[gg][w2, :])),
            w1, w2, MaxLags[gg], Dc, ref_lag
        )
        if assemblybin[gg] is not None:
            p_by_bin.append(assemblybin[gg]['pr'][-1])
            assemblybin[gg]['bin'] = BinSizes[gg]
        else:
            print(f"find_assemblies_recursive_prepruned returned None for w1={w1}, w2={w2}, binSize={BinSizes[gg]}")
            p_by_bin.append(float('inf'))  # Assign a high p-value if the result is None

    b = np.argmin(p_by_bin)
    return assemblybin[b], p_by_bin[b]


In [9]:
# Third Cell: Defining the CADopti function
def CADopti(spike_times, MaxLags, BinSizes, ref_lag=None, alph=None, No_th=None, O_th=None, bytelimit=None):
    # Function Documentation and Initialization of Default Parameters
    if ref_lag is None:
        ref_lag = 2
    if alph is None:
        alph = 0.05
    if No_th is None:
        No_th = 0  # no limitation on the number of assembly occurrences
    if O_th is None:
        O_th = float('inf')  # no limitation on the assembly order (=number of elements in the assembly)
    if bytelimit is None:
        bytelimit = float('inf')  # no limitation on assembly dimension

    # Initialize variables and pre-process spike_times
    nneu = len(spike_times)  # number of units
    testit = np.ones(len(BinSizes))
    binM = [None] * len(BinSizes)
    number_tests = 0

    # Remove NaNs and get valid spike times for each neuron
    spike_times = [neuron[~np.isnan(neuron)] for neuron in spike_times]
    
    # After removing NaNs, pad spike times to equal lengths
    max_spikes = max(len(spikes) for spikes in spike_times)
    padded_spike_times = [np.pad(spikes, (0, max_spikes - len(spikes)), 
                             mode='constant', constant_values=np.nan) 
                      for spikes in spike_times]

    # Check input data after processing for empty arrays after removing NaNs
    if not padded_spike_times or any(len(spikes) == 0 for spikes in padded_spike_times):
        raise ValueError("Invalid input: spike_times is empty or contains empty arrays after processing")
    
    # Calculate the minimum interval between spikes and overall min and max times
    int_val = np.min([np.min(np.diff(times)) for times in padded_spike_times if len(times) > 1])
    min_val = np.min([np.min(times) for times in padded_spike_times if len(times) > 0])
    max_val = np.max([np.max(times) for times in padded_spike_times if len(times) > 0])

    # Validations and checks
    if not np.isfinite(int_val):
        raise ValueError("Couldn't compute a valid inter-spike interval")
    
    if min_val >= max_val:
        raise ValueError(f"min value ({min_val}) is greater than or equal to max value ({max_val})")
    
    if int_val == 0:
        raise ValueError("int_val is zero")
    
    print(f"Minimum inter-spike interval: {int_val}")
    print(f"Min time: {min_val}, Max time: {max_val}")
    
    # matrix binning at all bins
    for gg in range(len(BinSizes)):
        bin_size = BinSizes[gg]
        tb = np.arange(min_val, max_val + bin_size, bin_size)
    
        binM[gg] = np.zeros((nneu, len(tb) - 1), dtype=np.uint8)
        number_tests += nneu * (nneu - 1) * (2 * MaxLags[gg] + 1) // 2
    
        for n in range(nneu):
            binM[gg][n, :], _ = np.histogram(padded_spike_times[n][~np.isnan(padded_spike_times[n])], tb)
        
        assembly = {'bin': [{'n': [], 'bin_edges': tb} for _ in range(len(BinSizes))]}
        
        if binM[gg].shape[1] - MaxLags[gg] < 100:
            print(f'Warning: testing bin size={int_val}. The time series is too short, consider taking a longer portion of spike train or diminish the bin size to be tested')
            testit[gg] = 0

    # Detecting First Order Assemblies
    print('order 1')
    Assemblies_all_orders = []
    O = 1
    Dc = 100  # length (in # bins) of the segments in which the spike train is divided to compute #abba variance (parameter k).

    assembly_selected_xy = []
    p_values = []

    # First order assembly
    print('order 1')
    assembly_selected_xy = []
    p_values = []
        
    # Prepare arguments for parallel processing
    pair_args = [
            (w1, w2, binM, MaxLags, BinSizes, Dc, ref_lag)
            for w1, w2 in combinations(range(nneu), 2)
        ]
        
    # Use multiprocessing to parallelize the computation
    with mp.Pool() as pool:
            results = pool.map(process_pair, pair_args)
        
    # Process the results
    for result, p_value in results:
        if result is not None:
            assembly_selected_xy.append(result)
            p_values.append(p_value)
        
    if not assembly_selected_xy:
        raise ValueError("No valid assemblies found. Check the input data and parameters.")
        
    assembly_selected = assembly_selected_xy
        
    # Holm-Bonferroni correction
    x = np.arange(1, len(p_values) + 1)
    p_values = np.sort(p_values)
    p_values_alpha = alph / (number_tests + 1 - x)
        
    ANfo = np.zeros((nneu, nneu))
        
    # Initialize HBcorrected_p before using it
    aus = np.where((p_values - p_values_alpha) < 0)[0]
    HBcorrected_p = 0 if len(aus) == 0 else p_values[aus[-1]]
        
    for oo in range(len(assembly_selected) - 1, -1, -1):
        if assembly_selected[oo]['pr'][-1] > HBcorrected_p:
            assembly_selected.pop(oo)
        else:
            ANfo[assembly_selected[oo]['elements'][0], assembly_selected[oo]['elements'][1]] = 1
        
    Assemblies_all_orders = [assembly_selected]

    # Detecting Higher Order Assemblies
    # Higher orders
    Oincrement = 1
    while Oincrement and O < (O_th - 1):
        O += 1
        print(f'order {O}')
        Oincrement = 0
        assembly_selected_aus = []
        xx = 0  # Python uses 0-based indexing
        
        for w1 in range(len(assembly_selected)):
            # bin at which to test w1
            ggg = BinSizes.index(assembly_selected[w1]['bin'])
        
            # element to test with w1
            w1_elements = assembly_selected[w1]['elements']
            w2_to_test = np.where(ANfo[w1_elements, :] == 1)[1]  # Using numpy for efficiency
            w2_to_test = w2_to_test[~np.isin(w2_to_test, w1_elements)]  # Remove elements already in the assembly
            w2_to_test = np.unique(w2_to_test)
        
            for w2 in w2_to_test:
                spikeTrain2 = binM[ggg][w2, :]
                assemblybin_aus = test_pair_ref(assembly_selected[w1], spikeTrain2, w2, MaxLags[ggg], Dc, ref_lag)
                p_values.append(assemblybin_aus['pr'][-1])
                number_tests += 2 * MaxLags[ggg] + 1
                if assemblybin_aus['pr'][-1] < HBcorrected_p:
                    assembly_selected_aus.append(assemblybin_aus)
                    assembly_selected_aus[-1]['bin'] = BinSizes[ggg]
                    xx += 1
                    Oincrement = 1

        if Oincrement:
            # Pruning within the same size
            na = len(assembly_selected_aus)
            nelement = len(assembly_selected_aus[0]['elements'])
            selection = np.full((na, nelement + 2), np.nan)
            assembly_final = [None] * na
            nns = 0

            for i in range(na):
                elem = sorted(assembly_selected_aus[i]['elements'])
                ism = np.all(selection[:, :nelement] == elem, axis=1)
                if not np.any(ism):
                    assembly_final[nns] = assembly_selected_aus[i]
                    selection[nns, :nelement] = elem
                    selection[nns, nelement] = assembly_selected_aus[i]['pr'][-1]
                    selection[nns, nelement + 1] = i
                    nns += 1
                else:
                    indx = np.where(ism)[0][0]
                    if selection[indx, nelement] > assembly_selected_aus[i]['pr'][-1]:
                        assembly_final[indx] = assembly_selected_aus[i]
                        selection[indx, nelement] = assembly_selected_aus[i]['pr'][-1]
                        selection[indx, nelement + 1] = i

            assembly_final = [a for a in assembly_final if a is not None]
            assembly_selected = assembly_final
            Assemblies_all_orders.append(assembly_final)

        # Holm-Bonferroni correction
        x = np.arange(1, len(p_values) + 1)
        p_values = np.sort(p_values)
        p_values_alpha = alph / (number_tests + 1 - x)
        aus = np.where((p_values - p_values_alpha) < 0)[0]
        HBcorrected_p = 0 if len(aus) == 0 else p_values[aus[-1]]

        for o in range(len(Assemblies_all_orders)):
            Assemblies_all_orders[o] = [a for a in Assemblies_all_orders[o] if a['pr'][-1] <= HBcorrected_p]

        # Pruning between different assembly sizes
        Element_template = []
        for assembly in Assemblies_all_orders[-1]:
            Element_template.append(assembly['elements'])

        for o in range(len(Assemblies_all_orders) - 2, -1, -1):
            new_assemblies = []
            for assembly in Assemblies_all_orders[o]:
                if not any(set(assembly['elements']).issubset(set(template)) for template in Element_template):
                    new_assemblies.append(assembly)
                    Element_template.append(assembly['elements'])
            Assemblies_all_orders[o] = new_assemblies

        # Reformat dividing by bins
        assembly = {'bin': [{} for _ in range(len(BinSizes))]}
        for o, order_assemblies in enumerate(Assemblies_all_orders):
            for oo, a in enumerate(order_assemblies):
                bx = BinSizes.index(a['bin'])
                if 'n' not in assembly['bin'][bx]:
                    assembly['bin'][bx]['n'] = []
                assembly['bin'][bx]['n'].append(a)

        # Remove empty bins
        assembly['bin'] = [b for b in assembly['bin'] if b]

        # Add parameters to assembly
        assembly['parameters'] = {
            'alph': alph,
            'Dc': Dc,
            'No_th': No_th,
            'O_th': O_th,
            'bytelimit': bytelimit,
            'ref_lag': ref_lag
        }

    # Return statement inside the function
    As_across_bins, As_across_bins_index = assemblies_across_bins(assembly, BinSizes)
    return As_across_bins, As_across_bins_index, assembly, Assemblies_all_orders