In [20]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
import probeinterface as pi
from pathlib import Path
import os 
import pandas as pd 

global_job_kwargs = dict(n_jobs=4, chunk_duration="1s",progress_bar=True)
si.set_global_job_kwargs(**global_job_kwargs)


basefolder="D:/M8_SNA-135381_17092024_1_g0_imec1"

metapath = basefolder + str('/Meta')
if not os.path.isdir(metapath):
   os.makedirs(metapath)


recording =  si.read_spikeglx(basefolder, stream_id='imec1.ap', load_sync_channel=False)
lfp = si.read_spikeglx(basefolder, stream_id='imec1.lf', load_sync_channel=False)
event =  si.read_spikeglx(basefolder, stream_id='nidq', load_sync_channel=False)




bad_channel_ids, channel_labels = si.detect_bad_channels(lfp,method = 'coherence+psd',outside_channels_location = 'both')
names = lfp.channel_ids
depth = lfp.get_channel_locations()[:,1]


ar = pd.DataFrame({'name':names, 'depth':depth, 'labels':channel_labels})
ar.to_csv(metapath + str('/lfp_labels.csv'))



def extract_and_save_ttl_events(data, bits, save_path):
    digital_signals = data.get_traces()
    digital_word = digital_signals[:, 0]
    sampling_rate = data.get_sampling_frequency()
    for bit in bits:
        # Extract TTL pulses for the current bit
        ttl_timestamps = extract_ttl_from_bit(digital_word, bit, sampling_rate)
        
        ttl_df = pd.DataFrame(ttl_timestamps, columns=['timestamps'])
        
        filename = f'ttl_{bit}.csv'
        
        ttl_df.to_csv(f"{save_path}/{filename}", index=False)
        print(f"Extracted TTL event timestamps for bit {bit} saved to {filename}")


def extract_ttl_from_bit(digital_word, bit, sampling_rate):
    # Extract the specific bit from the word (bit-shifting and masking)
    ttl_signal = (digital_word >> bit) & 1  # Right shift and mask to isolate the specific bit
    
    # Detect rising edges (0 -> 1 transitions)
    ttl_rising_edges = np.where(np.diff(ttl_signal) > 0)[0]
    
    # Convert sample indices to timestamps (in seconds)
    ttl_timestamps = ttl_rising_edges / sampling_rate
    
    return ttl_timestamps


bits_to_extract = [0, 1, 2]  
extract_and_save_ttl_events(event , bits_to_extract, metapath)











rec1 = si.highpass_filter(recording, freq_min=400.)
rec1 = si.phase_shift(rec1)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1,method = 'coherence+psd')
print(bad_channel_ids)
rec1 = si.interpolate_bad_channels(recording=rec1, bad_channel_ids=bad_channel_ids)

rec1 = si.common_reference(rec1, operator="median", reference="global")


Sorting_KS4 = si.run_sorter(sorter_name="kilosort4", recording=rec1, folder=basefolder + str('/sorted'),remove_existing_folder=True)

analyzer = si.create_sorting_analyzer(Sorting_KS4, rec1, sparse=True, format="memory")

analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels','unit_locations','correlograms'],**global_job_kwargs)
analyzer.compute('spike_amplitudes')
analyzer.compute('principal_components', n_components = 5, mode="by_channel_local",**global_job_kwargs)

metric_names=['firing_rate', 'presence_ratio', 'snr','isi_violation', 'amplitude_cutoff']
metrics = si.compute_quality_metrics(analyzer, metric_names=metric_names)


amplitude_cutoff_thresh = 0.1
isi_violations_ratio_thresh = 0.5
presence_ratio_thresh = 0.9


our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})"

keep_units = metrics.query(our_query)
keep_unit_ids = keep_units.index.values
analyzer_clean = analyzer.select_units(keep_unit_ids, folder=basefolder +str('/analyzer_clean'), format='binary_folder')

si.export_to_phy(analyzer_clean, output_folder=basefolder + str('/sorted/phy'),**global_job_kwargs)


In [17]:
!phy template-gui  E:\Florian\test\params.py




[33m15:14:23.038 [W] model:667            Skipping spike waveforms that do not exist, they will be extracted on the fly from the raw data as needed.[0m
