In [None]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
import probeinterface as pi
from pathlib import Path
import os 
import pandas as pd 

global_job_kwargs = dict(n_jobs=4, chunk_duration="1s",progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)


basefolder="F:/copydaya/M7_1_copy"

metapath = basefolder + str('/Meta')
if not os.path.isdir(metapath):
   os.makedirs(metapath)


recording =  si.read_spikeglx(basefolder, stream_id='imec1.ap', load_sync_channel=False)
lfp = si.read_spikeglx(basefolder, stream_id='imec1.lf', load_sync_channel=False)
event =  si.read_spikeglx(basefolder, stream_id='nidq', load_sync_channel=False)
print(recording)

#recording = si.ChannelSliceRecording(recording, channel_ids=recording.get_channel_ids()[180:330])


In [None]:


bad_channel_ids, channel_labels = si.detect_bad_channels(lfp,method = 'coherence+psd',outside_channels_location = 'both')
names = lfp.channel_ids
depth = lfp.get_channel_locations()[:,1]


ar = pd.DataFrame({'name':names, 'depth':depth, 'labels':channel_labels})
ar.to_csv(metapath + str('/lfp_labels.csv'))
filtered_ar = ar[ar['labels'] == 'out']

print(filtered_ar)

In [None]:
fig, ax = plt.subplots(figsize=(15, 15))
si.plot_probe_map(recording4, ax=ax, with_channel_ids=True)
ax.set_ylim(-200,3000)

In [None]:
rec1 = si.highpass_filter(recording, freq_min=400.)
rec1 = si.phase_shift(rec1)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1,method = 'coherence+psd')
print(bad_channel_ids)
rec1 = si.interpolate_bad_channels(recording=rec1, bad_channel_ids=bad_channel_ids)

rec1 = si.common_reference(rec1, operator="median", reference="global")
print(rec1)


%matplotlib widget
si.plot_traces({'raw':recording,'filtered':rec1}, backend='ipywidgets')

In [None]:


def extract_and_save_ttl_events(data, bits, save_path):
    digital_signals = data.get_traces()
    digital_word = digital_signals[:, 0]
    sampling_rate = data.get_sampling_frequency()
    for bit in bits:
        # Extract TTL pulses for the current bit
        ttl_timestamps = extract_ttl_from_bit(digital_word, bit, sampling_rate)
        
        ttl_df = pd.DataFrame(ttl_timestamps, columns=['timestamps'])
        
        filename = f'ttl_{bit}.csv'
        
        ttl_df.to_csv(f"{save_path}/{filename}", index=False)
        print(f"Extracted TTL event timestamps for bit {bit} saved to {filename}")


def extract_ttl_from_bit(digital_word, bit, sampling_rate):
    # Extract the specific bit from the word (bit-shifting and masking)
    ttl_signal = (digital_word >> bit) & 1  # Right shift and mask to isolate the specific bit
    
    # Detect rising edges (0 -> 1 transitions)
    ttl_rising_edges = np.where(np.diff(ttl_signal) > 0)[0]
    
    # Convert sample indices to timestamps (in seconds)
    ttl_timestamps = ttl_rising_edges / sampling_rate
    
    return ttl_timestamps


bits_to_extract = [0, 1, 2]  
extract_and_save_ttl_events(event , bits_to_extract, metapath)

In [None]:
from spikeinterface.sorters import installed_sorters
installed_sorters()
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
torch.cuda.get_device_name(0)

Sorting_KS4 = si.run_sorter(sorter_name="kilosort4", recording=rec1, folder=basefolder + str('/sorted'),remove_existing_folder=True)

In [None]:
#Sorting_KS4 = si.read_kilosort(folder_path=basefolder + str('/sorted/sorter_output'))
analyzer = si.create_sorting_analyzer(Sorting_KS4, rec1, sparse=True, format="memory")



In [None]:
analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels','unit_locations','correlograms'],**global_job_kwargs)
analyzer.compute('spike_amplitudes')
analyzer.compute('principal_components', n_components = 5, mode="by_channel_local",**global_job_kwargs)

In [None]:
metric_names=['firing_rate', 'presence_ratio', 'snr','isi_violation', 'amplitude_cutoff']
metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names)


amplitude_cutoff_thresh = 0.1
isi_violations_ratio_thresh = 0.5
presence_ratio_thresh = 0.9


our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})"

keep_units = metrics.query(our_query)
keep_unit_ids = keep_units.index.values
analyzer_clean = analyzer.select_units(keep_unit_ids, folder=basefolder +str('/analyzer_clean'), format='binary_folder')
print(analyzer)
print(analyzer_clean)


In [None]:
si.plot_sorting_summary(sorting_analyzer=sorting_analyzer, curation=True, backend='spikeinterface_gui')

In [None]:
si.export_to_phy(analyzer_clean, output_folder=basefolder + str('/sorted/phy'),**global_job_kwargs)

In [3]:
!phy template-gui E:/Florian/Data/batch3/M9_1/M9_SNA-135383_19092024_1_g0_imec1/sorted/phy/params.py

[33m15:26:33.627 [W] model:625            Unreferenced clusters found in spike_clusters (generally not a problem)[0m
[33m15:26:33.744 [W] model:667            Skipping spike waveforms that do not exist, they will be extracted on the fly from the raw data as needed.[0m
[0m15:32:31.326 [I] supervisor:711       Change metadata_group for clusters 159 to mua.[0m
[0m15:32:32.766 [I] supervisor:711       Change metadata_group for clusters 158 to mua.[0m
[0m15:32:40.614 [I] supervisor:711       Change metadata_group for clusters 154 to mua.[0m
[0m15:32:50.199 [I] supervisor:711       Change metadata_group for clusters 149 to mua.[0m
[0m15:32:54.501 [I] supervisor:711       Change metadata_group for clusters 147 to mua.[0m
[0m15:32:58.390 [I] supervisor:711       Change metadata_group for clusters 145 to mua.[0m
[0m15:33:02.974 [I] supervisor:711       Change metadata_group for clusters 143 to mua.[0m
[0m15:33:07.494 [I] supervisor:711       Change metadata_group for clusters