# Introduction
This pipeline is based on Spikeinterface version > 0.100 to sort Intan data.

"recording", "sorting", and "analyzer" are three key components of Spikeinterface processing.
1. "recording" is recording data stored in Spikeinterface format.
2. "sorting" is sorted results of Spikeinterface.
3. "analyzer" is generated from "sorting" and "recording" to use the post-sorting analyzing packages in Spikeinterface version > 0.100.
multi-session processing data and results in this pipeline are defined as "recordings" and 'sortings'.

"analyzer" is used in NEW version of Spikeinterface to post-process the sorted results.
"waveform_extractors" is used in OLD version of Spikeinterface to post-process the sorted results.
You are recommended to do post-processing, curation, and plot with following pipelines developed with "analyzer".
However, if you need plots related with "waveform_extractors" of old version Spikeinterface. 
For example, you want to check the plot of averaged waveforms during curation, or you want the template on all electrode as final results.
Please jump the compatible code with the kernel that has old version Spikeinterface here.

IMPORTANT! YOU HAVE TO CHANGE PARAMETERS IN "Set I/O path and sorter threshold" AND "Define the 'probe'" SECTIONS TO ADAPT NEW DATA.

# Import package

In [None]:
import glob 
import matplotlib.pyplot as plt 
import numpy as np
import os
import pandas as pd
import spikeinterface as si
import spikeinterface.core as sc 
import spikeinterface.curation as scu
import spikeinterface.extractors as se 
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.sorters as ss
import spikeinterface.widgets as sw
import sys 
import quantities as pq
import neo
import umap
import json

from scipy.stats import ttest_ind
from scipy.io import savemat
from probeinterface import generate_multi_columns_probe
from probeinterface.plotting import plot_probe
from probeinterface.utils import combine_probes
from elephant.gpfa import GPFA
from sklearn.decomposition import PCA
from src.importrhdutilities import load_file

sys.path.append('src')

# Set I/O path and sorter threshold
Have to change to adapt new data: project_name, subject_name, processing_name, segment_paths, output_root, 

In [None]:
"""
# This cells contains the parameters that has to change to adapt new data
"""
project_name = 'axolotl_tail_amputation_multisession_64ch' # The parent folder name of recording data, recommended to be set as the name of project
subject_name = 'axolotl_A' # The folder name of recording data, recommended to be set as the name of subject for each experiment
processing_name = '101424-101824-7sections' # The folder name of individual processing, recommended to be set as the combination name of all recording segments input
segment_paths = [f'data/{project_name}//{subject_name}/axolotl_embryo_before_amputation_1',
                f'data/{project_name}//{subject_name}/axolotl_embryo_before_amputation_2',
                f'data/{project_name}//{subject_name}/axolotl_embryo_after_amputation_3',
                f'data/{project_name}//{subject_name}/axolotl_embryo_after_amputation_4',
                f'data/{project_name}//{subject_name}/axolotl_embryo_after_amputation_5',
                f'data/{project_name}//{subject_name}/axolotl_embryo_after_amputation_6',
                f'data/{project_name}//{subject_name}/axolotl_embryo_after_amputation_7',
                ] # The folder path of all recording segments need to be concatenated

"""
# Adjust threshold(sorter_threshold) for sorting as needed. The "sorting" and analyzer" are named with suffix of threshold
"""
sorter_threshold = 3.0 


In [None]:
"""
# Dont' change this cell
"""
output_root = f'data/processed/{project_name}/{subject_name}/{processing_name}'# The output root of processed results

recordings_folder = f'{output_root}/recordings' # Define the path where the recordings will be saved
sortings_folder = f'{output_root}/sortings-{sorter_threshold}' # Define the path where the sortings will be saved
analyzer_folder = f'{output_root}/analyzer-{sorter_threshold}' # Define the path where the "analyzer" will be saved
analyzers_folder = f'{output_root}/analyzers-{sorter_threshold}' # Define the path where the "analyzer" will be saved

# Define the "probe"
Have to change to adapt new data:   intan_channel_indices = np.array(), 
                                    probe0 = generate_multi_columns_probe(), 
                                    probe0.rotate(), 
                                    probe0.set_device_channel_indices()
                            
Add more probes as needed 

In [None]:
"""
# Intan data's all channel indices
# change based on the channel number that Intan data has signal and need to be sorted


# If you use multiple Intan chips, their channel indices is continuously listed.
# For example, you use 64ch Intan chip at port A and 64 Intan chip at port B, then recorded Data will be ch0-63 for chip A and ch64-127 for chip B
# You use 16-47ch of 64ch Intan chip at port A and 16-47ch of 64ch Intan chip at port B, then recorded Data will be ch0-31 for chip A and ch32-63 for chip B
"""
# 64 ch channel indices
intan_channel_indices = np.array([ 
    [0, 1, 2, 3, 4, 5, 6, 7], 
    [8, 9, 10, 11, 12, 13, 14, 15], 
    [16, 17, 18, 19, 20, 21, 22, 23], 
    [24, 25, 26, 27, 28, 29, 30, 31],
    [32, 33, 34, 35, 36, 37, 38, 39], 
    [40, 41, 42, 43, 44, 45, 46, 47], 
    [48, 49, 50, 51, 52, 53, 54, 55], 
    [56, 57, 58, 59, 60, 61, 62, 63],
]) 

def find_shank(channel_index):
    for shank, shank_indices in enumerate(intan_channel_indices):
        if channel_index in shank_indices:
            return shank


def create_probe(channel_indices, savepath=None, show_probe=True):
    n_shank = len(channel_indices)
    n_channel = channel_indices.size
    shank_locations = np.array([[0, 0]])

    plt.figure(figsize=(20, 40))
    ax = plt.gca()
    
    """
    # Probe layout
    # change based on the geometry design of probe in the mask
    # not necessary to be probe0, could be multi-probe (probe0, probe1..) as convenience, just need to make sure the final geometry looks right. 
    """
    probe0 = generate_multi_columns_probe(num_columns=8,
                                        num_contact_per_column=8,
                                        xpitch=50, ypitch=50,
                                        contact_shapes='circle', contact_shape_params={'radius': 8},
                                        )
    probe0.rotate(0)
    
    """
    # probe(device)'s channel indices 
    # change based the how Intan data channel indices match to probe(device) channel indices, need to check the intan headstage electrode connector pinout, PCA layout and mask design 
    """
    probe0.set_device_channel_indices([31, 29, 27, 25, 23, 21, 19, 17, 
    15, 13, 11, 9, 7, 5, 3, 1, 
    0, 2, 4, 6, 8, 10, 12, 14, 
    16, 18, 20, 22, 24, 26, 28, 30,
    49, 51, 53, 55, 57, 59, 61, 63, 
    33, 35, 37, 39, 41, 43, 45, 47, 
    46, 44, 42, 40, 38, 36, 34, 32, 
    62, 60, 58, 56, 54, 52, 50, 58,
    ])
    plot_probe(probe0, with_device_index=True, ax=ax)

    multi_shank_probe = combine_probes([probe0])
    multi_shank_probe.set_device_channel_indices(channel_indices.flatten())

    n_channel = 64
    n_shank = 1

    plt.xlim(-50, 400)
    plt.ylim(-150, 400)
    plt.title(f'Probe - {n_channel}ch - {n_shank}shanks')
    if savepath is not None:
        plt.savefig(savepath, bbox_inches='tight')
        
    if show_probe:
        plt.show()
    plt.close()
    return multi_shank_probe

probe = create_probe(intan_channel_indices, show_probe=True)

# Load Intan data
Optional to change to adapt new data to choose channel: if file_traces.shape[0] == 128, selected_channels = list(range()) + list(range())

In [None]:
"""
# The code in this cell has a function to delete the redundant Intan channels
# For example, in the first recording, you use 64 channel Intan chip to record 32 channels, 
# in the second recording, you forget to turn off the redundant channels
# The code here delete the redundant channels to make sure the data and concatenated together correctly

# Use list(range()) + list(range()) to choose the channel indices that has real signal
"""

# Reading data in selected channels
session_info_file = f'{output_root}/session_info.csv'
if not os.path.isfile(session_info_file):
    session_info = []
    file_start = 0
    os.makedirs(recordings_folder, exist_ok=True)
    for segment_index, segment_path in enumerate(segment_paths):
        file_index = 0
        traces = []
        for recording_path in sorted(glob.glob(f'{segment_path}/*.rhd')):
            try:
                raw_data, data_present = load_file(recording_path)
            except:
                data_present = False
            if data_present:
                sampling_frequency = raw_data['frequency_parameters']['amplifier_sample_rate']
                file_traces = raw_data['amplifier_data']
                
                # Check if the data has 128 channels
                if file_traces.shape[0] == 128:
                    # Select channels 16-47 and 80-111
                    selected_channels = list(range(16, 48)) + list(range(80, 112))
                    file_traces = file_traces[selected_channels, :]
                
                session_info.append({
                    'segment': segment_index,
                    'segment_path': segment_path,
                    'file_index': file_index,
                    'file_path': recording_path,
                    'file_start': file_start,
                    'file_duration': file_traces.shape[1],
                })
                traces.append(file_traces)   
                file_start += file_traces.shape[1]  
                file_index += 1
        traces = np.hstack(traces)

        # Save the recording with selected channels only
        segment_recording = se.NumpyRecording(traces_list=traces.T, sampling_frequency=sampling_frequency)
        segment_recording.save(folder=f'{recordings_folder}/segment{segment_index}')
    session_info = pd.json_normalize(session_info)
    session_info.to_csv(session_info_file, index=False)
    
session_info = pd.read_csv(session_info_file)
session_info

# Preprocess and concatenate data to get "recording"
Optional to change to adapt new data: freq_min, freq_max

In [None]:
# Factual parameters
n_s_per_min = 60
n_ms_per_s = 1000

# Concatenate data
n_segment = len(session_info['segment'].unique())

recordings = []
for segment_index in range(n_segment):
    segment_recording = sc.load_extractor(f'{recordings_folder}/segment{segment_index}')
    
    """
    # Adjust frequency (freq_min and freq_max) of filtering as needed
    """
    segment_recording = spre.bandpass_filter(segment_recording, freq_min=300, freq_max=3000)
    
    segment_recording = spre.common_reference(segment_recording, reference='global', operator='median')
    segment_recording = segment_recording.set_probe(probe)
    recordings.append(segment_recording)
recording = sc.concatenate_recordings(recordings)
recording = recording.set_probe(probe)
n_frames_per_ms = recording.sampling_frequency // n_ms_per_s
recording

# Get basic information of concatenated data
n_frames_per_ms = recording.sampling_frequency // n_ms_per_s
channel_ids = recording.get_channel_ids()
channel_num = len(channel_ids)
sampling_frequency = recording.get_sampling_frequency()
recording_time = recording.get_total_duration()

# Sort data to get "sorting"
Optional to change to adapt new data:other parameters for sorting as needed

In [None]:
"""
# Adjust other parameters for sorting as needed
"""
sorter_parameters = {
    'detect_sign': -1,
    'adjacency_radius': -1, 
    'freq_min': None, 
    'freq_max': None,
    'filter': False,
    'whiten': True,  
    'clip_size': 50,
    'num_workers': 8,
    'detect_interval': 9,
    'detect_threshold': sorter_threshold,
}

if not os.path.isfile(f'{sortings_folder}/sorter_output/firings.npz'):
    ss.run_sorter(
        sorter_name='mountainsort4',
        recording=recording,
        folder = sortings_folder,
        remove_existing_folder=True,
        with_output=False,
        **sorter_parameters,
    )

sorting = se.NpzSortingExtractor(f'{sortings_folder}/sorter_output/firings.npz')
sorting = scu.remove_excess_spikes(sorting, recording)
splitted_sorting = sc.split_sorting(sorting, recordings)
sortings = [sc.select_segment_sorting(splitted_sorting, segment_indices=segment) for segment in range(n_segment)]

print(f'Number of units: {len(sorting.unit_ids)}')

# Generate "analyzer" for all segments and "analyzers" for each segment

In [None]:
# Generate the "analyzer" that is used with new version of Spikeinterface

# Generate "analyzer"
if os.path.exists(analyzer_folder):
    # If the folder exists, load the existing analyzer
    analyzer = si.load_sorting_analyzer(analyzer_folder)
    print("Analyzer loaded from existing folder.")
else:
    # If the folder doesn't exist, create a new analyzer
    analyzer = si.create_sorting_analyzer(sorting=sorting,
                                    recording=recording,
                                    format="binary_folder",
                                    return_scaled=True,  # Default is to return scaled
                                    folder=analyzer_folder
                                    )
    # Compute necessary extensions for following processing-----------------------
    analyzer.compute("random_spikes", method="all")
    analyzer.compute("waveforms", ms_before=1.5,ms_after=2.0)
    analyzer.compute("templates", operators=["average", "median", "std"])
    analyzer.compute("noise_levels")
    analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True)
    analyzer.compute(input="template_similarity", method='cosine_similarity')
    analyzer.compute("spike_amplitudes", peak_sign="neg")
    analyzer.compute("unit_locations", method="monopolar_triangulation")
    analyzer.compute("template_metrics", include_multi_channel_metrics=True)
    analyzer.compute("correlograms",window_ms=50.0,bin_ms=1.0,method="auto")
    analyzer.compute("isi_histograms",window_ms=50.0,bin_ms=1.0,method="auto")
    #------------------------------------------------------------------------------ 
    print("New analyzer created and saved.")
print(analyzer)


# Generate "quality metrics" based on "analyzer" and "analyzers"

In [None]:
# Path for saving the quality metrics
quality_metrics_path = f"{output_root}/quality_metrics-{sorter_threshold}.csv"

# Check if the quality metrics CSV already exists
if os.path.exists(quality_metrics_path):
    print(f"Metrics file already exists at {quality_metrics_path}. Skipping calculations.")
    quality_metrics = pd.read_csv(quality_metrics_path, index_col=0)
else:
    # List of metrics to compute
    metrics_to_compute = [
        'firing_rate', 
        'amplitude_median',
        'isi_violation', 
        'silhouette',  
        'snr',
    ]

    # Compute quality metrics
    quality_metrics = sqm.compute_quality_metrics(
        analyzer,
        metric_names=metrics_to_compute
    )

    # Save the metrics to a CSV file
    quality_metrics.to_csv(quality_metrics_path)
    print(f"Metrics saved to {quality_metrics_path}")

In [None]:
# List of metrics to compute
metrics_to_compute = [
    'firing_rate', 
    'amplitude_median',
    'isi_violation', 
    'silhouette',  
    'snr',
]

# Check and process each segment independently
for segment_idx, (recording, sorting) in enumerate(zip(recordings, sortings)):
    # Define the folder and file path for the current segment's metrics
    segment_analyzer_folder = os.path.join(analyzers_folder, f"segment_{segment_idx}")
    segment_metrics_path = f"{segment_analyzer_folder}/quality_metrics_segment_{segment_idx}.csv"
    
    if os.path.exists(segment_metrics_path):
        print(f"Metrics for segment {segment_idx} already exist. Skipping calculations.")
        continue
    
    # Load the analyzer for the current segment
    if os.path.exists(segment_analyzer_folder):
        analyzer = si.load_sorting_analyzer(segment_analyzer_folder)
        print(f"Analyzer for segment {segment_idx} loaded.")
    else:
        print(f"Analyzer for segment {segment_idx} does not exist. Skipping this segment.")
        continue
    
    # Compute quality metrics for this segment
    segment_metrics = sqm.compute_quality_metrics(
        analyzer,
        metric_names=metrics_to_compute
    )
    
    # Save segment-specific metrics to a CSV file
    segment_metrics.to_csv(segment_metrics_path)
    print(f"Metrics for segment {segment_idx} saved to {segment_metrics_path}.")

# Define basic plot functions for curation

In [None]:
# Define plot functions based on unit without saving functions
def unit_based_plot(sorting, analyzer, selected_units):
    """
    Plot template, unit probe map, ISI distribution,  autocorrelogram for each unit in selected_units.

    Parameters:
    - sorting: sorting of Spikeinterface
    - analyzer: analyzer of Spikeinterface
    - waveforms: waveforms of Spikeinterface
    - selected_units: A list of unit IDs to plot.
    """
    
    # Loop through the selected units
    for unit_id in selected_units:
        fig, axs = plt.subplots(1, 4, figsize=(15, 5))
        fig.suptitle(f'Plots for Unit {unit_id}', fontsize=16)
                
        # Plot template for the current unit 
        sw.plot_unit_templates(
            analyzer,
            unit_ids=[unit_id],
            ax=axs[0],
            same_axis=True,
            plot_channels="all",
            sparsity=None  # This ensures all channels are plotted
        )
        axs[0].set_title('Template (Whole Probe)')
                            
        # Plot template for the current unit 
        sw.plot_unit_templates(
            analyzer,
            unit_ids=[unit_id],
            ax=axs[1],
            same_axis=True,
            sparsity=None  # This ensures all channels are plotted
        )
        axs[1].set_title('Template (Local)')
                             
        # Plot ISI distribution for the current unit 
        sw.plot_isi_distribution(sorting.select_units([unit_id]), ax=axs[2])
        axs[2].set_title('ISI Distribution')
         
        # Plot autocorrelogram for the current unit 
        sw.plot_autocorrelograms(analyzer, unit_ids=[unit_id], ax=axs[3])
        axs[3].set_title('Autocorrelogram')
        
        # Adjust layout for better appearance
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

# Define plot functions based on unit with saving function
def unit_based_plot_save(sorting, analyzer, selected_units):
    """
    Plot template, unit probe map, ISI distribution,  autocorrelogram for each unit in selected_units.

    Parameters:
    - sorting: sorting of Spikeinterface
    - analyzer: analyzer of Spikeinterface
    - waveforms: waveforms of Spikeinterface
    - selected_units: A list of unit IDs to plot.
    """
    # Loop through the selected units
    for unit_id in selected_units:
        fig, axs = plt.subplots(1, 4, figsize=(15, 5))
        fig.suptitle(f'Plots for Unit {unit_id}', fontsize=16)
                
        # Plot template for the current unit 
        sw.plot_unit_templates(
            analyzer,
            unit_ids=[unit_id],
            ax=axs[0],
            same_axis=True,
            plot_channels="all",
            sparsity=None  # This ensures all channels are plotted
        )
        axs[0].set_title('Template (Whole Probe)')
                        
        # Plot template for the current unit 
        sw.plot_unit_templates(
            analyzer,
            unit_ids=[unit_id],
            ax=axs[1],
            same_axis=True,
            sparsity=None  # This ensures all channels are plotted
        )
        axs[1].set_title('Template (Local)')
                        
        # Plot ISI distribution for the current unit 
        sw.plot_isi_distribution(sorting.select_units([unit_id]), ax=axs[2])
        axs[2].set_title('ISI Distribution')
        
        # Plot autocorrelogram for the current unit 
        sw.plot_autocorrelograms(analyzer, unit_ids=[unit_id], ax=axs[3])
        axs[3].set_title('Autocorrelogram')
        
        # Adjust layout for better appearance
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        fig.savefig(os.path.join(output_root, f"unit_{unit_id}_properties.svg"), format="svg")
        plt.show()       
        
# Define plot functions based on probe
def probe_based_plot(sorting, analyzer, selected_units):
    """
    Plot a raster plot, unit presence, unit locations, and the unit-probe map for the selected units based on the probe configuration.
    
    Parameters:
    - sorting: The sorting object containing spike times and unit information.
    - analyzer: The analyzer object containing computed waveform templates and other analysis data.
    - selected_units: A list of unit IDs to include in the plots.
    """
    # Select the units you want to plot
    sorting_selected = sorting.select_units(selected_units)
    analyzer_selected = analyzer.select_units(selected_units)
        
    # Plot the raster plot
    fig, ax = plt.subplots(figsize=(12, 8))
    sw.plot_rasters(sorting_selected, time_range=(3610,4210),ax=ax)
    plt.title(f'Raster Plot for Units: {selected_units}')
    fig.savefig(os.path.join(output_root, "raster_plot.svg"), format="svg")
    plt.show()
    
    # Plot the unit presence over time
    fig, ax = plt.subplots(figsize=(12, 8))
    sw.plot_unit_presence(sorting_selected, ax=ax)
    plt.title(f'Unit Presence Plot for Units: {selected_units}')
    fig.savefig(os.path.join(output_root, "unit_presence.svg"), format="svg")
    plt.show()
    
    # Plot the unit locations
    fig, ax = plt.subplots(figsize=(12, 8))
    sw.plot_unit_locations(analyzer_selected, plot_legend=True, ax=ax)
    plt.title(f'Unit Locations for Units: {selected_units}')
    fig.savefig(os.path.join(output_root, "unit_locations.svg"), format="svg")
    plt.show()

# Curation 

In [None]:
"""
# First round curation based on unit template and ISI
"""
first_round_curated_units=[]  

In [None]:
"""
# Second round curation based on "quality metrics"
"""
first_round_metrics = quality_metrics.loc[first_round_curated_units]

# Set the curation parameter based on the quality metrics------------
isi_violations_ratio_threshold = 2 # quantify the proportion of spikes that violate the refractory period, range >0, <0.5 is great, >2 is bad
snr_threshold = 3 # >5 is great, 3-5 is ok, <3 is bad
silhouette_threshold = 0.1  # silhouette_score quantify separation from other units, range from -1 to 1. Good Quality: Silhouette Score > 0.5, Moderate Quality: Silhouette Score 0.1 to 0.5
#---------------------------------------------------------------------

# Apply the keep_mask only to units from the first round of curation
#keep_mask = (first_round_metrics["isi_violations_ratio"] < isi_violations_ratio_threshold) & (first_round_metrics["snr"] > snr_threshold)
keep_mask = (first_round_metrics["isi_violations_ratio"] < isi_violations_ratio_threshold) & (first_round_metrics["snr"] > snr_threshold) & (first_round_metrics["silhouette"] > silhouette_threshold)
final_curated_units = first_round_metrics[keep_mask].index.values
final_curated_units = [unit_id for unit_id in final_curated_units]
print(f"Final curated units based on quality metrics: {final_curated_units}")

# Define the final curated data with native Python integers
final_curated_data = [int(unit_id) for unit_id in final_curated_units]

# Save figure
fig.savefig(os.path.join(output_root, "unit_locations.svg"), format="svg")

# Save curated data in JSON format in the same output path
curated_data_path = os.path.join(output_root, "final_curated_data.json")
with open(curated_data_path, "w") as f:
    json.dump(final_curated_data, f)
print(f"Curated data saved to {curated_data_path}")

# plot firing change base on metrics

In [None]:
# Initialize a dictionary to store firing rates for each unit across segments
curated_units_firing_rate = {}

# Iterate over all segments and load the metrics
for segment_idx in range(len(recordings)):
    # Define the metrics file path for the current segment
    segment_metrics_path = os.path.join(analyzers_folder, f"segment_{segment_idx}", f"quality_metrics_segment_{segment_idx}.csv")
    
    if not os.path.exists(segment_metrics_path):
        print(f"Metrics file for segment {segment_idx} not found. Skipping this segment.")
        continue
    
    # Load the quality metrics for the current segment
    segment_metrics = pd.read_csv(segment_metrics_path, index_col=0)
    
    # Iterate through all curated units and store their firing_rate
    for unit_id in final_curated_units:
        if unit_id in segment_metrics.index:
            if segment_idx not in curated_units_firing_rate:
                curated_units_firing_rate[segment_idx] = []
            firing_rate = segment_metrics.loc[unit_id, "firing_rate"]
            if not pd.isna(firing_rate):  # Exclude NaN firing rates
                curated_units_firing_rate[segment_idx].append(firing_rate)
        else:
            print(f"Unit {unit_id} not found in segment {segment_idx}. Skipping.")

# Prepare data for bar plot and calculate t-test p-values
segment_indices = []
average_firing_rates = []
firing_rate_sds = []
p_values = []

for segment_idx, firing_rates in curated_units_firing_rate.items():
    segment_indices.append(segment_idx)
    average_firing_rates.append(np.mean(firing_rates))
    firing_rate_sds.append(np.std(firing_rates))

# Calculate p-values for consecutive segments
for i in range(len(segment_indices) - 1):
    segment_1_rates = curated_units_firing_rate[segment_indices[i]]
    segment_2_rates = curated_units_firing_rate[segment_indices[i + 1]]
    
    # Perform un-paired t-test
    t_stat, p_value = ttest_ind(segment_1_rates, segment_2_rates, equal_var=False)
    p_values.append((segment_indices[i], segment_indices[i + 1], p_value))

# Plot the bar plot with gray dots for each unit's firing rate
plt.figure(figsize=(12, 6))
bar_positions = np.arange(len(segment_indices))
plt.bar(bar_positions, average_firing_rates, yerr=firing_rate_sds, capsize=5, color='lightblue', alpha=0.7, label='Average Firing Rate')

# Plot gray dots on the bars for individual unit firing rates
for i, (segment_idx, firing_rates) in enumerate(curated_units_firing_rate.items()):
    plt.scatter([bar_positions[i]] * len(firing_rates), firing_rates, color='gray', alpha=0.6, label='_nolegend_')

# Formatting the plot
plt.xticks(bar_positions, segment_indices)
plt.title("Average Firing Rate Across Segments")
plt.xlabel("Segment Index")
plt.ylabel("Firing Rate (Hz)")
plt.legend(title="Data", loc='upper right')
plt.tight_layout()

# Save the figure
figure_path = os.path.join(output_root, "average_firing_rate_across_segments.svg")
plt.savefig(figure_path, format='svg')
print(f"Figure saved to {figure_path}")

# Show the figure
plt.show()

# Display p-values
print("Un-paired t-test p-values between consecutive segments:")
for seg_1, seg_2, p_val in p_values:
    print(f"Segment {seg_1} vs. Segment {seg_2}: p-value = {p_val:.5f}")

# Generate file for DataHigh neural population analysis in Matlab

In [None]:
def generate_datahigh_input(
    curated_units,
    recording_duration,                # The total time of concatenated segments (in second)
    segment_duration,                   # The time of each segment (in second, up to 10 conditions or segments)
    time_range,                  # The time range in each segment that will be processed (in second)
    trial_time,                          # The trail time in each time range (in second)
    bin_size                           # The bin_size in each trail (in second, should be adjusted to allow the data array has reasonable number of '1')
):
    # Define output file name
    output_file = os.path.join(output_root, "datahigh_input.mat")
    
    # Calculate the number of conditions
    num_conditions = int(recording_duration / segment_duration)
    conditions = [f"condition_{i+1}" for i in range(num_conditions)]
    
    # Define colors for up to 10 conditions
    base_colors = [
        [1, 0, 0],    # Red
        [0, 1, 0],    # Green
        [0, 0, 1],    # Blue
        [1, 1, 0],    # Yellow
        [1, 0, 1],    # Magenta
        [0, 1, 1],    # Cyan
        [0.5, 0.5, 0.5],  # Gray
        [0.5, 0, 0],  # Dark Red
        [0, 0.5, 0],  # Dark Green
        [0, 0, 0.5]   # Dark Blue
    ]
    epoch_colors = base_colors[:num_conditions]  # Use only as many colors as conditions
    
    # Calculate the number of trials within the time_range for each condition
    num_trials_per_condition = int((time_range[1] - time_range[0]) / trial_time)
    total_trials = num_trials_per_condition * num_conditions
    
    # Initialize the structured array with the specified field order
    D = np.empty(total_trials, dtype=[('data', 'O'), ('epochStarts', 'O'), ('epochColors', 'O'), ('condition', 'O')])
    
    # Initialize dictionary to store 1 ratios for each condition
    condition_ratios = {}
    
    # Loop over each condition
    trial_counter = 0
    for condition_idx, condition_name in enumerate(conditions):
        # Define the start time of the condition in the overall recording
        condition_start_time = condition_idx * segment_duration
        color = epoch_colors[condition_idx]
        
        # Define the start and end times for this time_range within the condition
        range_start = condition_start_time + time_range[0]
        range_end = condition_start_time + time_range[1]
        
        # Initialize counters for this condition
        condition_ones = 0
        condition_bins = 0
        
        # Generate trials within the specified time range for this condition
        for trial_num in range(num_trials_per_condition):
            trial_start_time = range_start + trial_num * trial_time
            trial_end_time = trial_start_time + trial_time
            
            # Initialize an empty list to hold data for each unit
            trial_data = []

            for unit_id in curated_units:
                # Get spike times for the unit and filter by trial window within the specified time range
                unit_spike_s = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency()
                trial_spikes = unit_spike_s[(unit_spike_s >= trial_start_time) & (unit_spike_s < trial_end_time)]
                
                # Bin the spike times into 1 ms bins
                bins = np.arange(trial_start_time, trial_end_time, bin_size)
                binned_spikes, _ = np.histogram(trial_spikes, bins=bins)
                
                # Convert to binary (0's and 1's)
                binary_spikes = (binned_spikes > 0).astype(int)
                
                # Count 1s and total bins for ratio calculation for this condition
                condition_ones += np.sum(binary_spikes)
                condition_bins += binary_spikes.size
                
                # Append to trial data
                trial_data.append(binary_spikes)

            # Convert trial_data to a matrix (neurons x time bins)
            trial_matrix = np.array(trial_data)
            
            # Assign values to the structured array fields
            D[trial_counter]['data'] = trial_matrix  # Set data to the spike train matrix
            D[trial_counter]['epochStarts'] = np.array([1])  # Set epochStarts to 1 for all trials
            D[trial_counter]['epochColors'] = np.array([color])  # Set color based on condition
            D[trial_counter]['condition'] = condition_name  # Assign the condition name
            
            # Increment the trial counter
            trial_counter += 1

        # Calculate the 1 ratio for this condition
        condition_ratio = condition_ones / condition_bins if condition_bins > 0 else 0
        condition_ratios[condition_name] = condition_ratio
        print(f"Condition {condition_name} - Ratio of 1s: {condition_ratio:.4f} ({condition_ratio * 100:.2f}%)")

    # Ensure the output directory exists
    os.makedirs(output_root, exist_ok=True)
    
    # Save as a .mat file
    savemat(output_file, {"D": D})
    
    # Return the condition ratios dictionary for further analysis if needed
    return condition_ratios

In [None]:
generate_datahigh_input(
    curated_units=final_curated_units,
    recording_duration=4200,                # The total time of concatenated segments (in second)
    segment_duration=600,                   # The time of each segment (in second, up to 10 conditions or segments)
    time_range=(300,500),                  # The time range in each segment that will be processed (in second)
    trial_time=10,                          # The trail time in each time range (in second)
    bin_size=0.001                           # The bin_size in each trail (in second, should be adjusted to allow the data array has reasonable number of '1')
)

"""
A 1-10% ratio of 1s in the data is generally effective for GPFA, with adjustments to bin size if needed to achieve this range. 
This provides a balance between capturing meaningful structure and avoiding an excess of zeros or too much noise.
"""

# Plot 2D UMAP

In [None]:
def plot_umap_with_filtering(analyzer, curated_unit_ids, output_root, nb_points=600):
    """
    Generate and plot a UMAP projection for curated sorted units with filtering.
    
    Parameters:
    - analyzer: The SortingAnalyzer object with precomputed principal components.
    - curated_unit_ids: A list of unit IDs to include in the UMAP plot.
    - output_root: Path to save the UMAP plot.
    - nb_points: Number of closest PCA points to select for each unit.
    """
    # Retrieve the PCA extension
    ext_pca = analyzer.get_extension("principal_components")

    # Collect filtered PCA scores for only the curated units
    all_pca_data = []
    all_labels = []
    for unit_id in curated_unit_ids:
        # Retrieve PCA projections for the current unit
        unit_pca = ext_pca.get_projections_one_unit(unit_id=unit_id, sparse=False)
        
        # Skip units with no PCA projections
        if unit_pca.size == 0:
            print(f"Skipping unit {unit_id} due to empty PCA projections.")
            continue
        
        # Flatten across components and channels: (num_spikes, num_pca_components * num_channels)
        unit_pca_flattened = unit_pca.reshape(unit_pca.shape[0], -1)

        # Compute distances to the mean PCA projection
        mean_pca = np.mean(unit_pca_flattened, axis=0)
        distances = np.sqrt(np.sum(np.square(unit_pca_flattened - mean_pca), axis=1))

        # Select the closest `nb_points` projections
        if len(distances) > nb_points:
            selected_indices = np.argsort(distances)[:nb_points]
            unit_pca_filtered = unit_pca_flattened[selected_indices]
        else:
            unit_pca_filtered = unit_pca_flattened

        # Append the filtered PCA projections and corresponding labels
        all_pca_data.append(unit_pca_filtered)
        all_labels.extend([unit_id] * len(unit_pca_filtered))

    # Check if there is any data to process
    if not all_pca_data:
        print("No valid PCA data found for the curated units. UMAP plot cannot be generated.")
        return

    # Concatenate all PCA data into a single 2D array for UMAP
    all_pca_data = np.vstack(all_pca_data)

    # Apply UMAP with enhanced separation settings
    reducer = umap.UMAP(
        n_components=2,
        n_neighbors=30,       # Increase for more global structure
        min_dist=0.1,         # Decrease for tighter clusters
        metric="cosine",      # Use cosine similarity for better high-dimensional structure
        random_state=42
    )
    embedding = reducer.fit_transform(all_pca_data)

    # Assign distinct colors for each unit using Method 2
    unique_units = np.unique(curated_unit_ids)
    num_colors = len(unique_units)
    cmap = plt.cm.get_cmap('tab20', max(num_colors, 40))  # Ensure at least 40 colors available
    unit_color_dict = {unit: cmap(i / num_colors) for i, unit in enumerate(unique_units)}

    # Generate colors for the embedding
    colors = [unit_color_dict[label] for label in all_labels]

    # Plot the UMAP result
    plt.figure(figsize=(10, 8))
    plt.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=5, alpha=0.8)
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    plt.title("UMAP of Curated Units (Filtered PCA) for Segment 5")

    # Add a legend for unit IDs
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=unit_color_dict[unit], markersize=10) 
               for unit in unique_units]
    plt.legend(handles, [f"Unit {unit}" for unit in unique_units], title="Unit IDs", bbox_to_anchor=(1.05, 1), loc='upper left')

    # Save the plot
    output_path = os.path.join(output_root, "UMAP_segment_5_filtered.svg")
    plt.savefig(output_path, format="svg", bbox_inches='tight')
    print(f"UMAP plot saved to {output_path}")
    plt.show()

# Call the function with segment 5 analyzer
plot_umap_with_filtering(analyzer_segment_5, final_curated_units, output_root)

# Plot results of final curated units

In [None]:
#unit_based_plot_save(sorting, analyzer, final_curated_units)
probe_based_plot(sorting, analyzer, final_curated_units)
#plot_umap(analyzer, final_curated_units)

# Plot pearson correlation coefficients

In [None]:
def compute_correlation_matrix(curated_units=final_curated_units, sorting=sorting, start_time=1200, end_time=1800):
    """
    Compute a correlation matrix for a list of curated units based on spike train overlap.

    Parameters:
        curated_units (list of int): List of curated unit IDs.
        sorting (SortingExtractor): The sorting extractor with the spike data.
        start_time (float): Start time of the interval for filtering spikes.
        end_time (float): End time of the interval for filtering spikes.

    Returns:
        np.ndarray: Correlation matrix where each element represents the correlation between two units.
    """
    # Extract spike trains for each curated unit within the specified time range
    curated_spike_trains = [
        (sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency())[
            (sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency() >= start_time) & 
            (sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency() < end_time)
        ]
        for unit_id in curated_units
    ]

    n_units = len(curated_spike_trains)
    correlation_matrix = np.zeros((n_units, n_units))

    # Compute correlation between each pair of units
    for i in range(n_units):
        for j in range(n_units):
            if i != j:  # Skip autocorrelation
                if len(curated_spike_trains[i]) > 1 and len(curated_spike_trains[j]) > 1:
                    # Bin the spike trains to a common resolution (e.g., 100 ms bins) for cross-correlation
                    bin_edges = np.arange(start_time, end_time, 0.1)  # 50 ms bin size
                    binned_i, _ = np.histogram(curated_spike_trains[i], bins=bin_edges)
                    binned_j, _ = np.histogram(curated_spike_trains[j], bins=bin_edges)
                    
                    # Calculate Pearson correlation coefficient
                    correlation = np.corrcoef(binned_i, binned_j)[0, 1]
                    correlation_matrix[i, j] = correlation if not np.isnan(correlation) else 0

    # Set diagonal values to NaN to display as white
    np.fill_diagonal(correlation_matrix, np.nan)

    return correlation_matrix

def plot_correlation_heatmap(correlation_matrix, curated_units=final_curated_units):
    """
    Plot the correlation matrix as a heatmap.

    Parameters:
        correlation_matrix (np.ndarray): Matrix of correlations between units.
        curated_units (list of int): List of curated unit IDs for labeling.
    """
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Display heatmap with NaN values as white, and adjust color range to -0.1 to 0.1
    cax = ax.imshow(correlation_matrix, cmap='viridis', vmin=-0.1, vmax=0.1)

    # Add color bar with a range from -0.1 to 0.1
    fig.colorbar(cax, ax=ax, label='Correlation', ticks=[-0.1, 0, 0.1])

    # Label the axes with unit numbers
    ax.set_xticks(np.arange(len(curated_units)))
    ax.set_yticks(np.arange(len(curated_units)))
    ax.set_xticklabels(curated_units)
    ax.set_yticklabels(curated_units)
    ax.set_xlabel("Unit")
    ax.set_ylabel("Unit")

    # Rotate x-axis labels for better readability
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    plt.tight_layout()
    fig.savefig(os.path.join(output_root, "correlation_heatmap_1200-1800.svg"), format="svg")
    plt.show()

# Compute the correlation matrix and plot the heatmap
correlation_matrix = compute_correlation_matrix()
plot_correlation_heatmap(correlation_matrix)

# Plot firing rate changes along recording

In [None]:
def plot_firing_rate_change(recording, sorting, bin_size, curated_unit_ids, time_ranges, segment_time):
    """
    Plot the change of firing rate along the recording time for curated units with the specified bin size
    and plot the averaged firing rate across all curated units within specified time ranges.
    Also, calculate the change in firing rate between segments and plot individual changes with averaged changes.

    Parameters:
    - recording: The recording object.
    - sorting: The sorting object containing spike times and unit information.
    - bin_size: The bin size in seconds for the analysis.
    - curated_unit_ids: A list of unit IDs to include in the plot.
    - time_ranges: A list of tuples specifying time ranges (e.g., [(0, 60), (100, 120)]).
    - segment_time: The duration of each segment in seconds for calculating firing rate changes.
    """
    os.makedirs(output_root, exist_ok=True)  # Ensure output directory exists

    # Assign a unique color to each unit
    color_cycle = plt.cm.tab10.colors  # Use a color map with up to 10 colors; extend if more units
    unit_colors = {unit_id: color_cycle[i % len(color_cycle)] for i, unit_id in enumerate(curated_unit_ids)}

    # Plot 1: Firing Rate Over Time
    concatenated_firing_rates = {unit_id: [] for unit_id in curated_unit_ids}
    concatenated_bin_centers = []
    current_start_time = 0

    for start_time, end_time in time_ranges:
        bins = np.arange(start_time, end_time + bin_size, bin_size)
        bin_centers = bins[:-1] + bin_size / 2
        adjusted_bin_centers = bin_centers + current_start_time - start_time

        concatenated_bin_centers.extend(adjusted_bin_centers)

        for unit_id in curated_unit_ids:
            unit_spike_train = sorting.get_unit_spike_train(unit_id) / recording.get_sampling_frequency()
            unit_spike_train = unit_spike_train[(unit_spike_train >= start_time) & (unit_spike_train < end_time)]

            firing_rate, _ = np.histogram(unit_spike_train, bins=bins)
            firing_rate = firing_rate / bin_size
            concatenated_firing_rates[unit_id].extend(firing_rate)

        current_start_time += end_time - start_time

    fig, ax = plt.subplots(figsize=(12, 8))
    for unit_id in curated_unit_ids:
        ax.plot(concatenated_bin_centers, concatenated_firing_rates[unit_id], label=f'Unit {unit_id}', 
                marker='o', linestyle='-', color=unit_colors[unit_id])

    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Firing Rate (spikes per second)')
    ax.set_title('Firing Rate Change Over Time for Curated Units (Continuous)')
    ax.legend()
    fig.savefig(os.path.join(output_root, "firing_rate_over_time_of_each_units.svg"), format="svg")
    plt.show()

    # Plot 2: Average Firing Rate Across Units
    fig, ax = plt.subplots(figsize=(12, 8))
    all_rates = np.array(list(concatenated_firing_rates.values()))
    if all_rates.size > 0:
        average_firing_rate = np.mean(all_rates, axis=0)
        ax.plot(concatenated_bin_centers, average_firing_rate, marker='o', linestyle='-', color='red', label='Average Firing Rate')
    else:
        print("No valid firing rates available for averaging.")

    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Average Firing Rate (spikes per second)')
    ax.set_title('Average Firing Rate Across Curated Units (Continuous)')
    ax.legend()
    fig.savefig(os.path.join(output_root, "average_firing_rate_across_units.svg"), format="svg")
    plt.show()

    # Plot 3: Firing Rate Changes Between Segments (Absolute Values with p-value Calculation for Adjacent Transitions)
    total_duration = recording.get_total_duration()
    num_segments = int(total_duration // segment_time)
    segment_ranges = [(i * segment_time, (i + 1) * segment_time) for i in range(num_segments)]

    firing_rate_changes = {unit_id: [] for unit_id in curated_unit_ids}

    for i in range(1, len(segment_ranges)):
        prev_start, prev_end = segment_ranges[i - 1]
        curr_start, curr_end = segment_ranges[i]

        for unit_id in curated_unit_ids:
            prev_spikes = sorting.get_unit_spike_train(unit_id) / recording.get_sampling_frequency()
            prev_spikes = prev_spikes[(prev_spikes >= prev_start) & (prev_spikes < prev_end)]

            curr_spikes = sorting.get_unit_spike_train(unit_id) / recording.get_sampling_frequency()
            curr_spikes = curr_spikes[(curr_spikes >= curr_start) & (curr_spikes < curr_end)]

            prev_rate = len(prev_spikes) / segment_time
            curr_rate = len(curr_spikes) / segment_time

            rate_change = abs(curr_rate - prev_rate)  # Take the absolute value of the rate change
            firing_rate_changes[unit_id].append(rate_change)

    fig, ax = plt.subplots(figsize=(12, 8))
    all_changes = []
    dot_values = {}
    for transition_index in range(len(segment_ranges) - 1):
        dot_values[transition_index] = []

    for unit_id, changes in firing_rate_changes.items():
        for transition_index, rate_change in enumerate(changes):
            ax.scatter(transition_index, rate_change, color=unit_colors[unit_id])
            dot_values[transition_index].append(rate_change)
        all_changes.append(changes)

    avg_changes = np.mean(np.array(all_changes), axis=0) if len(all_changes) > 0 else []
    std_changes = np.std(np.array(all_changes), axis=0) if len(all_changes) > 0 else []

    if len(avg_changes) > 0:
        ax.bar(np.arange(len(avg_changes)), avg_changes, color='gray', alpha=0.5, label='Average Change')
        ax.errorbar(
            np.arange(len(avg_changes)), avg_changes, yerr=std_changes, fmt='o', color='black', ecolor='red', 
            elinewidth=2, capsize=5, label='SD Error'
        )

    ax.set_xlabel('Segment Transition Index')
    ax.set_ylabel('Absolute Firing Rate Change (spikes per second)')
    ax.set_title('Absolute Firing Rate Change Between Segments')
    ax.legend()
    fig.savefig(os.path.join(output_root, "absolute_firing_rate_changes_between_segments_with_sd.svg"), format="svg")
    plt.show()

    # Print individual dot values grouped by transition index
    for transition_index, values in dot_values.items():
        print(f"Transition {transition_index + 1}:")
        print(f"{[f'{v:.4f}' for v in values]}")

    # Calculate and print p-values for adjacent transitions
    print("\nAdjacent Transition Comparisons (p-values):")
    for i in range(len(dot_values) - 1):
        stat, p_value = ttest_ind(dot_values[i], dot_values[i + 1], equal_var=False)  # Perform t-test
        print(f"Transition {i + 1} vs Transition {i + 2}: p = {p_value:.4e}")

In [None]:
firing_rate_change_bin_size=30 #The bin size in seconds for the analysis, in second.
time_ranges = [(0,4200)] #Exclude the data that has human noise.
segment_time = 600
#plot_firing_rate_change(recording, sorting, bin_size=firing_rate_change_bin_size, curated_unit_ids=final_curation_unit_ids)
plot_firing_rate_change(recording, sorting, bin_size=firing_rate_change_bin_size, curated_unit_ids=final_curated_units, time_ranges=time_ranges, segment_time=segment_time)
"""
    Plot the change of firing rate along the recording time for curated units with the specified bin size
    and plot the averaged firing rate across all curated units.
    
    Parameters:
    - recording: The recording object.
    - sorting: The sorting object.
    - bin_size: The bin size in seconds for the analysis.
    - curated_units: A list of unit IDs to include in the plot.
"""

"""
    How to interpret the results:
    - First plot shows the firing rate changes of units, each line corresponds to one unit. 
    - If lots of 0 in first plot, means the bin size choose here is too small for the following neural population analysis.
    
    - Second plot shows the averaged firing rate changes of all units.
"""