In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio

# Import the CADopti function and other necessary functions
from assembly_assignment_matrix import assembly_assignment_matrix
from assembly_activity_function import assembly_activity_function

from importnb import Notebook
with Notebook():
    from CADopti_notebook import CADopti


In [2]:
# Load data from the .mat file
mat_data = sio.loadmat('Data.mat')

# Uncomment the following lines if you need to troubleshoot loading the data
# print("Keys in Data.mat:", mat_data.keys())
# print("spM shape:", mat_data['spM'].shape)
# print("spM data type:", mat_data['spM'].dtype)
# print("Sample of spM (first 5 rows, first 10 columns):")
# print(mat_data['spM'][:5, :10])

# Assign spM correctly
spM = mat_data['spM']

# Convert spM to a list of spike times, removing NaNs
spike_times = [row[~np.isnan(row)] for row in spM]

# Print some statistics to check if the data is correct
print("Number of neurons:", len(spike_times))
print("Number of spikes per neuron:")
print([len(spikes) for spikes in spike_times])

Number of neurons: 50
Number of spikes per neuron:
[6722, 6551, 6631, 6631, 6765, 6679, 6617, 6594, 6568, 6599, 7061, 7301, 6727, 7350, 6899, 7746, 7616, 7695, 7723, 7747, 8945, 9033, 9134, 9133, 9003, 6403, 6335, 6350, 6218, 6479, 6333, 6458, 6383, 6324, 6392, 6310, 6332, 6457, 6306, 6372, 6476, 6272, 6305, 6334, 6332, 6367, 6291, 6357, 6289, 6374]


In [4]:
# Check if there are any empty spike trains
empty_trains = [i for i, spikes in enumerate(spike_times) if len(spikes) == 0]
if empty_trains:
    print(f"Empty spike trains found at indices: {empty_trains}")

# Check for NaNs, Infinities, and Identical Values
for i, spikes in enumerate(spike_times):
    if np.isnan(spikes).any():
        print(f"NaNs found in spike train {i}")
    if np.isinf(spikes).any():
        print(f"Infinities found in spike train {i}")
    if len(np.unique(spikes)) == 1:
        print(f"Identical spike times found in spike train {i}: {spikes}")

In [5]:
# Calculate and print inter-spike intervals for each spike train
for i, spikes in enumerate(spike_times):
    if len(spikes) > 1:
        inter_spike_intervals = np.diff(spikes)
        print(f"Spike train {i} inter-spike intervals (min, max, mean): {np.min(inter_spike_intervals)}, {np.max(inter_spike_intervals)}, {np.mean(inter_spike_intervals)}")
    else:
        print(f"Spike train {i} has fewer than 2 spikes, cannot compute intervals.")


Spike train 0 inter-spike intervals (min, max, mean): 0.006989250650462964, 1.575759369102741, 0.2033603596023709
Spike train 1 inter-spike intervals (min, max, mean): 0.006989250650462964, 1.880338923038238, 0.208710458019369
Spike train 2 inter-spike intervals (min, max, mean): 0.006989250650462964, 2.0024313813671597, 0.20611428288051573
Spike train 3 inter-spike intervals (min, max, mean): 0.006989250650462964, 1.5707407287889055, 0.20614927576253922
Spike train 4 inter-spike intervals (min, max, mean): 0.006989250650462964, 1.922688834340903, 0.2021202567704867
Spike train 5 inter-spike intervals (min, max, mean): 0.007057196032974389, 1.9981355697564993, 0.20467791812197883
Spike train 6 inter-spike intervals (min, max, mean): 0.007057196032974389, 1.8155607412246582, 0.20660231871383578
Spike train 7 inter-spike intervals (min, max, mean): 0.007057196032974389, 2.2983896190800124, 0.2072584084052289
Spike train 8 inter-spike intervals (min, max, mean): 0.007057196032974389, 1.78

In [3]:
try:
    # Define parameters
    MaxLags = [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]  # 10 is the maximum lag for each bin size
    BinSizes = [0.015, 0.025, 0.04, 0.06, 0.085, 0.15, 0.25, 0.4, 0.6, 0.85, 1.5]

    # Call CADopti function
    As_across_bins, As_across_bins_index, assembly = CADopti(spM, MaxLags, BinSizes)

    if As_across_bins is None:
        raise ValueError("No valid assemblies found. Check the input data and parameters.")
    
    # Visualization
    nneu = len(spike_times)  # number of recorded units
    display = 'raw'  # or 'clustered'
    Amatrix, Binvector, Unit_order, As_order = assembly_assignment_matrix(As_across_bins, nneu, BinSizes, display)

    plt.figure()

    plt.imshow(Amatrix)

    plt.title('Assembly Assignment Matrix')

    plt.xlabel('Assembly')

    plt.ylabel('Neuron')

    plt.colorbar(label='Assignment')

    plt.savefig('assembly_assignment_matrix.png')

    plt.close()

    # Assembly Activation
    lagChoice = 'duration'
    act_count = 'full'
    assembly_activity = assembly_activity_function(As_across_bins, assembly, spike_times, BinSizes, lagChoice, act_count)

    plt.figure(figsize=(10, 10))
    for i, activity in enumerate(assembly_activity):
        plt.subplot(len(assembly_activity), 1, i+1)
        plt.plot(activity[:, 0], activity[:, 1])
        plt.title(f'Assembly {i+1} Activity')
        plt.xlabel('Time')
        plt.ylabel('Activity')
    plt.tight_layout()
    plt.savefig('assembly_activity.png')
    plt.close()

    # Print summary of detected assemblies
    print("\nDetected Assemblies:")
    for i, assembly in enumerate(As_across_bins):
        print(f"Assembly {i+1}:")
        print(f"  Elements: {assembly['elements']}")
        print(f"  Bin size: {assembly['bin']}")
        print(f"  Lags: {assembly['lag']}")
        print(f"  p-values: {assembly['pr']}")
        print(f"  Occurrences: {assembly['Noccurrences']}")
        print()
except Exception as e:
    print(f"An error occurred: {e}")

An error occurred: Couldn't compute a valid inter-spike interval
