In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tsmoothie.smoother import GaussianSmoother
import spikeinterface
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.postprocessing as sp
import spikeinterface.preprocessing as spre
import spikeinterface.qualitymetrics as qm
import helper_functions as helper

In [2]:
#Reading the file, BP filtering
local_path= '/mnt/disk15tb/mmpatil/Spikesorting/Data/Mandar/230705/18712/Network/000009/data.raw.h5' #network data from chip 16848

recording1 = se.read_maxwell(local_path)



#recording = si.ConcatenateSegmentRecording([recording1,recording2])
channel_ids = recording1.get_channel_ids()
fs = recording1.get_sampling_frequency()
num_chan = recording1.get_num_channels()
num_seg = recording1.get_num_segments()
total_recording = recording1.get_total_duration()

#print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
print(f"total_recording: {total_recording} s")

recording_bp = spre.bandpass_filter(recording1, freq_min=300, freq_max=6000)

recodring_cmr = spre.common_reference(recording_bp, reference='global', operator='median')
#recording_chunk = recodring_cmr.frame_slice(start_frame= 1*fs,end_frame=425*fs)
recording_chunk = recodring_cmr.frame_slice(start_frame= 0*fs,end_frame=300*fs)
print(f"chunk duration: {recording_chunk.get_total_duration()} s")


Sampling frequency: 20000.0
Number of channels: 1011
Number of segments: 1
total_recording: 300.09 s
chunk duration: 300.0 s


In [None]:

from flac_numcodecs import Flac
compressor = Flac(level =8)
local_path= '/mnt/disk15tb/mmpatil/Spikesorting/Data/May16_analysison1024Amp/16848/Network/000041/data.raw.h5'
rec_original = se.read_maxwell(local_path)
rec_int32 = spre.scale(rec_original, dtype="int32")
# remove the 2^15 offset
rec_rm_offset = spre.scale(rec_int32, offset=-2 ** 15)
# now we can safely cast to int16
rec_int16 = spre.scale(rec_rm_offset, dtype="int16")
recording_zarr = rec_int16.save(format = "zarr",folder="/mnt/disk15tb/mmpatil/Spikesorting/Data/May16_analysison1024Amp/16848/Network/000041/compressed.zarr",compressor=compressor,
                                channel_chunk_size =2,n_jobs=64,chunk_duration="1s")

In [None]:
# Create a sample dictionary
my_dict = {'key1': 'value1', 'key2': 'value2', 'key3': 'value1', 'key4': 'value3', 'key5':'value2'}

# Create an empty dictionary to store our results
result_dict = {}

# Loop through each key-value pair in my_dict
for key, value in my_dict.items():
    # Check if the value already exists in result_dict
    if value in result_dict:
        # If it does, append the current key to the list of keys that have the same value
        result_dict[value].append(key)
    else:
        # If it doesn't, create a new entry in result_dict with the value as the key and a list containing the current key
        result_dict[value] = [key]

# Create a list called output that contains only the values from result_dict that have more than one key
output = []
for value in result_dict.values():
    if len(value) > 1:
        output.append(value)

# Print the output to the console
print(output)

In [None]:
default_KS2_params = ss.get_default_sorter_params('kilosort2')
print(default_KS2_params)

In [None]:

default_KS2_params['keep_good_only'] = True
default_KS2_params['detect_threshold'] = 12
default_KS2_params['projection_threshold']=[18, 10]
default_KS2_params['preclust_threshold'] = 8
run_sorter = ss.run_kilosort2(recording_chunk, output_folder="/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sorting_7jul/ks2FR/", docker_image= "kilosort2-maxwellcomplib:latest",verbose=True, **default_KS2_params)
#run_sorter = ss.run_sorter('kilosort2',recording= recording_chunk, output_folder="/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/kilosort2",docker_image= True,verbose=True, **default_KS2_params)

In [3]:
sorting_KS3 = ss.Kilosort3Sorter._get_result_from_folder('/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sorting_7jul/ks2FR/sorter_output/')
total_units = sorting_KS3.get_unit_ids()
print(total_units)
print(len(total_units))
channel_ids = recording_chunk.get_channel_ids()
print(channel_ids)
channel_association_dict = {int(y):x for x,y in enumerate(channel_ids) }
print(channel_association_dict)

[  0   1   2   3   4   5   6   8  13  14  15  16  17  18  19  23  24  25
  27  28  29  30  31  32  39  40  42  43  44  45  46  47  48  49  50  51
  52  53  54  55  56  57  58  59  60  61  63  65  66  67  68  70  71  72
  73  74  75  76  77  78  79  81  83  84  85  86  88  89  90  91  93  94
  95  96  97  98  99 100 102 104 105 106 107 108 109 110 111 112 113 114
 116 117 118 120 121 122 123 124 125 126 127 128 129 130 132 133 134 136
 137 138 140 141 143 144 145 146 148 150 151 152 153 154 155 156 157 158
 159 160 161 162 163 164 165 166 167 168 169 170 171 173 174 175 177 178
 179 180 181 182 183 184 186 187 188 189 190 191 192 193 194 195 196]
161
['0' '1' '2' ... '1021' '1022' '1023']
{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19, 20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38, 39: 39, 40

In [4]:
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
#waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder="./waveformsblock1_7min",overwrite=True, ms_before=1., ms_after=2.,**job_kwargs)
waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder='/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sorting_7jul/waveformsFR',overwrite=True)
print(waveforms)

extract waveforms memmap:   0%|          | 0/300 [00:00<?, ?it/s]

WaveformExtractor: 1011 channels - 161 units - 1 segments
  before:60 after:80 n_per_units:500


In [None]:
waveforms = si.load_waveforms('/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sortingMaxwelljun16/waveforms_goodFR')
print(waveforms)

In [5]:
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
sp.compute_spike_amplitudes(waveforms,load_if_exists=True,**job_kwargs)

extract amplitudes:   0%|          | 0/300 [00:00<?, ?it/s]

[array([ -69.236755,  -56.648254, -113.29651 , ...,  -50.354004,
         -69.236755, -132.17926 ], dtype=float32)]

In [6]:
import spikeinterface.qualitymetrics as qm
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
metrics = qm.compute_quality_metrics(waveforms,load_if_exists=True,**job_kwargs)

  snrs[unit_id] = np.abs(amplitude) / noise


In [7]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows',None)
display(metrics)

Unnamed: 0,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,sliding_rp_violation,amplitude_cutoff,amplitude_median
0,2753,9.176667,1.0,inf,0.0,0,0.0,0,0.01,0.000751,163.650513
1,487,1.623333,1.0,inf,0.0,0,0.0,0,0.15,,138.473511
2,1650,5.5,1.0,inf,0.0,0,0.0,0,0.015,0.00397,88.119507
3,2430,8.1,1.0,inf,0.0,0,0.0,0,0.01,0.0046,81.825256
4,703,2.343333,1.0,inf,0.0,0,0.0,0,0.125,0.007227,37.765503
5,375,1.25,1.0,inf,0.0,0,0.0,0,0.25,,188.827515
6,1841,6.136667,1.0,inf,0.0,0,0.0,0,0.015,0.000919,37.765503
8,1873,6.243333,1.0,inf,0.0,0,0.0,0,0.025,0.001655,31.471252
13,1299,4.33,1.0,inf,0.0,0,0.0,0,0.025,0.027464,69.236755
14,1703,5.676667,1.0,inf,0.0,0,0.0,0,0.035,0.000544,56.648254


In [8]:
import importlib
import mea_analysis_pipeline as msp

In [None]:
msp.remove_violated_units(metrics)

In [9]:
import mea_analysis_pipeline as msp
unit_ids = msp.remove_violated_units(metrics)
print(f"{unit_ids} {len(unit_ids)}")

[  0   2   3   4   6   8  13  14  15  16  17  18  19  23  24  25  27  28
  29  30  31  32  39  40  42  43  44  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  63  65  66  67  68  71  72  73  74  75  76
  77  78  79  81  84  85  86  88  89  90  91  93  94  96  98  99 100 102
 104 105 106 107 108 109 110 112 113 114 116 120 121 123 127 129 130 132
 134 136 137 140 141 143 145 146 148 150 152 153 154 155 156 157 159 162
 164 165 166 167 169 171 173 174 175 178 179 180 181 183 184 186 187 188
 189 191 192 193 194 195 196] 133


In [10]:
redundant_units = msp.remove_similar_templates(waveforms)
print(f"redundant-units : {redundant_units}")
non_violated_units = [item for item in unit_ids if item not in redundant_units]

196 97
redundant-units : [196]


In [None]:
sorting_auto_KS3 = sorting_KS3.select_units(unit_ids)

In [11]:
waveform_good = waveforms.select_units(non_violated_units,new_folder='/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sorting_7jul/waveforms_goodFR')


In [13]:
waveform_good = si.load_waveforms('/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/sorting_7jul/waveforms_goodFR')

In [14]:
%matplotlib widget
locations = sp.compute_unit_locations(waveform_good)
print(len(locations))




132


In [15]:
import pickle

file_path = './Firingratelocation.pkl'

with open(file_path,'wb') as file:
    pickle.dump(locations,file)

In [None]:
unit_extremum_channel =spikeinterface.full.get_template_extremum_channel(waveform_good, peak_sign='neg')
#Step 1: keep only units that are in good_units 
unit_extremum_channel = {key:value for key,value in unit_extremum_channel.items() }
print(f"extremum channel : {unit_extremum_channel}")

#unique channels and count:
unique_channel_count = set([value for key,value in unit_extremum_channel.items()])
print(f"{unique_channel_count} and len : {len(unique_channel_count)}")

#Step3: get units that correspond to same electrodes.
output_units = [[key for key, value in unit_extremum_channel.items() if value == v] for v in set(unit_extremum_channel.values()) if list(unit_extremum_channel.values()).count(v) > 1]
print(f"Units that correspond to same electrode: {output_units}")

In [None]:
print(type(metrics['amplitude_median'][int('0')]))

In [None]:
my_list = [[5, 11], [9, 19, 27], [46, 48], [1, 2, 3], [50, 53, 59], [17, 18]]

# Flatten the list using a nested list comprehension
flattened_list = [element for sublist in my_list for element in sublist]

print(flattened_list)

In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



In [None]:

output = [[key for key, value in extremum_channels_ids.items() if value == v] for v in set(extremum_channels_ids.values()) if list(extremum_channels_ids.values()).count(v) > 1]
print(output)

In [None]:


print(helper.get_key_by_value(extremum_channels_ids,'221'))

In [None]:

isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(waveforms, isi_threshold_ms=1.0)
print(isi_violations_ratio)
print(isi_violations_count)

rp_contamination,rp_violation = qm.compute_refrac_period_violations(waveforms)
print(rp_contamination)
print(rp_violation)
snr_ratio = qm.compute_snrs(waveforms,peak_sign="both", peak_mode='at_index')
print(snr_ratio)

firing_rate = qm.compute_firing_rates(waveforms)
print(firing_rate)

amp_cutoff = qm.compute_amplitude_cutoffs(waveforms)
print(amp_cutoff)

In [None]:
for i in range(29):
    record_num = str(i).zfill(4)
    record_name = 'rec' + record_num
    print(record_name)

In [None]:
import helper_functions as helper


filename = 'Extremechannels_4min.json'
helper.dumpdicttofile(extremum_channels_ids,filename)


In [None]:
violated_units = [unit for unit, ratio in isi_violations_ratio.items() if ratio > 0.0]
print(violated_units)
print(f"isi violated units:{len(violated_units)}")

refrct_violated_units = [unit for unit,ratio in rp_contamination.items() if ratio >0.0]
print(refrct_violated_units)
print(f"refract vio units:{len(refrct_violated_units)}")

In [None]:
frrate_units = [ unit for unit,ratio in firing_rate.items() if ratio < 0.1]
print(frrate_units)

In [None]:
deletion_candidates = list(set(violated_units+frrate_units))
print(deletion_candidates)

In [None]:
print(sorting_KS3)

clean_sorting = sorting_KS3.remove_units(deletion_candidates)
print(clean_sorting)
good_units = [units for units in total_units if units not in deletion_candidates ]
print(good_units)

#now getting the wavefrom extractor

waveform_good = waveforms.select_units(good_units,new_folder='/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/waveforms_good')
print(waveform_good)

In [None]:
unit_extremum_channel =sp.get_template_extremum_channel(waveforms, peak_sign='neg')
#Step 1: keep only units that are in good_units 
print(unit_extremum_channel)
unit_extremum_channel = {key:value for key,value in unit_extremum_channel.items() if key  in good_units}

print(unit_extremum_channel)
#Step3: get units that correspond to same electrodes.
output_units = [[key for key, value in unit_extremum_channel.items() if value == v] for v in set(unit_extremum_channel.values()) if list(unit_extremum_channel.values()).count(v) > 1]

print(output_units)
output=[]
if output_units :
    for sublist in output_units :
        amp_max = 0 
        for unit in sublist:
            if metrics['amplitude_median'][int(unit)] > amp_max :
                amp_max = metrics['amplitude_median'][int(unit)]
                reqd_unit = unit
        output.append(reqd_unit)
#Step 5 --> unit_extremum_channel - output_units + output
output_units = [element for sublist in output_units for element in sublist]
new_list = [ item for item in output_units if item not in output]

print(f"Output : {output}")

print(f"Output units : {output_units}")

print(f"new list {new_list}")
required_templates = {key:value for key,value in unit_extremum_channel.items() if key not in new_list}
print(f"reqd templates {required_templates}")

In [None]:
%matplotlib widget

In [None]:
import spikeinterface.postprocessing as sp

locations = sp.compute_unit_locations(waveforms)
print(type(locations))
import numpy as np
#np.savetxt("unitloc_10mins.txt",locations)
ax = plt.subplot(111)
sw.plot_probe_map(recording1,ax=ax,with_channel_ids=False)
for x,y in locations:
    ax.scatter(x,y)

In [None]:
import spikeinterface_gui
app = spikeinterface_gui.mkQApp() 

# create the mainwindow and show
win = spikeinterface_gui.MainWindow(waveforms)
win.show()
# run the main Qt6 loop
app.exec_()#Need to implement compute noise levels.

In [None]:
channel_locations = recording_chunk.get_channel_locations()
channel_ids = recording_chunk.get_channel_ids()
_ = [print(f"{channel_id}: {location}") for location, channel_id in zip(channel_locations, channel_ids)]



In [None]:
fig, ax1 = plt.subplots(figsize=(15,5))
spike_times = {}
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=100*fs)
    print(spike_train)
    if len(spike_train) > 0:
        spike_times[idx] = spike_train / float(fs)
        #print(spike_times[unit_id])
       # print(unit_id*np.ones_like(spike_times[unit_id]))
        ax1.plot(spike_times[idx],idx*np.ones_like(spike_times[idx]),
                             marker='|', mew=1, markersize=3,
                             ls='',color='black')
                       

In [None]:
t_start = 0 
t_end = int(600*fs)
dt = 1
#initialising the spike train.
units= clean_sorting.get_num_units()
frame_numbers = t_end
spike_array = np.zeros((units,frame_numbers), dtype= int)
for idx, unit_id in enumerate(clean_sorting.get_unit_ids()):
    spike_train = clean_sorting.get_unit_spike_train(unit_id,start_frame=t_start,end_frame=t_end)
    for spike_time in spike_train:
        spike_array[idx,spike_time] = 1

print(spike_array)

print(spike_array[0,63782])

In [None]:
np.savez_compressed('spike_array_compressed_blockactivity.npz',spike_array)

In [None]:
with np.load('spike_array_compressed_blockactivity.npz') as data:
    decompressed_data = data['arr_0']

print(np.array_equal(spike_array, decompressed_data))


In [None]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



In [None]:
colors = [ 'Lime','Gold', 'Orange','Orangered']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([22,23,27,28]):
    wf = waveforms.get_waveforms(unit_id)
    color = colors[i]
    ax.plot(wf[:, :,channel_association_dict[20]].T, color=color, lw=0.3)
print(wf.shape)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
wf=[]
for i, unit_id in enumerate([218]):
    wf = waveform_good.get_waveforms(unit_id)
    ax.plot(wf[:, :,channel_association_dict[902]].T, color=colors[0], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[613]].T, color=colors[1], lw=0.3)
    ax.plot(wf[:, :,channel_association_dict[663]].T, color=colors[2], lw=0.3)
print(wf.shape)

In [None]:
peak_shift=si.get_template_extremum_channel_peak_shift(waveform_good)

print(peak_shift)

In [None]:
colors = [ 'Fuchsia','Olive', 'Teal']
fig, ax = plt.subplots()
for i, unit_id in enumerate([26, 40 , 46]):
    template = waveforms.get_template(unit_id)
    color = colors[i]
    ax.plot(template[:, channel_association_dict[780]].T, color=color, lw=3)
print(template.shape)

In [None]:
w = sw.plot_unit_templates(waveform_good, unit_ids=[183],plot_channels=False )

In [None]:
w = sw.plot_unit_waveforms(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_templates(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_probe_map(waveform_good, unit_ids=[2,4,7])

In [None]:
import spikeinterface_gui
# This creates a Qt app
app = spikeinterface_gui.mkQApp() 

# create the mainwindow and show
win = spikeinterface_gui.MainWindow(waveforms)
win.show()
# run the main Qt6 loop
app.exec_()

