## Demonstration of how to apply the latent_ensemble_detector on a neural dataset

Welcome! This notebook demonstrates how to use latent_ensembles_detector to detect neural ensembles with FastICA on a simple, controlled dataset.

We'll use a synthetic dataset: 100 simulated place cells recorded while an animal runs on a 1-meter linear track, with some overlapping place fields, 
generated with my companion repo place-cells-simulations (so everything is reproducible and we know the ground truth).

In [None]:
import latent_ensembles_detector as led
import json

In [None]:
# Load data 
loaded_data = led.load_data("../data/experiment_data.npz")
#print(data.files)
data = loaded_data[0]
spikes = data['spikes']  # shape (n_neurons, n_timepoints)
time = data['time']   # shape (n_timepoints,)
pos = data['pos']  # shape (n_timepoints,)
metadata = json.loads(data['meta'].item())  # dictionary with metadata

In [None]:
# OPTION 1: Run this cell if you DO NOT have a spike matrix yet.
# It takes the spike timepoints and the neuron IDs (aka clusters) and converts them into a binned spike matrix.

"""
# Define experimental parameters
start_time = 0  # in seconds
end_time = 600  # in seconds
sampling_rate = 30000  # in Hz

spike_matrix, _ =  led.compute_spike_matrix (spikeTimes = spike_times, spikeClusters = spike_clusters, time = time, start_time = start_time,  end_time = end_time, sampling_rate = sampling_rate)

"""

In [None]:
# OPTION 2: Run this cell if you ALREADY have a spike matrix, but need to re-bin it to a different time resolution.
# The original paper recommends time bins of 25 ms. 

spike_matrix = led.rebin_spikes(spike_matrix = spikes, old_dt = metadata["dt"], new_dt = 0.025)

print(spike_matrix.shape)  # shape (n_neurons, n_time_bins)

In [None]:
# Find neural ensembles and principal neurons in each ensemble

n_ensembles =  led.estimate_ensembles_number(spike_matrix = spike_matrix)
    
weights, _ =  led.perform_fastICA(n_ensembles = n_ensembles, spike_matrix = spike_matrix)

principalCells = led.find_principal_neurons (weights = weights)

print(principalCells)

In [None]:
# Plot heatmap of principal cells
led.plot_principal_cells_heatmap(principal_cells = principalCells, save_path = "results/principal_cells_heatmap.png")

In [None]:
# Save results
led.save_data("results/spike_matrix.npy", spike_matrix)
led.save_data("results/weights.npy", weights)
led.save_data("results/principal_cells.json", principalCells)