In [9]:
import pandas as pd
import numpy as np
import concurrent.futures
import sys
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from wavemap_paper.helper_functions import RAND_STATE, set_rand_state
from phylib.io.model import load_model
from umap import umap_ as umap
from phylib.utils.color import selected_cluster_color
import networkx as nx

In [2]:
# Read the TSV file
cluster_info = pd.read_csv(r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\cluster_info.tsv', sep='\t')

good_clusters = np.intersect1d(
    cluster_info['cluster_id'][cluster_info['group'] == 'good'],
    cluster_info['cluster_id'][cluster_info['fr'] > 0.5]
)

In [3]:
# Directly specify the path to params.py
params_path = r'.\test1\20240320_142408_alone_comp_subj_3-1_t6b6_merged.rec\phy\params.py'

# Load the TemplateModel
model = load_model(params_path)

In [4]:
def process_cluster(cluster_id):
    waveforms = model.get_cluster_spike_waveforms(cluster_id)[:, :, 0]
    return (cluster_id, waveforms)

# Dictionary to store the results
cluster_waveforms = {}

# Use ThreadPoolExecutor to handle clusters in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
    future_to_cluster = {executor.submit(process_cluster, cluster_id): cluster_id for cluster_id in good_clusters}
    for future in concurrent.futures.as_completed(future_to_cluster):
        cluster_id, best_channel_waveforms = future.result()
        cluster_waveforms[cluster_id] = best_channel_waveforms

In [5]:
def calculate_snr(waveforms):
    # Calculate the mean waveform across all spikes in a single cluster, ignoring NaNs
    W_bar = np.nanmean(waveforms, axis=0)
    # Calculate the signal amplitude (max - min of the mean waveform), ignoring NaNs
    sig_amp = np.nanmax(W_bar) - np.nanmin(W_bar)
    # Subtract the mean waveform from each spike to get the noise
    noise = waveforms - np.tile(W_bar, (waveforms.shape[0], 1))
    # Calculate the signal-to-noise ratio, ignoring NaNs
    snr = sig_amp / (2 * np.nanstd(noise.flatten()))
    return snr

# Initialize a dictionary to hold the SNR values for each cluster
snr_dict = {}

# Iterate through each unit/spike_cluster in the cluster_waveforms dictionary
for unit, spikes in cluster_waveforms.items():
    if spikes.ndim == 2:
        # Calculate SNR for each cluster's waveforms
        snr = calculate_snr(spikes)
        # Store the SNR value in the dictionary with its corresponding unit
        snr_dict[unit] = snr
    else:
        print(f"Skipping unit {unit} due to unexpected shape: {spikes.shape}")

# Filter and retain only waveforms with SNR >= 3
high_snr_waveforms = {unit: spikes for unit, spikes in cluster_waveforms.items() if snr_dict[unit] >= 3}

# Determine the number of units excluded due to low SNR
excluded_units = len(snr_dict) - len(high_snr_waveforms)

# Print the result
print(f"Number of excluded units due to low SNR: {excluded_units}")

Number of excluded units due to low SNR: 1


In [6]:
mean_waveforms = {}
for key, waveforms in high_snr_waveforms.items():
    # Calculate the mean along the first axis (axis=0) to average all spikes
    mean_waveform = np.mean(waveforms, axis=0)
    mean_waveforms[key] = mean_waveform

In [12]:
# Initialize a dictionary to hold the normalized waveforms
normWFs = {}

for cluster_id, waveform in mean_waveforms.items():
    # Mean subtraction
    mean_subtracted_waveform = waveform - np.mean(waveform)
    
    # Normalization: reshape the waveform to have a shape of (1, length of waveform)
    # for compatibility with sklearn's normalize function
    reshaped_waveform = mean_subtracted_waveform.reshape(1, -1)
    
    # Normalize the waveform to have a maximum value of 1
    normalized_waveform = normalize(reshaped_waveform, norm='max')
    
    # Store the normalized waveform (squeeze to remove single-dimensional entries from the shape)
    normWFs[cluster_id] = normalized_waveform.squeeze()

In [13]:
set_rand_state(RAND_STATE)

In [14]:
reducer = umap.UMAP(random_state = RAND_STATE, n_neighbors = 15)

mapper = reducer.fit(normWFs)

TypeError: float() argument must be a string or a number, not 'dict'