In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel

In [None]:
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.widgets as sw
from pprint import pprint
import numpy as np
from pathlib import Path
import os

import warnings
warnings.simplefilter("ignore")

from spikeinterface import NumpyRecording, NumpySorting
from spikeinterface import append_recordings, concatenate_recordings

# %matplotlib widget
%matplotlib inline


In [None]:
base_folder = Path('/Users/shiqi/Desktop/spikeinterface/POD/o1_part6')
# base_folder = Path("/home/alessio/Documents/data/spiketutorials/Official_Tutorial_SI_0.99_Nov23/")
oe_folder1 = base_folder / 'before'
oe_folder2 = base_folder / 'after'
full_raw_rec1 = si.read_openephys(oe_folder1)
fs = full_raw_rec1.get_sampling_frequency()
full_raw_rec1_sub = full_raw_rec1.frame_slice(start_frame=0*fs, end_frame=10*fs)
full_raw_rec2 = si.read_openephys(oe_folder2)
full_raw_rec2_sub = full_raw_rec2.frame_slice(start_frame=0*fs, end_frame=10*fs)

In [None]:
fs = full_raw_rec1.get_sampling_frequency()
full_raw_rec = si.concatenate_recordings([full_raw_rec1_sub, full_raw_rec2_sub])
#full_raw_rec = full_raw_rec1_sub
#full_raw_rec.get_probe().to_dataframe()

In [None]:
import scipy.signal as signal
from scipy.ndimage import gaussian_filter1d
rec1 = si.bandpass_filter(full_raw_rec, freq_min=300, freq_max=6000)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1)
rec2 = rec1.remove_channels(bad_channel_ids)
print('bad_channel_ids', bad_channel_ids)

rec3 = si.phase_shift(rec2)
rec4 = si.common_reference(rec3, operator="median", reference="global")
rec = rec4

In [None]:
#sorting_SC = se.read_spykingcircus(folder_path="/Users/shiqi/Desktop/spikeinterface/POD/y_part7/folder_SC_y_p7/sorter_output")
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)

In [None]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
peaks = detect_peaks(rec,  method='locally_exclusive', detect_threshold=6, radius_um=50., **job_kwargs)
peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)

In [None]:
spike_x_locations = peak_locations['x']
spike_y_locations = peak_locations['y']
spike_times = peaks['sample_index'] / fs
time_window = 300  # seconds
valid_time_indices = spike_times <= time_window
spike_times = spike_times[valid_time_indices]
spike_x_locations = spike_x_locations[valid_time_indices]
spike_y_locations = spike_y_locations[valid_time_indices]

# Initialize plot with 4 horizontal subplots
fig, axes = plt.subplots(1, 4, figsize=(15, 5))

# Define the number of shanks
num_shanks = 4

# Define the boundaries for each shank based on x-coordinates
shank_boundaries = np.linspace(np.min(spike_x_locations), np.max(spike_x_locations), num_shanks + 1)

# Define colors for each shank
colors = [(1, 0.5, 0.5), (1, 0.5, 0), (0, 0.75, 0.5), (0.5, 0.5, 1)] 

# Plot each spike, assigning them to shanks based on x-coordinates
for shank_id in range(num_shanks):
    shank_mask = (spike_x_locations >= shank_boundaries[shank_id]) & (spike_x_locations < shank_boundaries[shank_id + 1])
    shank_spike_times = spike_times[shank_mask]
    shank_spike_y_locations = spike_y_locations[shank_mask]

    axes[shank_id].scatter(shank_spike_times, shank_spike_y_locations, s=1, color=colors[shank_id])
    #axes[shank_id].set_title(f'Shank {shank_id}')
    #axes[shank_id].set_xlabel('Time (s)')
    #axes[shank_id].set_xticks([])
    #axes[shank_id].set_yticks([])

    # Remove spines (boundaries)
    axes[shank_id].spines['top'].set_visible(False)
    axes[shank_id].spines['right'].set_visible(False)
    axes[shank_id].spines['left'].set_visible(False)
    axes[shank_id].spines['bottom'].set_visible(False)
    
# Adjust layout
plt.tight_layout()
plt.savefig('spike_plot.png', format='png', dpi=300)
plt.show()

In [None]:
#sorting_SC = se.SpykingCircusSortingExtractor(base_folder / 'folder_POD/sorter_output')

In [None]:
from spikeinterface.sorters import run_sorter
sorting_SC = run_sorter(sorter_name="spykingcircus", recording=rec, detect_sign= -1,
                        detect_threshold=6, output_folder=base_folder / 'folder_part6', filter = False, verbose=True)

In [None]:
sorting_SC

In [None]:
sorting_analyzer = si.create_sorting_analyzer(sorting_SC, rec,
                                              format="binary_folder", folder=base_folder / 'my_sorting_analyzer_part6',
                                              **job_kwargs)
sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500)
waveform=sorting_analyzer.compute("waveforms", **job_kwargs)
templates=sorting_analyzer.compute("templates", **job_kwargs)
noise=sorting_analyzer.compute("noise_levels")
unit_loc=sorting_analyzer.compute("unit_locations", method="monopolar_triangulation")
spike_loc = sorting_analyzer.compute(input="spike_locations",method="center_of_mass")
isi=sorting_analyzer.compute("isi_histograms")
corr=sorting_analyzer.compute("correlograms", window_ms=100, bin_ms=0.5)
#pc=sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs)
#quality_metrics=sorting_analyzer.compute("quality_metrics", metric_names=["snr", "firing_rate"])
similarity=sorting_analyzer.compute("template_similarity")
amplitute=sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

In [None]:
traces = rec.get_traces()

In [None]:
import matplotlib.cm as cm
import matplotlib.pyplot as plt
start_sample = 12000
channel_range = range(130, 140)  # Channels 10 through 21 (inclusive)
time_window = 2  # seconds
num_samples = int(fs * time_window)
end_sample = start_sample + num_samples

selected_traces = traces[start_sample:end_sample, channel_range]

time_axis = np.arange(start_sample, end_sample) / fs

plt.figure(figsize=(15, 10))
cmap = plt.get_cmap('cool', len(channel_range))

for i, channel in enumerate(channel_range):
    plt.plot(time_axis, selected_traces[:, i] + i * 800, color=cmap(i), label=f'Channel {channel + 1}')  # Offset each trace for clarity

plt.xlabel('Time (s)')
plt.yticks([])

plt.savefig('raw_traces_channels_10_to_21_10s_gradient.png', format='png', dpi=300)

plt.show()

In [None]:
selected_traces.shape
selected_traces=np.array(selected_traces, dtype=object)

In [None]:
from scipy.signal import find_peaks
num_channels = selected_traces.shape[1]
all_peaks = {}  # Dictionary to store peaks for each channel

for i in range(num_channels):
    # Calculate the standard deviation of the trace
    std_dev = np.std(selected_traces[:, i])
    
    # Set the negative threshold as -5 * std
    threshold = -3 * std_dev
    
    # Find all peaks below the negative threshold
    peaks, properties = find_peaks(-selected_traces[:, i], height=-threshold)
    
    # Store the peaks and their properties
    all_peaks[f"Channel_{i + channel_range.start}"] = {
        'peaks': peaks,
        'peak_heights': properties['peak_heights']
    }

# Print or use the peak information
for channel, peak_info in all_peaks.items():
    print(f"{channel}: Found {len(peak_info['peaks'])} peaks")

In [None]:
time_points = {}
for channel, peak_info in all_peaks.items():
    # Calculate the time for each peak index
    time_points[channel] = peak_info['peaks'] / fs

# Create a scatter plot
plt.figure(figsize=(10, 6))

# Plot each channel's peaks as a scatter plot
for i, (channel, times) in enumerate(time_points.items()):
    plt.scatter(times, [i+channel_range.start] * len(times), label=channel)

# Adding labels and title
plt.xlabel('Time (s)')
plt.ylabel('Channel')
plt.title('Scatter Plot of Peak Time Points')
plt.yticks(range(channel_range.start, channel_range.stop))  # Set y-axis ticks to the channel numbers
plt.legend(title="Channels")
plt.grid(True)
plt.show()

In [None]:
fs = 30000  # Define your sampling frequency
num_samples = selected_traces.shape[0]
time_axis = np.arange(start_sample, end_sample) / fs  # Create the time axis based on start_sample and fs

# Create the figure with desired size
plt.figure(figsize=(15, 10))

# Get the colormap
cmap = plt.get_cmap('cool', len(channel_range))

# Plot the selected traces with offset for clarity
for i, channel in enumerate(channel_range):
    plt.plot(time_axis, selected_traces[:, i] + i * 800, color=cmap(i), label=f'Channel {channel}')  # Offset each trace
    
    # Add scatter plot of peak points
    if f"Channel_{channel}" in all_peaks:
        peak_indices = all_peaks[f"Channel_{channel}"]['peaks']
        peak_times = peak_indices / fs + start_sample / fs  # Convert to time points
        peak_amplitudes = selected_traces[peak_indices, i] + i * 800  # Use offset to match trace
        plt.scatter(peak_times, peak_amplitudes, color=cmap(i), edgecolor='k', zorder=3)

# Adding labels and title
plt.xlabel('Time (s)')
plt.ylabel('Amplitude + Offset')
plt.title('Traces with Peak Detection')
plt.yticks([])  # Remove y-ticks for clarity
plt.legend(title="Channels", bbox_to_anchor=(1.05, 1), loc='upper left')  # Place legend outside

# Save the figure
plt.savefig('raw_traces_with_peaks.png', format='png', dpi=300)

# Show the figure
plt.show()

In [None]:
time_window_ms = 3  # 3 ms window
window_samples = int(fs * (time_window_ms / 1000))  # Number of samples for 3 ms window
half_window = window_samples // 2  # Half window for centering

# Dictionary to store extracted waveforms for each channel
extracted_spikes = {}

# Iterate through each channel to extract spikes
for channel in channel_range:
    channel_label = f"Channel_{channel}"
    if channel_label in all_peaks:
        peak_indices = all_peaks[channel_label]['peaks']
        
        # List to store waveforms for the current channel
        channel_waveforms = []
        
        # Extract waveforms around each peak
        for peak in peak_indices:
            start_idx = peak - half_window
            end_idx = peak + half_window
            
            # Ensure indices are within the bounds of the trace
            if start_idx >= 0 and end_idx < traces.shape[0]:
                # Extract the waveform for this peak
                waveform = traces[start_idx:end_idx, channel - channel_range.start]
                channel_waveforms.append(waveform)
        
        # Store the waveforms in the dictionary for the current channel
        extracted_spikes[channel_label] = {
            'waveforms': np.array(channel_waveforms)
        }

# Check the extracted spikes for each channel
for channel, data in extracted_spikes.items():
    print(f"{channel}: {len(data['waveforms'])} spikes extracted.")

In [None]:
time_window_ms = 3  # 5 ms time window
window_samples = int(fs * (time_window_ms / 1000))  # Number of samples for 5 ms window
half_window = window_samples // 2  # Half window for centering

# Dictionary to store 5 ms extracted waveforms for each channel
extracted_waveforms = {}

# Iterate through each channel and extract waveforms around peaks
for channel in channel_range:
    channel_label = f"Channel_{channel}"
    if channel_label in all_peaks:
        peak_indices = all_peaks[channel_label]['peaks']
        
        # List to store waveforms for the current channel
        waveforms = []
        
        # Extract waveforms around each peak
        for peak in peak_indices:
            start_idx = peak - half_window
            end_idx = peak + half_window
            
            # Ensure indices are within the bounds of the trace
            if start_idx >= 0 and end_idx < selected_traces.shape[0]:
                # Extract the waveform around this peak
                waveform = selected_traces[start_idx:end_idx, channel - channel_range.start]
                waveforms.append(waveform)
        
        # Store the waveforms in the dictionary for the current channel
        extracted_waveforms[channel_label] = np.array(waveforms)

# Check the extracted waveforms for each channel
for channel, waveforms in extracted_waveforms.items():
    print(f"{channel}: {waveforms.shape[0]} spikes extracted, each with shape {waveforms.shape[1:]}.")

In [None]:
single_time_axis = np.linspace(-half_window, half_window, window_samples) / fs * 1000  # Convert to ms

# Get the colormap for the channels
cmap = plt.get_cmap('cool', len(channel_range))

# Sort the extracted_waveforms dictionary by channel numbers and reverse the order
sorted_channels = sorted(extracted_waveforms.keys(), key=lambda x: int(x.split('_')[1]), reverse=True)

# Create a figure for plotting all flattened 5 ms waveforms sequentially on x-axis for each channel
fig, axes = plt.subplots(len(sorted_channels), 1, figsize=(15, len(sorted_channels) * 3), sharey=True)

# If only one channel, convert axes to a list for consistency
if len(sorted_channels) == 1:
    axes = [axes]

# Iterate through each channel in reversed order and plot each flattened 5 ms waveform sequentially on the x-axis
for i, channel_label in enumerate(sorted_channels):
    waveforms = extracted_waveforms[channel_label]
    
    # Set initial x offset
    x_offset = 0
    
    # Get the color for the current channel from the colormap
    channel_index = int(channel_label.split('_')[1]) - min(channel_range)  # Adjust for color mapping
    channel_color = cmap(channel_index)
    
    # Plot each flattened waveform with a horizontal offset and the channel color
    for waveform in waveforms:
        # Flatten the waveform to 1D
        flattened_waveform = waveform.flatten()
        
        # Create a new time axis for this flattened waveform with the offset
        time_axis = np.arange(len(flattened_waveform)) + x_offset
        
        # Plot the flattened waveform on the current channel's subplot with the assigned color
        axes[i].plot(time_axis, flattened_waveform, alpha=0.8, color=channel_color)
        
        # Update x_offset for the next waveform with a gap
        x_offset += len(flattened_waveform) + 10  # Add a gap between waveforms (adjust gap as needed)
    
    # Remove grid, axis labels, and ticks for this subplot
    axes[i].grid(False)  # Remove grid
    axes[i].set_xticks([])  # Remove x-axis ticks
    axes[i].set_yticks([])  # Remove y-axis ticks
    axes[i].set_frame_on(False)  # Remove the box frame around the subplot

# Remove common x-label and set a minimal title
plt.suptitle('Flattened 5 ms Window Spikes for Each Channel (Sequential on X-axis)', fontsize=16)
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05, right=0.95, hspace=0.4)  # Adjust layout
plt.show()

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np

# Flatten and concatenate all waveforms
all_waveforms = []
channel_labels = []

for channel_label, waveforms in extracted_waveforms.items():
    # Flatten each waveform and add to list
    flattened_waveforms = [waveform.flatten() for waveform in waveforms]
    all_waveforms.extend(flattened_waveforms)
    channel_labels.extend([channel_label] * len(flattened_waveforms))  # Keep track of channel for each waveform

# Convert to numpy array
all_waveforms = np.array(all_waveforms)

# Check the shape of the concatenated waveforms
print(f"Total waveforms shape: {all_waveforms.shape}")

# Step 2: Apply PCA
n_components = 3  # Number of components for PCA
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(all_waveforms)

# Explained variance by each component
explained_variance = pca.explained_variance_ratio_
print(f"Explained variance ratio: {explained_variance}")

# Step 3: Spike Sorting Using K-means Clustering
n_clusters = 4  # You can change this based on expected neuron types
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(pca_result)

# Step 4: Visualization of PCA Clustering Results
fig, ax = plt.subplots(figsize=(10, 7))

# Scatter plot of the first two principal components, colored by cluster labels
scatter = ax.scatter(pca_result[:, 0], pca_result[:, 1], c=cluster_labels, cmap='viridis', alpha=0.8, edgecolor='k')

# Add a colorbar and labels
plt.colorbar(scatter, label='Cluster')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('PCA of Extracted Waveforms with K-means Clustering')
plt.grid(True)
plt.show()

# Step 5: Plotting the average waveforms of each cluster
fig, axs = plt.subplots(n_clusters, 1, figsize=(10, n_clusters * 3))

for cluster in range(n_clusters):
    # Extract waveforms belonging to the current cluster
    cluster_waveforms = all_waveforms[cluster_labels == cluster]
    
    # Calculate the average waveform
    average_waveform = np.mean(cluster_waveforms, axis=0)
    
    # Plot all waveforms and the average waveform
    for waveform in cluster_waveforms:
        axs[cluster].plot(waveform, color='gray', alpha=0.3)
    axs[cluster].plot(average_waveform, color='red', linewidth=2, label=f'Cluster {cluster + 1} Average')
    
    # Set labels and title
    axs[cluster].set_title(f'Cluster {cluster + 1} Waveforms')
    axs[cluster].set_xlabel('Sample Index')
    axs[cluster].set_ylabel('Amplitude')
    axs[cluster].legend()
    axs[cluster].grid(True)

plt.tight_layout()
plt.show()

In [None]:
from matplotlib.colors import ListedColormap
noise_clusters = [1, 2]  # Replace with the actual noise cluster numbers

filtered_waveforms = all_waveforms[~np.isin(cluster_labels, noise_clusters)]
filtered_pca_result = pca_result[~np.isin(cluster_labels, noise_clusters)]
filtered_cluster_labels = cluster_labels[~np.isin(cluster_labels, noise_clusters)]

# Check the new shape of the filtered data
print(f"Filtered waveforms shape: {filtered_waveforms.shape}")
print(f"Filtered PCA result shape: {filtered_pca_result.shape}")

# Step 3: Define custom light blue and light purple colors for clusters
# Ensure that the number of colors matches the number of clusters (excluding noise)
unique_clusters = np.unique(filtered_cluster_labels)
cluster_colors = {unique_clusters[0]: '#1E90FF', unique_clusters[1]: '#9370DB'}  # Light blue and light purple

# Map the cluster labels to colors for visualization
cluster_colors_array = [cluster_colors[cluster] for cluster in filtered_cluster_labels]

# Visualize the PCA without Noise Clusters using the custom light blue and light purple colors
fig, ax = plt.subplots(figsize=(10, 7))

# Scatter plot of the first two principal components, excluding the noise clusters
scatter = ax.scatter(filtered_pca_result[:, 0], filtered_pca_result[:, 1], 
                     c=cluster_colors_array, alpha=0.8, edgecolor='k')

# Add labels
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('PCA of Extracted Waveforms without Noise Clusters (Light Blue and Light Purple)')
ax.grid(False)  # Remove grid
plt.show()

# Step 4: Plotting the Average Waveforms of Each Non-Noise Cluster with Light Blue and Light Purple
fig, axs = plt.subplots(len(unique_clusters), 1, figsize=(6, len(unique_clusters) * 3))

for i, cluster in enumerate(unique_clusters):
    # Extract waveforms belonging to the current cluster
    cluster_waveforms = filtered_waveforms[filtered_cluster_labels == cluster]
    
    # Calculate the average waveform
    average_waveform = np.mean(cluster_waveforms, axis=0)
    
    # Plot all waveforms in the cluster color and the average waveform in the same color
    for waveform in cluster_waveforms:
        axs[i].plot(waveform, color=cluster_colors[cluster], alpha=0.3)
    axs[i].plot(average_waveform, color=cluster_colors[cluster], linewidth=2, label=f'Cluster {cluster + 1} Average')
    
    axs[i].grid(False)  # Remove grid

plt.tight_layout()
plt.show()

In [None]:
template_metrics = si.compute_template_metrics(sorting_analyzer)
display(template_metrics)

In [None]:
arr=np.load(base_folder/'my_sorting_analyzer_part7_2/extensions/correlograms/ccgs.npy')
bins=np.load(base_folder/'my_sorting_analyzer_part7_2/extensions/correlograms/bins.npy')

In [None]:
from scipy.optimize import curve_fit
def fit_ACG(acg_narrow, plots=True):
    # Setting up the time bins and parameters
    offset = 100
    x = np.arange(1, 101) / 2.0  # Time bins in milliseconds
    
    # Setting time-zero bin to zero (-0.5ms -> 0.5ms)
    acg_narrow[99:101] = 0
    
    # Variables for initial parameters and bounds for curve fitting
    a0 = [20, 1, 30, 2, 0.5, 5, 1.5, 2]
    lb = [1, 0.1, 0, 0, -30, 0, 0.1, 0]
    ub = [500, 50, 500, 15, 50, 20, 5, 100]
    
    fit_params = np.nan * np.ones((8,))
    rsquare = np.nan
    
    # Define the fitting equation
    def predicted(x, a, b, c, d, e, f, g, h):
        return np.maximum(c * (np.exp(-(x - f) / a) - d * np.exp(-(x - f) / b)) + h, 0)
    
    # Perform the fitting
    try:
        popt, _ = curve_fit(predicted, x, acg_narrow, p0=a0, bounds=(lb, ub))
        
        # Generate fitted curve
        fitted_curve = predicted(x, *popt)
        
        # Calculate R-squared value
        residuals = acg_narrow - fitted_curve
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((acg_narrow - np.mean(acg_narrow))**2)
        rsquare = 1 - (ss_res / ss_tot)
        
        # Plotting
        if plots:
            plt.figure(figsize=(10, 6))
            plt.plot(x, acg_narrow, label='Auto-correlogram Data', color='b')
            plt.plot(x, fitted_curve, 'r-', label=f'Rise={popt[1]:.3f}, Decay={popt[0]:.3f}')
            plt.title('Exponential Fit of Auto-Correlogram')
            plt.xlabel('Time Bins (ms)')
            plt.ylabel('Spike Count')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.show()
        
        # Return fit parameters and R-squared value
        fit_params = {
            'acg_tau_decay': popt[0],
            'acg_tau_rise': popt[1],
            'acg_c': popt[2],
            'acg_d': popt[3],
            'acg_asymptote': popt[4],
            'acg_refrac': popt[5],
            'acg_tau_burst': popt[6],
            'acg_h': popt[7],
            'acg_fit_rsquare': rsquare
        }
        
    except RuntimeError:
        print("Fit failed. Returning NaN values.")
    
    return fit_params

In [None]:
acg_length = 100
sample_acg = arr[54,54,100:200]

# Fit exponential curve and plot results
fit_params = fit_ACG(sample_acg)
print(f"Fit parameters: {fit_params}")

num_units = arr.shape[0]
fit_params_units = np.empty((num_units, num_units), dtype=object)

for i in range(num_units):
    acg_unit = arr[i, i, 100:200]  # Extract the autocorrelogram for unit i
    
    fit_params = fit_ACG(acg_unit, plots=False)
    
    # Store the fit parameters
    fit_params_units[i, i] = fit_params

In [None]:
num_units = fit_params_units.shape[0]
units_with_acg_tau_rise_gt_6 = []
tau_rise_values = []
for i in range(num_units):
    fit_params = fit_params_units[i, i]  # Access the dictionary in the first column of each row
    if fit_params is not None and len(fit_params_units[i, i]) == 9:
        tau_rise_values.append(fit_params['acg_tau_rise'])
        if fit_params['acg_tau_rise'] > 6:
            units_with_acg_tau_rise_gt_6.append(i)

# Get the unit IDs from sorting_SC
unit_ids = sorting_SC.get_unit_ids()

# Ensure the mapping is by the order of unit_ids
#mapped_tau_rise = {unit_ids[i]: tau_rise_values[i] for i in range(len(unit_ids))}

# Debugging: Print to verify the mapped tau_rise values
#print("Mapped tau_rise by order:", mapped_tau_rise)

result_WI = [unit_ids[i] in [unit_ids[j] for j in units_with_acg_tau_rise_gt_6] for i in range(len(unit_ids))]

#unit_ids = sorting_SC.get_unit_ids()
#selected_unit_ids = [unit_ids[i] for i in units_with_acg_tau_rise_gt_6]
#result_WI = [unit_id in selected_unit_ids for unit_id in unit_ids]

len(units_with_acg_tau_rise_gt_6)

In [None]:
valid_units = []
tau_rise_values = []

# Filter out units with NaN in fit_params_units
for i in range(num_units):
    fit_params = fit_params_units[i, i]
    if fit_params is not None and len(fit_params) == 9 and not np.isnan(fit_params['acg_tau_rise']):
        valid_units.append(i)
        tau_rise_values.append(fit_params['acg_tau_rise'])

# Get the unit IDs from sorting_SC
unit_ids = sorting_SC.get_unit_ids()

# Map valid tau_rise_values to corresponding unit_ids
# We filter unit_ids to match only valid units
filtered_unit_ids = [unit_ids[i] for i in valid_units]
mapped_tau_rise = {filtered_unit_ids[i]: tau_rise_values[i] for i in range(len(filtered_unit_ids))}
result_WI = [unit_ids[i] in [unit_ids[j] for j in units_with_acg_tau_rise_gt_6] for i in range(len(unit_ids))]

len(units_with_acg_tau_rise_gt_6)

In [None]:
num_units = fit_params_units.shape[0]
units_with_acg_tau_rise_lt_6 = []

tau_rise_values = []
for i in range(num_units):
    fit_params = fit_params_units[i, i]  # Access the dictionary in the first column of each row
    if fit_params is not None and len(fit_params_units[i, i]) == 9:
        tau_rise_values.append(fit_params['acg_tau_rise'])
        if fit_params['acg_tau_rise'] < 6:
            units_with_acg_tau_rise_lt_6.append(i)
            
result_pyramidal = [unit_ids[i] in [unit_ids[j] for j in units_with_acg_tau_rise_lt_6] for i in range(len(unit_ids))]
len(units_with_acg_tau_rise_lt_6)

In [None]:
import spikeinterface.extractors as se
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import (
    compute_snrs,
    compute_firing_rates,
    compute_isi_violations,
    calculate_pc_metrics,
    compute_quality_metrics,
    compute_quality_metrics,
    compute_amplitude_medians
)

In [None]:
metrics1 = compute_quality_metrics(sorting_analyzer, metric_names=["firing_rate", "snr", "amplitude_cutoff", "amplitude_median"])
print(metrics1)
metrics1.to_csv(base_folder /"metrics1.csv", index=False)  # Replace with your desired filename

In [None]:
sorting_analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True)

metrics2 = compute_quality_metrics(
    sorting_analyzer,
    metric_names=[
        "isolation_distance",
        "d_prime",
    ],
)
metrics2.to_csv(base_folder /"metrics2.csv", index=False) 

In [None]:
keep_mask = (metrics1["snr"] > 3)  & (metrics2["isolation_distance"] > 50) & (template_metrics["peak_to_valley"] < 0.00045) 
print(keep_mask)
keep_unit_ids1 = keep_mask[keep_mask].index.values
keep_unit_ids1 = [unit_id for unit_id in keep_unit_ids1]
#keep_unit_ids1 = [8, 290, 81, 238, 299, 164, 239, 132, 379, 302, 383, 312, 303, 385, 384, 135, 345, 7, 393, 347, 138, 40, 349, 359, 421, 427, 37, 438, 15, 199, 451, 155, 192]
spike_trains_narrow = {unit_id: sorting_SC.get_unit_spike_train(unit_id) for unit_id in keep_unit_ids1}
print(keep_unit_ids1)
analyzer_narrow = sorting_analyzer.select_units(keep_unit_ids1)

In [None]:
narrow_sorting = sorting_SC.select_units(keep_unit_ids1)
print(narrow_sorting)
#narrow_sorting.save(folder=base_folder/'narrow_sorting')

print(f"Number of units before curation: {len(sorting_SC.get_unit_ids())}")
print(f"Narrow interneuron: {len(narrow_sorting.get_unit_ids())}")

In [None]:
keep_mask = (metrics1["snr"] > 3)  & (metrics2["isolation_distance"] > 50) & (template_metrics["peak_to_valley"] > 0.00045)  &  result_pyramidal   
print(keep_mask) 
keep_unit_ids2 = keep_mask[keep_mask].index.values
keep_unit_ids2 = [unit_id for unit_id in keep_unit_ids2]
#keep_unit_ids2 = [10, 304, 184, 363, 173, 71, 6, 417, 100, 130, 254, 101, 177, 305, 318, 186, 406, 179, 319, 242, 42, 80, 257, 18, 190, 65, 420, 82, 125, 31, 180, 244, 14, 360, 245, 368, 259, 67, 140, 194, 361, 26, 182, 313, 247, 116, 315, 183, 249, 261, 260, 70, 316, 392, 277, 337, 375, 265, 198, 145, 328, 103, 144, 20, 91, 434, 446, 433, 104, 200, 222, 395, 28, 285, 93, 326, 105, 223, 34, 396, 19, 45, 96, 439, 95, 106, 167, 154, 270, 151, 387, 224, 327, 340, 450, 397, 271, 169, 401, 297, 403, 331, 324, 226, 2, 107, 111, 30, 334, 333, 273, 404, 228, 408, 426, 53, 119, 343, 64, 3, 172, 346, 234, 300, 276, 57]
spike_trains_pyramidal = {unit_id: sorting_SC.get_unit_spike_train(unit_id) for unit_id in keep_unit_ids2}
print(keep_unit_ids2)
analyzer_pyramidal = sorting_analyzer.select_units(keep_unit_ids2)

In [None]:
pyramidal_sorting = sorting_SC.select_units(keep_unit_ids2)
print(pyramidal_sorting)
#pyramidal_sorting.save(folder=base_folder/'pyramidal_sorting')
print(f"Number of units before curation: {len(sorting_SC.get_unit_ids())}")
print(f"pyramidal waveform: {len(pyramidal_sorting.get_unit_ids())}")

In [None]:
keep_mask = (metrics1["snr"] > 3)  & (metrics2["isolation_distance"] > 50) & (template_metrics["peak_to_valley"] > 0.00045)  & result_WI 
print(keep_mask)
keep_unit_ids3 = keep_mask[keep_mask].index.values
keep_unit_ids3 = [unit_id for unit_id in keep_unit_ids3]
#keep_unit_ids3 = [295, 382, 400, 306, 405, 230, 311, 142, 414, 149, 41, 68, 416, 72, 436, 166, 447, 442, 48, 211, 219, 246, 225, 49, 221, 220]
spike_trains_WI = {unit_id: sorting_SC.get_unit_spike_train(unit_id) for unit_id in keep_unit_ids3}
print(keep_unit_ids3)
analyzer_WI = sorting_analyzer.select_units(keep_unit_ids3)

In [None]:
WI_sorting = sorting_SC.select_units(keep_unit_ids3)
print(WI_sorting)
#WI_sorting.save(folder=base_folder/'WI_sorting')
print(f"Number of units before curation: {len(sorting_SC.get_unit_ids())}")
print(f"WI waveform: {len(WI_sorting.get_unit_ids())}")

In [None]:
bin_size = int(fs)
num_samples = full_raw_rec.get_num_samples()
num_bins = num_samples // bin_size
unit_locations=np.load(base_folder/'my_sorting_analyzer_part7_2/extensions/unit_locations/unit_locations.npy')

In [None]:
import pandas as pd
spike_counts = np.zeros((len(keep_unit_ids1), num_bins), dtype=int)
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids1)}

sorted_unit_ids = sorted(keep_unit_ids1, key=lambda unit_id: unit_id_to_location[unit_id][1])

spike_counts1 = np.zeros((len(sorted_unit_ids), num_bins), dtype=int)
for unit_idx, unit_id in enumerate(sorted_unit_ids):
    spike_train = spike_trains_narrow[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts1[unit_idx, :] = binned_spikes
normalized_spike_counts1 = spike_counts1 / spike_counts1.max(axis=1, keepdims=True)
spike_counts_CA1_df = pd.DataFrame(spike_counts1)
spike_counts_CA1_df.insert(0, 'Unit_ID', sorted_unit_ids) 
spike_counts_CA1_df.to_csv(base_folder / "spike_count_NI_2.csv", index=False)

plt.figure(figsize=(6, 6))
plt.imshow(spike_counts1, aspect='auto', cmap='viridis', interpolation='nearest')
#plt.xlim(0,600)
plt.colorbar(label='Spike Count')
plt.xlabel('Time (bins)')
plt.ylabel('Units')
plt.tight_layout()
#plt.savefig(base_folder/'spike_raster_NI.png', format='png', dpi=300)
plt.show()

In [None]:
spike_counts_first_300 = normalized_spike_counts1[:, :300]
spike_counts_second_300 = normalized_spike_counts1[:, 300:600]

# Plotting the heatmaps as subplots
plt.figure(figsize=(12, 8))

# First 300 bins
plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
plt.imshow(spike_counts_first_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('First 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Second 300 bins
plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
plt.imshow(spike_counts_second_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('Second 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Adjust layout for better spacing
plt.tight_layout()
plt.savefig(base_folder/'spike_raster_NI.pdf', format='pdf', dpi=600)
# Show the plot
plt.show()

In [None]:
spike_counts2 = np.zeros((len(keep_unit_ids2), num_bins), dtype=int)
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids2)}

sorted_unit_ids = sorted(keep_unit_ids2, key=lambda unit_id: unit_id_to_location[unit_id][1])

spike_counts = np.zeros((len(sorted_unit_ids), num_bins), dtype=int)
differences = []
for unit_idx, unit_id in enumerate(sorted_unit_ids):
    spike_train = spike_trains_pyramidal[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts2[unit_idx, :] = binned_spikes
normalized_spike_counts2 = spike_counts2 / spike_counts2.max(axis=1, keepdims=True)
spike_counts_CA1_df = pd.DataFrame(spike_counts2)
spike_counts_CA1_df.insert(0, 'Unit_ID', sorted_unit_ids) 
spike_counts_CA1_df.to_csv(base_folder / "spike_count_PC_2.csv", index=False)

plt.figure(figsize=(4, 6))
plt.imshow(spike_counts2, aspect='auto', cmap='viridis', interpolation='nearest')
#plt.xlim(0,600)
plt.colorbar(label='Spike Count')
plt.xlabel('Time (bins)')
plt.ylabel('Units')
plt.tight_layout()
#plt.savefig(base_folder/'spike_raster_PC.png', format='png', dpi=300)
plt.show()

In [None]:
spike_counts_first_300 = normalized_spike_counts2[:, :300]
spike_counts_second_300 = normalized_spike_counts2[:, 300:600]

# Plotting the heatmaps as subplots
plt.figure(figsize=(12, 8))

# First 300 bins
plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
plt.imshow(spike_counts_first_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('First 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Second 300 bins
plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
plt.imshow(spike_counts_second_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('Second 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Adjust layout for better spacing
plt.tight_layout()
plt.savefig(base_folder/'spike_raster_PC.pdf', format='pdf', dpi=600)
# Show the plot
plt.show()

In [None]:
spike_counts3 = np.zeros((len(keep_unit_ids3), num_bins), dtype=int)
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids3)}

sorted_unit_ids = sorted(keep_unit_ids3, key=lambda unit_id: unit_id_to_location[unit_id][1])

spike_counts = np.zeros((len(sorted_unit_ids), num_bins), dtype=int)
differences = []
for unit_idx, unit_id in enumerate(sorted_unit_ids):
    spike_train = spike_trains_WI[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts3[unit_idx, :] = binned_spikes
normalized_spike_counts3 = spike_counts3 / spike_counts3.max(axis=1, keepdims=True)
spike_counts_CA1_df = pd.DataFrame(spike_counts3)
spike_counts_CA1_df.insert(0, 'Unit_ID', sorted_unit_ids) 
spike_counts_CA1_df.to_csv(base_folder / "spike_count_WI_2.csv", index=False)

plt.figure(figsize=(4, 6))
plt.imshow(spike_counts3, aspect='auto', cmap='viridis', interpolation='nearest')
#plt.xlim(0,60)
plt.colorbar(label='Spike Count')
plt.xlabel('Time (bins)')
plt.ylabel('Units')
plt.tight_layout()
#plt.savefig(base_folder/'spike_raster_WI.png', format='png', dpi=300)
plt.show()

In [None]:
spike_counts_first_300 = normalized_spike_counts3[:, :300]
spike_counts_second_300 = normalized_spike_counts3[:, 300:600]

# Plotting the heatmaps as subplots
plt.figure(figsize=(12, 8))

# First 300 bins
plt.subplot(1, 2, 1)  # 1 row, 2 columns, first subplot
plt.imshow(spike_counts_first_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('First 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Second 300 bins
plt.subplot(1, 2, 2)  # 1 row, 2 columns, second subplot
plt.imshow(spike_counts_second_300, aspect='auto', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Spike Count')
plt.title('Second 300 Bins')
plt.xlabel('Time (bins)')
plt.ylabel('Units')

# Adjust layout for better spacing
plt.tight_layout()
plt.savefig(base_folder/'spike_raster_WI.pdf', format='pdf', dpi=600)
# Show the plot
plt.show()

In [None]:
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids1)}
sorted_unit_ids = sorted(keep_unit_ids1, key=lambda unit_id: unit_id_to_location[unit_id][1])
filtered_unit_ids1 = [unit_id for unit_id in sorted_unit_ids if unit_id_to_location[unit_id][1] < 4000]
filtered_unit_ids2 = [unit_id for unit_id in sorted_unit_ids if unit_id_to_location[unit_id][1] > 4000]

#spike_counts = np.zeros((len(filtered_unit_ids1) + len(filtered_unit_ids2), num_bins), dtype=int)
#DG_differences = []
#CA1_differences = []
spike_counts_CA1 = np.zeros((len(filtered_unit_ids1), num_bins), dtype=int)
spike_counts_Cortex = np.zeros((len(filtered_unit_ids2), num_bins), dtype=int)

for unit_idx, unit_id in enumerate(filtered_unit_ids1):
    spike_train = spike_trains_narrow[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_CA1[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    DG_differences.append(first_300_sum - remaining_sum)
for unit_idx, unit_id in enumerate(filtered_unit_ids2):
    spike_train = spike_trains_narrow[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_Cortex[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    CA1_differences.append(first_300_sum - remaining_sum)
spike_counts_CA1_df = pd.DataFrame(spike_counts_CA1)
spike_counts_CA1_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids1)  # Insert the unit IDs as the first column

spike_counts_Cortex_df = pd.DataFrame(spike_counts_Cortex)
spike_counts_Cortex_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids2)  # Insert the unit IDs as the first column

# Save the DataFrames to CSV files
spike_counts_CA1_df.to_csv(base_folder / "spike_count_NI_DG.csv", index=False)
spike_counts_Cortex_df.to_csv(base_folder / "spike_count_NI_CA1.csv", index=False)


In [None]:
print(spike_counts_CA1_df)

In [None]:
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids2)}
sorted_unit_ids2 = sorted(keep_unit_ids2, key=lambda unit_id: unit_id_to_location[unit_id][1])
filtered_unit_ids1 = [unit_id for unit_id in sorted_unit_ids2 if unit_id_to_location[unit_id][1] < 4000]
filtered_unit_ids2 = [unit_id for unit_id in sorted_unit_ids2 if unit_id_to_location[unit_id][1] > 4000]

spike_counts_CA1 = np.zeros((len(filtered_unit_ids1), num_bins), dtype=int)
spike_counts_Cortex = np.zeros((len(filtered_unit_ids2), num_bins), dtype=int)

for unit_idx, unit_id in enumerate(filtered_unit_ids1):
    spike_train = spike_trains_pyramidal[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_CA1[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    DG_differences.append(first_300_sum - remaining_sum)
for unit_idx, unit_id in enumerate(filtered_unit_ids2):
    spike_train = spike_trains_pyramidal[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_Cortex[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    CA1_differences.append(first_300_sum - remaining_sum)

spike_counts_CA1_df = pd.DataFrame(spike_counts_CA1)
spike_counts_CA1_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids1)  # Insert the unit IDs as the first column

spike_counts_Cortex_df = pd.DataFrame(spike_counts_Cortex)
spike_counts_Cortex_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids2)  # Insert the unit IDs as the first column

# Save the DataFrames to CSV files
spike_counts_CA1_df.to_csv(base_folder / "spike_count_PC_DG.csv", index=False)
spike_counts_Cortex_df.to_csv(base_folder / "spike_count_PC_CA1.csv", index=False)

In [None]:
unit_id_to_location = {unit_id: unit_locations[i] for i, unit_id in enumerate(keep_unit_ids3)}
sorted_unit_ids3 = sorted(keep_unit_ids3, key=lambda unit_id: unit_id_to_location[unit_id][1])
filtered_unit_ids1 = [unit_id for unit_id in sorted_unit_ids3 if unit_id_to_location[unit_id][1] < 4000]
filtered_unit_ids2 = [unit_id for unit_id in sorted_unit_ids3 if unit_id_to_location[unit_id][1] > 4000]

spike_counts_CA1 = np.zeros((len(filtered_unit_ids1), num_bins), dtype=int)
spike_counts_Cortex = np.zeros((len(filtered_unit_ids2), num_bins), dtype=int)

for unit_idx, unit_id in enumerate(filtered_unit_ids1):
    spike_train = spike_trains_WI[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_CA1[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    DG_differences.append(first_300_sum - remaining_sum)
for unit_idx, unit_id in enumerate(filtered_unit_ids2):
    spike_train = spike_trains_WI[unit_id]
    binned_spikes, _ = np.histogram(spike_train, bins=num_bins, range=(0, num_samples))
    spike_counts_Cortex[unit_idx, :] = binned_spikes
#    first_300_sum = np.sum(binned_spikes[:300])
#    remaining_sum = np.sum(binned_spikes[300:])
#    CA1_differences.append(first_300_sum - remaining_sum)
spike_counts_CA1_df = pd.DataFrame(spike_counts_CA1)
spike_counts_CA1_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids1)  # Insert the unit IDs as the first column

spike_counts_Cortex_df = pd.DataFrame(spike_counts_Cortex)
spike_counts_Cortex_df.insert(0, 'Mapped_Unit_ID', filtered_unit_ids2)  # Insert the unit IDs as the first column

# Save the DataFrames to CSV files
spike_counts_CA1_df.to_csv(base_folder / "spike_count_WI_DG.csv", index=False)
spike_counts_Cortex_df.to_csv(base_folder / "spike_count_WI_CA1.csv", index=False)

In [None]:
import spikeinterface.widgets as sw
sw.plot_sorting_summary(analyzer_WI, backend="spikeinterface_gui")

In [None]:
plt.figure(figsize=(15, len(unit_ids) * 3))
for i, unit_id in enumerate(unit_ids):
    plt.subplot(len(unit_ids), 1, i + 1)
    plt.hist(amplitudes[unit_id], bins=50, color='blue', edgecolor='black', alpha=0.7)
    plt.title(f'Unit {unit_id} - Peak Spike Amplitudes')
    plt.xlabel('Amplitude')
    plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
import spikeinterface.widgets as sw
unit_ids = sorting_SC.unit_ids[10:20]

sw.plot_unit_waveforms(sorting_analyzer, unit_ids=unit_ids, figsize=(12, 3))

In [None]:
sw.plot_unit_locations(sorting_analyzer, unit_ids=unit_ids, figsize=(4, 8))
ax.set_ylim(-100, 3000)

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D
import umap
from sklearn.preprocessing import StandardScaler

In [None]:
we = si.extract_waveforms(rec, sorting_SC, folder=base_folder/'waveforms')


In [None]:
import umap
from sklearn.preprocessing import StandardScaler

# Assuming keep_unit_ids1, keep_unit_ids2, keep_unit_ids3, and waveform extraction functions are defined

waveform_features = []
labels = []

# Iterate over all selected unit IDs and collect data
for unit_id in keep_unit_ids1 + keep_unit_ids2 + keep_unit_ids3:
    # Get waveforms for the unit and average them across spikes
    waveforms = we.get_waveforms(unit_id)
    avg_waveform = np.mean(waveforms, axis=0)  # Average across spikes
    waveform_features.append(avg_waveform.flatten())  # Flatten and store
    
    # Label based on the unit group
    if unit_id in keep_unit_ids1:
        labels.append('NI')
    elif unit_id in keep_unit_ids2:
        labels.append('PC')
    else:
        labels.append('WI')

# Convert lists to numpy arrays
waveform_features = np.array(waveform_features)

# Normalize the waveform features
scaler = StandardScaler()
waveform_features = scaler.fit_transform(waveform_features)

# Perform UMAP on the normalized waveform features
umap_reducer = umap.UMAP(n_components=5, random_state=42)
waveform_umap = umap_reducer.fit_transform(waveform_features)

# Convert labels to numpy array
labels = np.array(labels)

# Create a figure for plotting
fig, ax = plt.subplots(figsize=(10, 8))

unique_labels = np.unique(labels)
colors = ['blue', 'green', 'red']

# Plot the UMAP results
for i, label in enumerate(unique_labels):
    indices = np.where(labels == label)
    ax.scatter(waveform_umap[indices, 0], waveform_umap[indices, 1], 
               c=colors[i], label=label, alpha=0.5)

ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.legend()
plt.show()

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import matplotlib.pyplot as plt

# Perform LDA on the waveform features
lda = LDA(n_components=2)
waveform_lda = lda.fit_transform(waveform_features, labels)

# Create an LDA plot
fig, ax = plt.subplots(figsize=(8, 8))
unique_labels = np.unique(labels)
colors = ['blue', 'green', 'red']

for i, label in enumerate(unique_labels):
    indices = np.where(labels == label)
    ax.scatter(waveform_lda[indices, 0], waveform_lda[indices, 1], 
               c=colors[i], label=label, alpha=0.5)

ax.set_xlabel('LDA 1')
ax.set_ylabel('LDA 2')
ax.legend()
plt.title('LDA of Neuron Waveforms')
plt.show()

In [None]:
template_metrics=template_metrics["peak_to_valley"]

In [None]:
x_wide, y_wide = [], []
x_narrow, y_narrow = [], []
x_pyramidal, y_pyramidal = [], []

# Collect data for wide interneurons
for unit_id in keep_unit_ids1:
    x_wide.append(template_metrics[unit_id])
    y_wide.append(mapped_tau_rise[unit_id])

# Collect data for narrow interneurons
for unit_id in keep_unit_ids2:
    x_narrow.append(template_metrics[unit_id])
    y_narrow.append(mapped_tau_rise[unit_id])

# Collect data for pyramidal cells
for unit_id in keep_unit_ids3:
    x_pyramidal.append(template_metrics[unit_id])
    y_pyramidal.append(mapped_tau_rise[unit_id])

colors = ['blue', 'green', 'red']
fig, ax = plt.subplots(figsize=(8, 8))
# Plot Narrow Interneuron (NI)
plt.scatter(x_wide, y_wide, edgecolor=colors[0], facecolor=colors[0], label='NI', marker='o', alpha=0.5, linewidths=1.5)

# Plot Pyramidal Cell (PC)
plt.scatter(x_narrow, y_narrow, edgecolor=colors[1], facecolor=colors[1], label='PC', marker='o', alpha=0.5, linewidths=1.5)

# Plot Wide Interneuron (WI)
plt.scatter(x_pyramidal, y_pyramidal, edgecolor=colors[2], facecolor=colors[2], label='WI', marker='o', alpha=0.5, linewidths=1.5)

plt.xlabel('Template Metric (e.g., Peak-to-Valley)')
plt.ylabel('Tau Rise')
plt.title('Neurons by Template Metric and Tau Rise')
plt.legend()
plt.show()


In [None]:
import matplotlib.cm as cm
probe = rec.get_probe()
channel_locations = probe.contact_positions  # Get the electrode positions
filtered_indices = np.isin(unit_ids, keep_unit_ids1)
filtered_x_coords = unit_locations[filtered_indices, 0]
filtered_y_coords = unit_locations[filtered_indices, 1]

# Generate a color map with a unique color for each unit
colors = cm.rainbow(np.linspace(0, 1, len(keep_unit_ids1)))

# Plot the Neuropixels probe layout
plt.figure(figsize=(8, 10))
plt.scatter(channel_locations[:, 0], channel_locations[:, 1], c='gray', marker='o', label='Electrodes', alpha=0.5)

for i, unit_id in enumerate(keep_unit_ids1):
    unit_x = filtered_x_coords[unit_ids[filtered_indices] == unit_id]
    unit_y = filtered_y_coords[unit_ids[filtered_indices] == unit_id]
    plt.scatter(unit_x, unit_y, c=[colors[i]], marker='o', label=f'Unit {unit_id}')

plt.title('Pyramidal Cell Locations on Neuropixels Probe')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
#plt.legend()
plt.show()