# Test Phylib
### This is just a playground to test out different functions and features of the `phylib` library, to better understand it
### It seems to be able to do similar things to `multirecording_spikeanalysis.py`, but not as directly and requires the entire phy folder, including the recording.dat

In [1]:
import pandas as pd
import numpy as np

import sys
import matplotlib.pyplot as plt
from phylib.io.model import load_model
from phylib.utils.color import selected_cluster_color

`phylib` is a library used by Phy

In [2]:
# Read the TSV file
cluster_info = pd.read_csv(r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\cluster_info.tsv', sep='\t')

good_clusters = np.intersect1d(
    cluster_info['cluster_id'][cluster_info['group'] == 'good'],
    cluster_info['cluster_id'][cluster_info['fr'] > 0.5]
)

`cluster_info.tsv` is essentially just a csv of the table you see when you open Phy

In [3]:
# Directly specify the path to params.py
params_path = r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\params.py'

# Load the TemplateModel
model = load_model(params_path)

`params.py` just points to `recording.dat`, and defines the number of channels (32) and sample rate (20 kHz), but is what phylib uses to be able to load a recording with `load_model()`

In [4]:
cluster_waveforms = {}

for cluster_id in good_clusters:
    # Get the waveforms for the cluster
    waveforms = model.get_cluster_spike_waveforms(cluster_id)
    
    # Get the channel IDs and find the best channel (assuming it's the first one returned by get_cluster_channels)
    channel_ids = model.get_cluster_channels(cluster_id)
    best_channel = channel_ids[0]
    
    # Get the waveform for just the best channel. Since the data is already in 40 time points, no subsampling is needed
    best_channel_waveforms = waveforms[:, :, 0]  # Since channel_ids[0] is the best channel
    # Store the waveforms in the dictionary. 
    cluster_waveforms[cluster_id] = best_channel_waveforms 

This gets the waveforms for only the channels marked "good" on Phy, and `get_cluster_spike_waveforms()` returns a 3-dimensional array that includes all 40 timepoint waveform for every spike for every channel, but channel[0] is always the best channel

In [5]:
cluster_waveforms

{9: array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
          0.89965886,  0.1729701 ],
        [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
          1.6655885 ,  1.0949929 ],
        [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
          0.83251476,  1.9278888 ],
        ...,
        [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
          3.7053027 ,  0.17694335],
        [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
          0.7814573 ,  0.6139724 ],
        [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
          1.1730937 ,  1.1040318 ]], dtype=float32),
 13: array([[0.74050546, 1.1471487 , 1.1523566 , ..., 2.8801389 , 2.9853294 ,
         3.0963612 ],
        [0.33376965, 0.1129124 , 0.25598168, ..., 1.8844459 , 2.5375013 ,
         3.7338994 ],
        [3.7924511 , 3.189696  , 2.8261626 , ..., 5.365091  , 5.273031  ,
         4.5386353 ],
        ...,
        [1.6787068 , 2.0034997 , 2.0164783 , ..., 

In [6]:
# Assuming your dictionary is named `cluster_waveforms`
# Initialize an empty list to collect waveforms
all_waveforms = []

# Iterate through each unit/spike_cluster in the dictionary
for unit, spikes_3d in cluster_waveforms.items():
    # Check the shape of the current array
    if spikes_3d.ndim == 3:
        # Extract the waveforms from the first channel (index 0)
        waveforms_first_channel = spikes_3d[:, :, 0]  # Shape: (number_of_spikes, timepoints)
        
        # Append these waveforms to the list
        all_waveforms.append(waveforms_first_channel)
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes_3d.shape}")

# Combine all collected waveforms into a single 2D Numpy array
if all_waveforms:
    W = np.vstack(all_waveforms)  # Shape: (total_number_of_spikes, timepoints)
    print(f"Combined waveform shape: {W.shape}")
else:
    print("No valid waveforms to process.")

# Now W is ready to be used for further processing, such as calculating SNR

Skipping unit 9 due to unexpected shape: (13768, 40)
Skipping unit 13 due to unexpected shape: (6260, 40)
Skipping unit 18 due to unexpected shape: (78443, 40)
Skipping unit 20 due to unexpected shape: (77825, 40)
Skipping unit 115 due to unexpected shape: (13468, 40)
Skipping unit 117 due to unexpected shape: (1691, 40)
Skipping unit 123 due to unexpected shape: (33886, 40)
Skipping unit 150 due to unexpected shape: (5366, 40)
Skipping unit 154 due to unexpected shape: (29374, 40)
Skipping unit 190 due to unexpected shape: (2437, 40)
No valid waveforms to process.


In [7]:
# Assuming your dictionary is named `cluster_waveforms`
# Initialize an empty list to collect waveforms
all_waveforms = []

# Iterate through each unit/spike_cluster in the dictionary
for unit, spikes in cluster_waveforms.items():
    if spikes.ndim == 2:
        # If the array is 2D, use it directly
        all_waveforms.append(spikes)
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes.shape}")

# Combine all collected waveforms into a single 2D Numpy array
if all_waveforms:
    W = np.vstack(all_waveforms)  # Shape: (total_number_of_spikes, timepoints)
    print(f"Combined waveform shape: {W.shape}")
else:
    print("No valid waveforms to process.")

# Now W is ready to be used for further processing, such as calculating SNR

Combined waveform shape: (262518, 40)


In [8]:
len(all_waveforms)

10

In [9]:
len(all_waveforms[0])

13768

In [10]:
len(cluster_waveforms[9])

13768

In [11]:
all_waveforms

[array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
          0.89965886,  0.1729701 ],
        [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
          1.6655885 ,  1.0949929 ],
        [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
          0.83251476,  1.9278888 ],
        ...,
        [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
          3.7053027 ,  0.17694335],
        [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
          0.7814573 ,  0.6139724 ],
        [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
          1.1730937 ,  1.1040318 ]], dtype=float32),
 array([[0.74050546, 1.1471487 , 1.1523566 , ..., 2.8801389 , 2.9853294 ,
         3.0963612 ],
        [0.33376965, 0.1129124 , 0.25598168, ..., 1.8844459 , 2.5375013 ,
         3.7338994 ],
        [3.7924511 , 3.189696  , 2.8261626 , ..., 5.365091  , 5.273031  ,
         4.5386353 ],
        ...,
        [1.6787068 , 2.0034997 , 2.0164783 , ..., 2.54044

In [12]:
all_waveforms[0]

array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
         0.89965886,  0.1729701 ],
       [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
         1.6655885 ,  1.0949929 ],
       [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
         0.83251476,  1.9278888 ],
       ...,
       [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
         3.7053027 ,  0.17694335],
       [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
         0.7814573 ,  0.6139724 ],
       [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
         1.1730937 ,  1.1040318 ]], dtype=float32)

In [13]:
all_waveforms[0][0]

array([  0.40537423,   0.9204547 ,   1.3042495 ,   0.66315186,
        -0.83373624,  -1.3520111 ,  -0.45797062,  -0.0220464 ,
        -0.36910826,   0.3050875 ,   1.4935703 ,   0.8006043 ,
        -0.44051605,   1.1329583 ,   4.469434  ,   7.3652925 ,
        10.524259  ,  10.669554  ,   2.2784607 ,  -9.069311  ,
       -12.297354  ,  -8.102026  ,  -4.433738  ,  -3.3412938 ,
        -3.1506605 ,  -3.1593354 ,  -2.5206325 ,  -1.4534206 ,
        -1.2984093 ,  -1.4807814 ,  -0.7342157 ,   0.31443456,
         0.7896247 ,   0.92220134,   1.1487566 ,   1.8166375 ,
         2.5917692 ,   2.269646  ,   0.89965886,   0.1729701 ],
      dtype=float32)

In [14]:
len(all_waveforms[0][0])

40

In [15]:
cluster_waveforms[9]

array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
         0.89965886,  0.1729701 ],
       [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
         1.6655885 ,  1.0949929 ],
       [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
         0.83251476,  1.9278888 ],
       ...,
       [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
         3.7053027 ,  0.17694335],
       [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
         0.7814573 ,  0.6139724 ],
       [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
         1.1730937 ,  1.1040318 ]], dtype=float32)

In [16]:
len(cluster_waveforms[9])

13768

In [17]:
cluster_waveforms[9][0]

array([  0.40537423,   0.9204547 ,   1.3042495 ,   0.66315186,
        -0.83373624,  -1.3520111 ,  -0.45797062,  -0.0220464 ,
        -0.36910826,   0.3050875 ,   1.4935703 ,   0.8006043 ,
        -0.44051605,   1.1329583 ,   4.469434  ,   7.3652925 ,
        10.524259  ,  10.669554  ,   2.2784607 ,  -9.069311  ,
       -12.297354  ,  -8.102026  ,  -4.433738  ,  -3.3412938 ,
        -3.1506605 ,  -3.1593354 ,  -2.5206325 ,  -1.4534206 ,
        -1.2984093 ,  -1.4807814 ,  -0.7342157 ,   0.31443456,
         0.7896247 ,   0.92220134,   1.1487566 ,   1.8166375 ,
         2.5917692 ,   2.269646  ,   0.89965886,   0.1729701 ],
      dtype=float32)

In [18]:
len(cluster_waveforms[9][0])

40

In [19]:
len(cluster_waveforms[150])

5366

In [20]:
len(all_waveforms[1])

6260

In [21]:
len(all_waveforms[2])

78443

In [22]:
# Function to pad arrays to the same number of columns with NaNs
def pad_array(arr, max_length):
    pad_width = max_length - arr.shape[1]
    if pad_width > 0:
        return np.pad(arr, ((0, 0), (0, pad_width)), mode='constant', constant_values=np.nan)
    return arr

# Assuming your dictionary is named `cluster_waveforms`
# Initialize an empty list to collect waveforms
all_waveforms = []

# Determine the maximum number of timepoints across all clusters
max_timepoints = max(spikes.shape[1] for spikes in cluster_waveforms.values() if spikes.ndim == 2)

# Iterate through each unit/spike_cluster in the dictionary
for unit, spikes in cluster_waveforms.items():
    if spikes.ndim == 2:
        # Pad the array to the maximum number of timepoints
        padded_spikes = pad_array(spikes, max_timepoints)
        all_waveforms.append(padded_spikes)
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes.shape}")

# Combine all collected waveforms into a single 2D Numpy array
if all_waveforms:
    W = np.vstack(all_waveforms)  # Shape: (total_number_of_spikes, max_timepoints)
    print(f"Combined waveform shape: {W.shape}")

    # Calculate the mean waveform across all spikes, ignoring NaNs
    W_bar = np.nanmean(W, axis=0)

    # Calculate the signal amplitude (max - min of the mean waveform), ignoring NaNs
    sig_amp = np.nanmax(W_bar) - np.nanmin(W_bar)

    # Subtract the mean waveform from each spike to get the noise
    noise = W - np.tile(W_bar, (np.shape(W)[0], 1))

    # Calculate the signal-to-noise ratio, ignoring NaNs
    snr = sig_amp / (2 * np.nanstd(noise.flatten()))
    
    print(f"Signal-to-Noise Ratio (SNR): {snr}")
else:
    print("No valid waveforms to process.")

Combined waveform shape: (262518, 40)
Signal-to-Noise Ratio (SNR): 3.0603712211149254


In [23]:
# Assuming you already have all_waveforms as a list of 2D arrays
# Initialize an empty list to collect waveforms with SNR >= 3
no_noise_waveforms = []

# Function to calculate SNR for a given 2D array
def calculate_snr(waveforms):
    W_bar = np.nanmean(waveforms, axis=0)
    sig_amp = np.nanmax(W_bar) - np.nanmin(W_bar)
    noise = waveforms - np.tile(W_bar, (waveforms.shape[0], 1))
    snr = sig_amp / (2 * np.nanstd(noise.flatten()))
    return snr

# Determine the maximum number of timepoints across all waveforms
max_timepoints = max(waveforms.shape[1] for waveforms in all_waveforms)

# Pad arrays to the maximum number of timepoints with NaNs
padded_waveforms = [np.pad(waveforms, ((0, 0), (0, max_timepoints - waveforms.shape[1])), mode='constant', constant_values=np.nan)
                    for waveforms in all_waveforms]

# Filter waveforms with SNR >= 3
for waveforms in padded_waveforms:
    snr = calculate_snr(waveforms)
    if snr >= 3:
        no_noise_waveforms.append(waveforms)

# Combine the filtered waveforms into a single 2D Numpy array if any are found
if no_noise_waveforms:
    filtered_W = np.vstack(no_noise_waveforms)
    print(f"Filtered waveform shape: {filtered_W.shape}")
else:
    print("No waveforms with SNR >= 3 were found.")

# For debugging and display purposes, you can also print the SNR values of all waveforms
for i, waveforms in enumerate(padded_waveforms):
    snr = calculate_snr(waveforms)
    print(f"Waveform {i} SNR: {snr}")

Filtered waveform shape: (257152, 40)
Waveform 0 SNR: 6.2093730912001135
Waveform 1 SNR: 7.192345780132867
Waveform 2 SNR: 5.006380243680825
Waveform 3 SNR: 6.128031714773328
Waveform 4 SNR: 3.4471185002757503
Waveform 5 SNR: 4.568458520622393
Waveform 6 SNR: 5.476385349031534
Waveform 7 SNR: 2.5879878802736367
Waveform 8 SNR: 4.649149235301674
Waveform 9 SNR: 7.711252008938313


In [24]:
len(no_noise_waveforms)

9

In [25]:
len(all_waveforms)

10

In [26]:
cluster_waveforms

{9: array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
          0.89965886,  0.1729701 ],
        [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
          1.6655885 ,  1.0949929 ],
        [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
          0.83251476,  1.9278888 ],
        ...,
        [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
          3.7053027 ,  0.17694335],
        [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
          0.7814573 ,  0.6139724 ],
        [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
          1.1730937 ,  1.1040318 ]], dtype=float32),
 13: array([[0.74050546, 1.1471487 , 1.1523566 , ..., 2.8801389 , 2.9853294 ,
         3.0963612 ],
        [0.33376965, 0.1129124 , 0.25598168, ..., 1.8844459 , 2.5375013 ,
         3.7338994 ],
        [3.7924511 , 3.189696  , 2.8261626 , ..., 5.365091  , 5.273031  ,
         4.5386353 ],
        ...,
        [1.6787068 , 2.0034997 , 2.0164783 , ..., 

In [27]:
cluster_waveforms[9]

array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
         0.89965886,  0.1729701 ],
       [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
         1.6655885 ,  1.0949929 ],
       [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
         0.83251476,  1.9278888 ],
       ...,
       [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
         3.7053027 ,  0.17694335],
       [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
         0.7814573 ,  0.6139724 ],
       [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
         1.1730937 ,  1.1040318 ]], dtype=float32)

In [28]:
len(cluster_waveforms)

10

In [29]:
len(cluster_waveforms[9])

13768

In [30]:
len(cluster_waveforms[9][0])

40

In [31]:
def calculate_snr(waveforms):
    # Calculate the mean waveform across all spikes in a single cluster, ignoring NaNs
    W_bar = np.nanmean(waveforms, axis=0)
    # Calculate the signal amplitude (max - min of the mean waveform), ignoring NaNs
    sig_amp = np.nanmax(W_bar) - np.nanmin(W_bar)
    # Subtract the mean waveform from each spike to get the noise
    noise = waveforms - np.tile(W_bar, (waveforms.shape[0], 1))
    # Calculate the signal-to-noise ratio, ignoring NaNs
    snr = sig_amp / (2 * np.nanstd(noise.flatten()))
    return snr

# Initialize a dictionary to hold the SNR values for each cluster
snr_dict = {}

# Iterate through each unit/spike_cluster in the cluster_waveforms dictionary
for unit, spikes in cluster_waveforms.items():
    if spikes.ndim == 2:
        # Calculate SNR for each cluster's waveforms
        snr = calculate_snr(spikes)
        # Store the SNR value in the dictionary with its corresponding unit
        snr_dict[unit] = snr
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes.shape}")

# Filter and retain only waveforms with SNR >= 3
high_snr_waveforms = {unit: spikes for unit, spikes in cluster_waveforms.items() if snr_dict[unit] >= 3}

# Determine the number of units excluded due to low SNR
excluded_units = len(snr_dict) - len(high_snr_waveforms)

# Print the result
print(f"Number of excluded units due to low SNR: {excluded_units}")

Number of excluded units due to low SNR: 1


In [32]:
high_snr_waveforms[9][0]

array([  0.40537423,   0.9204547 ,   1.3042495 ,   0.66315186,
        -0.83373624,  -1.3520111 ,  -0.45797062,  -0.0220464 ,
        -0.36910826,   0.3050875 ,   1.4935703 ,   0.8006043 ,
        -0.44051605,   1.1329583 ,   4.469434  ,   7.3652925 ,
        10.524259  ,  10.669554  ,   2.2784607 ,  -9.069311  ,
       -12.297354  ,  -8.102026  ,  -4.433738  ,  -3.3412938 ,
        -3.1506605 ,  -3.1593354 ,  -2.5206325 ,  -1.4534206 ,
        -1.2984093 ,  -1.4807814 ,  -0.7342157 ,   0.31443456,
         0.7896247 ,   0.92220134,   1.1487566 ,   1.8166375 ,
         2.5917692 ,   2.269646  ,   0.89965886,   0.1729701 ],
      dtype=float32)

In [33]:
model

<phylib.io.model.TemplateModel at 0x216f2d3d2b0>

In [34]:
type(model)

phylib.io.model.TemplateModel

In [35]:
print(dir(model))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_amplitudes', '_channels', '_compute_wmi', '_find_best_channels', '_find_path', '_get_template_dense', '_get_template_from_spikes', '_get_template_sparse', '_load_amplitudes', '_load_channel_map', '_load_channel_positions', '_load_channel_probes', '_load_channel_shanks', '_load_data', '_load_features', '_load_metadata', '_load_similar_templates', '_load_spike_attributes', '_load_spike_clusters', '_load_spike_reorder', '_load_spike_samples', '_load_spike_templates', '_load_spike_waveforms', '_load_template_features', '_load_templates', '_load_traces', '_load_wm', '_load_wmi', '_read_array', '_template_n_channels', '_unwhiten', '_waveform_durations

In [36]:
model.spike_times

array([1.9500000e-03, 3.1000000e-03, 7.3500000e-03, ..., 3.1796972e+03,
       3.1797020e+03, 3.1797067e+03])

In [37]:
len(model.spike_times)

719400

In [38]:
len(model.spike_clusters)

719400

In [39]:
model.spike_clusters

array([18,  9,  7, ...,  8,  8,  3])

In [40]:
model.spike_times[0]

0.00195

In [41]:
model.spike_times[1]

0.0031

In [42]:
spike_times = np.load(r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\spike_times.npy')

In [43]:
spike_times[0]

array([39], dtype=int64)

In [44]:
spike_times[1]

array([62], dtype=int64)

In [45]:
model.amplitudes

array([ -8.934715 , -12.297354 ,  -6.8265376, ...,  -6.7597466,
        -6.7344522,  -7.820067 ], dtype=float32)

In [46]:
len(model.amplitudes)

719400

In [47]:
model.cluster_ids

array([  1,   2,   3,   4,   7,   8,   9,  12,  13,  18,  19,  20,  21,
        93, 104, 115, 117, 123, 127, 143, 144, 147, 148, 149, 150, 154,
       171, 189, 190, 193, 194, 195])

In [48]:
model.get_cluster_channels(9)

array([16, 18,  6, 17, 11,  7,  9, 25, 14, 29,  8, 15, 21, 22, 26, 20, 12,
       23, 27,  3, 24, 30, 13, 19,  2, 31,  4,  1,  0, 28, 10,  5],
      dtype=uint32)

In [49]:
model.get_cluster_channels(190)

array([ 6, 18, 16, 17, 11,  7,  9, 14, 13, 25, 22,  8, 26, 15,  2, 29, 12,
       30, 20, 23, 24, 27, 21,  3, 19, 28,  0, 31,  4,  5,  1, 10],
      dtype=uint32)

In [50]:
model.clusters_waveforms_durations

array([ 0.6 ,  0.7 ,  0.15,  0.7 , -0.15,  0.6 , -0.15,  0.65,  0.2 ,
       -0.15,  0.6 , -0.15,  0.75,  0.7 , -0.15,  0.7 , -0.15,  0.7 ,
        0.25,  0.8 ,  0.8 , -0.15])

In [51]:
model.wm

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

In [52]:
len(model.wm)

32

In [53]:
len(model.wm[0])

32

In [54]:
model.wmi

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

need:\
    the previous step needs to clusters/waveforms identifiable by original #\
    `spike_train` for each unit\
    `duration` is the max - min spiketrain, but needs to be in seconds I think?

In [59]:
model.spike_times

array([1.9500000e-03, 3.1000000e-03, 7.3500000e-03, ..., 3.1796972e+03,
       3.1797020e+03, 3.1797067e+03])

In [60]:
model.spike_clusters

array([18,  9,  7, ...,  8,  8,  3])

In [61]:
model.spike_times[0]

0.00195

In [62]:
spike_times

array([[      39],
       [      62],
       [     147],
       ...,
       [63593944],
       [63594040],
       [63594134]], dtype=int64)

In [63]:
spike_clusters = np.load(r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\spike_clusters.npy')

In [64]:
spike_clusters

array([18,  9,  7, ...,  8,  8,  3])

In [65]:
len(spike_times)

719400

In [66]:
len(spike_clusters)

719400

In [69]:
end_time = int(spike_times[-1])

In [70]:
end_time

63594134

In [71]:
# Initialize an empty dictionary to hold the spike times for each cluster
cluster_spike_times = {}

# Iterate through each unique cluster ID
for cluster_id in np.unique(spike_clusters):
    # Get the indices of the spikes that belong to the current cluster
    cluster_indices = np.where(spike_clusters == cluster_id)[0]
    # Get the spike times for these indices
    cluster_spike_times[cluster_id] = spike_times[cluster_indices]