In [1]:
import pandas as pd
import numpy as np
import concurrent.futures
import sys
import matplotlib.pyplot as plt
from phylib.io.model import load_model
from phylib.utils.color import selected_cluster_color

In [2]:
# Read the TSV file
cluster_info = pd.read_csv(r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\cluster_info.tsv', sep='\t')

good_clusters = np.intersect1d(
    cluster_info['cluster_id'][cluster_info['group'] == 'good'],
    cluster_info['cluster_id'][cluster_info['fr'] > 0.5]
)

In [3]:
# Directly specify the path to params.py
params_path = r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\params.py'

# Load the TemplateModel
model = load_model(params_path)

In [4]:
cluster_waveforms = {}

for cluster_id in good_clusters:
    # Get the waveforms for the cluster
    waveforms = model.get_cluster_spike_waveforms(cluster_id)
    
    # Get the channel IDs and find the best channel (assuming it's the first one returned by get_cluster_channels)
    channel_ids = model.get_cluster_channels(cluster_id)
    
    # Get the waveform for just the best channel. Since the data is already in 40 time points, no subsampling is needed
    best_channel_waveforms = waveforms[:, :, 0]  # Since channel_ids[0] is the best channel
    # Store the waveforms in the dictionary. 
    cluster_waveforms[cluster_id] = best_channel_waveforms 

In [5]:
cluster_waveforms

{9: array([[ 0.40537423,  0.9204547 ,  1.3042495 , ...,  2.269646  ,
          0.89965886,  0.1729701 ],
        [ 1.4716856 , -0.18663864, -1.8644416 , ...,  1.7085094 ,
          1.6655885 ,  1.0949929 ],
        [ 1.2226917 ,  0.27370277,  0.15729961, ..., -0.11720731,
          0.83251476,  1.9278888 ],
        ...,
        [-0.6192026 , -1.5477538 , -0.3860595 , ...,  2.0028834 ,
          3.7053027 ,  0.17694335],
        [-1.674623  , -0.97858775, -0.06450201, ...,  1.6264546 ,
          0.7814573 ,  0.6139724 ],
        [-0.1571035 , -0.70310694, -0.836848  , ...,  2.4313078 ,
          1.1730937 ,  1.1040318 ]], dtype=float32),
 13: array([[0.74050546, 1.1471487 , 1.1523566 , ..., 2.8801389 , 2.9853294 ,
         3.0963612 ],
        [0.33376965, 0.1129124 , 0.25598168, ..., 1.8844459 , 2.5375013 ,
         3.7338994 ],
        [3.7924511 , 3.189696  , 2.8261626 , ..., 5.365091  , 5.273031  ,
         4.5386353 ],
        ...,
        [1.6787068 , 2.0034997 , 2.0164783 , ..., 

In [6]:
def process_cluster(cluster_id):
    waveforms = model.get_cluster_spike_waveforms(cluster_id)
    channel_ids = model.get_cluster_channels(cluster_id)
    best_channel_waveforms = waveforms[:, :, 0]  # Assuming channel_ids[0] is already the best channel
    return (cluster_id, best_channel_waveforms)

# Dictionary to store the results
cluster_waveforms = {}

# Use ThreadPoolExecutor to handle clusters in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
    future_to_cluster = {executor.submit(process_cluster, cluster_id): cluster_id for cluster_id in good_clusters}
    for future in concurrent.futures.as_completed(future_to_cluster):
        cluster_id, best_channel_waveforms = future.result()
        cluster_waveforms[cluster_id] = best_channel_waveforms

In [7]:
cluster_waveforms

{117: array([[ 3.0652988e-03,  7.6683294e-03,  2.0962958e-01, ...,
          3.4154861e+00,  2.4242194e+00,  2.3850234e+00],
        [ 8.5359031e-01,  7.1648288e-01,  1.0315886e+00, ...,
          1.2333052e+00,  1.5011677e+00,  7.0454276e-01],
        [-7.1695483e-01, -4.2525470e-01, -8.0104136e-01, ...,
          1.6927334e+00,  2.1589284e+00,  2.2851896e+00],
        ...,
        [-1.9302584e+00, -7.6183051e-01,  1.4385498e+00, ...,
          2.2544894e+00,  1.2180524e+00,  7.4693960e-01],
        [ 3.5493067e-01,  5.8153677e-01,  6.3269429e-02, ...,
          2.1838758e+00,  2.3007801e+00,  1.8068649e+00],
        [ 1.4697404e+00,  9.6029615e-01,  2.2761877e+00, ...,
          1.1059222e+00, -2.0028825e-01, -1.4888754e+00]], dtype=float32),
 190: array([[-0.3583337 , -1.0804446 , -0.22278394, ...,  3.1371636 ,
          3.3423917 ,  2.3632755 ],
        [ 1.2022593 ,  1.7476974 ,  1.4472492 , ...,  2.7672553 ,
          0.4109821 , -0.5041535 ],
        [ 2.1082544 ,  2.5640907 ,  

In [8]:
def calculate_snr(waveforms):
    # Calculate the mean waveform across all spikes in a single cluster, ignoring NaNs
    W_bar = np.nanmean(waveforms, axis=0)
    # Calculate the signal amplitude (max - min of the mean waveform), ignoring NaNs
    sig_amp = np.nanmax(W_bar) - np.nanmin(W_bar)
    # Subtract the mean waveform from each spike to get the noise
    noise = waveforms - np.tile(W_bar, (waveforms.shape[0], 1))
    # Calculate the signal-to-noise ratio, ignoring NaNs
    snr = sig_amp / (2 * np.nanstd(noise.flatten()))
    return snr

# Initialize a dictionary to hold the SNR values for each cluster
snr_dict = {}

# Iterate through each unit/spike_cluster in the cluster_waveforms dictionary
for unit, spikes in cluster_waveforms.items():
    if spikes.ndim == 2:
        # Calculate SNR for each cluster's waveforms
        snr = calculate_snr(spikes)
        # Store the SNR value in the dictionary with its corresponding unit
        snr_dict[unit] = snr
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes.shape}")

# Filter and retain only waveforms with SNR >= 3
high_snr_waveforms = {unit: spikes for unit, spikes in cluster_waveforms.items() if snr_dict[unit] >= 3}

# Determine the number of units excluded due to low SNR
excluded_units = len(snr_dict) - len(high_snr_waveforms)

# Print the result
print(f"Number of excluded units due to low SNR: {excluded_units}")

Number of excluded units due to low SNR: 1


In [12]:
def process_cluster(cluster_id):
    waveforms = model.get_cluster_spike_waveforms(cluster_id)[:, :, 0]
    return (cluster_id, waveforms)

# Dictionary to store the results
cluster_waveforms = {}

# Use ThreadPoolExecutor to handle clusters in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
    future_to_cluster = {executor.submit(process_cluster, cluster_id): cluster_id for cluster_id in good_clusters}
    for future in concurrent.futures.as_completed(future_to_cluster):
        cluster_id, best_channel_waveforms = future.result()
        cluster_waveforms[cluster_id] = best_channel_waveforms

In [13]:
cluster_waveforms.keys()

dict_keys([117, 190, 150, 13, 115, 9, 154, 123, 18, 20])