In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tsmoothie.smoother import GaussianSmoother
import spikeinterface
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.postprocessing as sp
import spikeinterface.preprocessing as spre
import spikeinterface.qualitymetrics as qm
import helper_functions as helper

In [2]:
#Reading the file, BP filtering
local_path= '/mnt/disk15tb/jonathan/Syngap3/Syngap3/230103/16657/Network/000138/data.raw.h5' #network data from chip 16848

recording1 = se.read_maxwell(local_path)



#recording = si.ConcatenateSegmentRecording([recording1,recording2])
channel_ids = recording1.get_channel_ids()
fs = recording1.get_sampling_frequency()
num_chan = recording1.get_num_channels()
num_seg = recording1.get_num_segments()
total_recording = recording1.get_total_duration()

#print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
print(f"total_recording: {total_recording} s")

recording_bp = spre.bandpass_filter(recording1, freq_min=300, freq_max=6000)

recodring_cmr = spre.common_reference(recording_bp, reference='global', operator='median')
#recording_chunk = recodring_cmr.frame_slice(start_frame= 1*fs,end_frame=425*fs)
recording_chunk = recodring_cmr.frame_slice(start_frame= 0*fs,end_frame=300*fs)
print(f"chunk duration: {recording_chunk.get_total_duration()} s")


Sampling frequency: 20000.0
Number of channels: 1009
Number of segments: 1
total_recording: 300.09 s
chunk duration: 300.0 s


In [3]:
default_KS2_params = ss.get_default_sorter_params('kilosort2')
print(default_KS2_params)

{'detect_threshold': 6, 'projection_threshold': [10, 4], 'preclust_threshold': 8, 'car': True, 'minFR': 0.1, 'minfr_goodchannels': 0.1, 'freq_min': 150, 'sigmaMask': 30, 'nPCs': 3, 'ntbuff': 64, 'nfilt_factor': 4, 'NT': None, 'AUCsplit': 0.9, 'wave_length': 61, 'keep_good_only': False, 'skip_kilosort_preprocessing': False, 'scaleproc': None, 'save_rez_to_mat': False, 'delete_tmp_files': True, 'delete_recording_dat': False, 'n_jobs': 48, 'chunk_duration': '1s', 'progress_bar': True}


In [None]:

default_KS2_params['keep_good_only'] = True
default_KS2_params['detect_threshold'] = 12
default_KS2_params['projection_threshold']=[18, 10]
default_KS2_params['preclust_threshold'] = 8
run_sorter = ss.run_kilosort2(recording_chunk, output_folder="./sorting/KS_Syngap3_03Jan/", docker_image= "kilosort2-maxwellcomplib:latest",verbose=True, **default_KS2_params)
#run_sorter = ss.run_sorter('kilosort2',recording= recording_chunk, output_folder="/mnt/disk15tb/mmpatil/Spikesorting/sorter_output/kilosort2",docker_image= True,verbose=True, **default_KS2_params)

In [4]:
sorting_KS3 = ss.Kilosort2Sorter._get_result_from_folder('/home/mmp/disktb/mmpatil/MEA_Analysis/Python/sorting/KS_Syngap3_03Jan/sorter_output')
total_units = sorting_KS3.get_unit_ids()
print(len(total_units))
#print(len(total_units))
channel_ids = recording_chunk.get_channel_ids()



496


In [5]:
sorting_good = sorting_KS3.remove_empty_units()

In [6]:
sorting_good = spikeinterface.curation.remove_excess_spikes(sorting_good,recording_chunk)

In [None]:
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
#waveforms = si.extract_waveforms(recording_chunk,sorting_KS3,folder="./waveformsblock1_7min",overwrite=True, ms_before=1., ms_after=2.,**job_kwargs)
waveforms = si.extract_waveforms(recording_chunk,sorting_good,folder='./sorting/waveforms_Syngap3_03Jan/',overwrite=True,**job_kwargs)
print(waveforms)

In [7]:
waveforms = si.load_waveforms('/home/mmp/disktb/mmpatil/MEA_Analysis/Python/sorting/waveforms_Syngap3_03Jan',sorting_good)
print(waveforms)

WaveformExtractor: 1009 channels - 496 units - 1 segments
  before:60 after:80 n_per_units:500


In [8]:
job_kwargs = dict(n_jobs=64, chunk_duration="1s", progress_bar=True)
sp.compute_spike_amplitudes(waveforms,load_if_exists=True,**job_kwargs)

[memmap([-207.71027 , -276.94702 , -528.71704 , ..., -100.70801 ,
          -25.177002, -264.35852 ], dtype=float32)]

In [9]:
import spikeinterface.qualitymetrics as qm
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
metrics = qm.compute_quality_metrics(waveforms,load_if_exists=True,**job_kwargs)

In [10]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows',None)
display(metrics)

Unnamed: 0,num_spikes,firing_rate,presence_ratio,snr,isi_violations_ratio,isi_violations_count,rp_contamination,rp_violations,sliding_rp_violation,amplitude_cutoff,amplitude_median
0,222.0,0.74,1.0,inf,0.0,0.0,0.0,0.0,,,201.41602
1,501.0,1.67,1.0,inf,0.0,0.0,0.0,0.0,0.14,0.308915,44.059753
2,1596.0,5.32,1.0,inf,0.392585,10.0,0.759594,8.0,,0.000442,50.354004
3,122.0,0.406667,1.0,inf,0.0,0.0,0.0,0.0,,,258.06427
4,296.0,0.986667,1.0,inf,0.0,0.0,0.0,0.0,,,37.765503
6,1591.0,5.303333,1.0,inf,0.0,0.0,0.0,0.0,0.07,0.001068,157.35626
7,1376.0,4.586667,1.0,inf,0.0,0.0,0.0,0.0,0.07,0.00685,25.177002
8,25.0,0.083333,0.8,inf,0.0,0.0,0.0,0.0,,,811.9583
9,491.0,1.636667,1.0,inf,0.829597,2.0,1.0,1.0,,,157.35626
10,307.0,1.023333,1.0,inf,0.0,0.0,0.0,0.0,,,107.00226


In [11]:
import mea_analysis_pipeline as msp
update_qual_metrics = msp.remove_violated_units(metrics)
non_violated_units  = update_qual_metrics.index.values
#check for template similarity.
redundant_units = msp.remove_similar_templates(waveforms,sim_score=0.5)
print(redundant_units)
                                               #todo: need to extract metrics here.
non_violated_units = [item for item in non_violated_units if item not in redundant_units]
print(non_violated_units)
print(len(non_violated_units))

247 146
392 260
415 412
511 508
660 642
704 545
705 594
710 605
712 10
[146, 392, 412, 511, 642, 545, 705, 605, 10]
[2, 6, 7, 12, 14, 15, 21, 22, 23, 34, 40, 42, 49, 50, 58, 60, 62, 64, 70, 77, 78, 84, 88, 90, 91, 92, 93, 95, 97, 98, 106, 108, 109, 111, 115, 116, 118, 124, 128, 129, 137, 138, 139, 141, 147, 149, 154, 158, 159, 160, 167, 168, 171, 175, 176, 177, 180, 185, 190, 192, 197, 198, 202, 204, 208, 210, 211, 215, 219, 226, 230, 234, 236, 238, 239, 240, 243, 245, 248, 250, 253, 254, 255, 256, 257, 261, 262, 263, 273, 274, 280, 288, 291, 297, 301, 303, 307, 310, 312, 315, 321, 323, 325, 327, 329, 331, 332, 334, 335, 339, 340, 342, 345, 348, 350, 351, 352, 355, 356, 357, 358, 365, 367, 369, 374, 375, 377, 381, 385, 387, 388, 393, 395, 407, 408, 410, 415, 419, 423, 424, 433, 438, 441, 442, 446, 447, 448, 450, 451, 452, 461, 463, 468, 477, 481, 488, 493, 494, 495, 496, 498, 502, 503, 505, 508, 509, 510, 512, 513, 518, 519, 521, 522, 528, 532, 535, 538, 541, 544, 547, 557, 564, 569, 5

In [12]:
extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveforms, peak_sign='neg')
print(extremum_channels_ids)



{0: '3', 1: '5', 2: '6', 3: '7', 4: '8', 6: '12', 7: '197', 8: '16', 9: '18', 10: '21', 11: '23', 12: '26', 13: '28', 14: '33', 15: '38', 16: '38', 19: '41', 21: '47', 22: '48', 23: '50', 24: '54', 27: '57', 28: '59', 29: '59', 30: '61', 31: '67', 32: '67', 34: '70', 36: '72', 37: '72', 40: '76', 42: '78', 45: '81', 48: '83', 49: '84', 50: '85', 52: '87', 55: '97', 56: '100', 58: '106', 59: '108', 60: '111', 62: '117', 64: '118', 65: '120', 66: '121', 68: '123', 69: '126', 70: '133', 71: '191', 72: '136', 73: '137', 74: '138', 76: '264', 77: '145', 78: '147', 79: '148', 82: '156', 84: '157', 87: '163', 88: '165', 89: '167', 90: '168', 91: '169', 92: '172', 93: '173', 95: '175', 96: '176', 97: '177', 98: '178', 99: '180', 101: '180', 103: '183', 104: '188', 105: '189', 106: '192', 108: '197', 109: '198', 110: '201', 111: '202', 112: '202', 113: '203', 115: '206', 116: '207', 117: '209', 118: '210', 119: '211', 121: '213', 122: '214', 124: '216', 125: '217', 127: '224', 128: '225', 129: 

In [13]:
from spikeinterface.postprocessing import compute_spike_amplitudes, compute_principal_components
from spikeinterface.exporters import export_to_phy

# the waveforms are sparse so it is faster to export to phy
folder = '/home/mmp/disktb/mmpatil/codbase/MEA_Analysis/SpikeSortingPipeline/sorting_intermediate/waveforms/'

we = si.extract_waveforms(recording_chunk,sorting_good,folder=folder,sparse=True)
# some computations are done before to control all options
compute_spike_amplitudes(we)
compute_principal_components(we, n_components=3, mode='by_channel_global')

# the export process is fast because everything is pre-computed
export_to_phy(we, output_folder='path/to/phy_folder')

extract waveforms shared_memory:   0%|          | 0/300 [00:00<?, ?it/s]

extract waveforms shared_memory:   0%|          | 0/300 [00:00<?, ?it/s]

extract waveforms shared_memory:   0%|          | 0/300 [00:00<?, ?it/s]

extract waveforms memmap:   0%|          | 0/300 [00:00<?, ?it/s]

extract amplitudes:   0%|          | 0/300 [00:00<?, ?it/s]

Fitting PCA:   0%|          | 0/496 [00:00<?, ?it/s]

Projecting waveforms:   0%|          | 0/496 [00:00<?, ?it/s]



write_binary_recording:   0%|          | 0/300 [00:00<?, ?it/s]

extract PCs:   0%|          | 0/300 [00:00<?, ?it/s]

Run:
phy template-gui  /mnt/disk15tb/mmpatil/codbase/MEA_Analysis/SpikeSortingPipeline/path/to/phy_folder/params.py


In [14]:
!phy template-gui  /mnt/disk15tb/mmpatil/codbase/MEA_Analysis/SpikeSortingPipeline/path/to/phy_folder/params.py

[33m16:18:13.711 [W] model:667            Skipping spike waveforms that do not exist, they will be extracted on the fly from the raw data as needed.[0m
[31m16:18:14.509 [E] __init__:62          An error has occurred (ValueError): attempt to get argmax of an empty sequence
Traceback (most recent call last):
  File "/home/mmp/.local/bin/phy", line 8, in <module>
    sys.exit(phycli())
  File "/usr/lib/python3/dist-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/usr/lib/python3/dist-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/usr/lib/python3/dist-packages/click/core.py", line 1659, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/lib/python3/dist-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/lib/python3/dist-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/usr/lib/p

In [15]:
for unit_id in we.unit_ids:
    template = we.get_template(unit_id)
    print(f"Unit {unit_id} - template shape: {template.shape}")

Unit 0 - template shape: (140, 4)
Unit 1 - template shape: (140, 4)
Unit 2 - template shape: (140, 5)
Unit 3 - template shape: (140, 5)
Unit 4 - template shape: (140, 5)
Unit 6 - template shape: (140, 4)
Unit 7 - template shape: (140, 2)
Unit 8 - template shape: (140, 2)
Unit 9 - template shape: (140, 5)
Unit 10 - template shape: (140, 4)
Unit 11 - template shape: (140, 3)
Unit 12 - template shape: (140, 3)
Unit 13 - template shape: (140, 4)
Unit 14 - template shape: (140, 1)
Unit 15 - template shape: (140, 5)
Unit 16 - template shape: (140, 5)
Unit 19 - template shape: (140, 5)
Unit 21 - template shape: (140, 5)
Unit 22 - template shape: (140, 4)
Unit 23 - template shape: (140, 4)
Unit 24 - template shape: (140, 6)
Unit 27 - template shape: (140, 5)
Unit 28 - template shape: (140, 5)
Unit 29 - template shape: (140, 5)
Unit 30 - template shape: (140, 3)
Unit 31 - template shape: (140, 5)
Unit 32 - template shape: (140, 5)
Unit 34 - template shape: (140, 3)
Unit 36 - template shape: (14

In [None]:


#locations = sp.compute_unit_locations(waveforms)
locations = recording_chunk.get_channel_locations()
print(type(locations))

#np.savetxt("unitloc_10mins.txt",locations)
fig,ax = plt.subplots(figsize=(10.5,6.5))
#sw.plot_probe_map(recording1,ax=ax,with_channel_ids=False)
for x,y in locations:
    ax.scatter(x,y, s=3)
ax.invert_yaxis()
#ax.set_title('Aug02 - 19388')
#plt.savefig('./plots/Aug02_19388_unsorted_units.pdf',format='pdf')

In [None]:
locations = sp.compute_unit_locations(waveforms)
#locations = recording_chunk.get_channel_locations()
print(len(locations))

#np.savetxt("unitloc_10mins.txt",locations)
fig, ax = plt.subplots(figsize=(10.5,6.5))
#sw.plot_probe_map(recording1,ax=ax,with_channel_ids=False)
for x,y in locations:
    ax.scatter(x,y, s=3)

ax.invert_yaxis()
#ax.set_title('Aug 02 - 19388')
#plt.savefig('./plots/Aug02_19388_sorted_units.pdf',format='pdf')

In [None]:
import spikeinterface_gui
app = spikeinterface_gui.mkQApp() 

# create the mainwindow and show
win = spikeinterface_gui.MainWindow(waveforms)
win.show()
# run the main Qt6 loop
app.exec_()#Need to implement compute noise levels.

In [None]:
len(recording_chunk.channel_ids)

In [None]:
fig, ax1 = plt.subplots(figsize=(8,4))
spike_times = {}
for idx, unit_id in enumerate(waveforms.unit_ids):
    #print(unit_id)
    spike_train = sorting_good.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=120*fs)
    #print(spike_train)
    if len(spike_train) > 0:
        spike_times[idx] = spike_train / float(fs)
        #print(spike_times[unit_id])
       # print(unit_id*np.ones_like(spike_times[unit_id]))
        # ax1.plot(spike_times[idx],waveforms.channel_ids_to_indices([str(int(extremum_channels_ids[unit_id]))])*np.ones_like(spike_times[idx]),
        #                      marker='|', mew=1, markersize=3,
        #                      ls='',color='royalblue')
        ax1.plot(spike_times[idx],unit_id*np.ones_like(spike_times[idx]),
                             marker='|', mew=1, markersize=3,
                             ls='',color='royalblue')
        ax1.set_title('Raster plot')
        ax1.set_xlabel('time s')
        ax1.set_ylabel("channels")
        #ax1.set_yticks(range(16))
#plt.savefig('./plots/sortedunitsraster.pdf',format='pdf')
                       

In [None]:
fig, ax1 = plt.subplots(figsize=(8,4))
spike_times = {}
for idx, unit_id in enumerate(non_violated_units):
    #print(unit_id)
    spike_train = sorting_good.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=120*fs)
    #print(spike_train)
    if len(spike_train) > 0:
        spike_times[idx] = spike_train / float(fs)
        #print(spike_times[unit_id])
       # print(unit_id*np.ones_like(spike_times[unit_id]))
        # ax1.plot(spike_times[idx],waveforms.channel_ids_to_indices([str(int(extremum_channels_ids[unit_id]))])*np.ones_like(spike_times[idx]),
        #                      marker='|', mew=1, markersize=3,
        #                      ls='',color='royalblue')
        ax1.plot(spike_times[idx],unit_id*np.ones_like(spike_times[idx]),
                             marker='|', mew=1, markersize=3,
                             ls='',color='royalblue')
        ax1.set_title('Raster plot')
        ax1.set_xlabel('time s')
        #ax1.set_yticks(range(16))
        ax1.set_ylabel("channels")

In [None]:
unit_locations = sp.compute_unit_locations(waveforms)

print(len(unit_locations))

unit_locations_dict = {}
 
for i in range(0,len(unit_locations)):
    unit_locations_dict[metrics.index[i]] = unit_locations[i]

#filter the unit_locations

filtered_locations = {key : value for key,value in unit_locations_dict.items() if key in non_violated_units}

print(filtered_locations)

In [None]:
corr_channel_ids = []
corr_unit_ids=[]
spike_times =[]
for idx, unit_id in enumerate(non_violated_units):
    #print(unit_id)
    spike_train = sorting_good.get_unit_spike_train(unit_id,start_frame=1*fs,end_frame=300*fs)
    spike_times.extend(spike_train)
    corr_unit_ids.extend([unit_id]*len(spike_train))
    channel = waveforms.channel_ids_to_indices([str(int(extremum_channels_ids[unit_id]))]) 
    corr_channel_ids.extend([channel[0]]*len(spike_train))



In [None]:
data = {
    'spike_times': spike_times,
    'spike_channel_ids': corr_channel_ids,
    'spike_unit_ids': corr_unit_ids
    
}
df = pd.DataFrame(data)

# Sort the DataFrame by 'spike_times' column in ascending order
df_sorted = df.sort_values(by='spike_times')



In [None]:
putative_unit_locX = []
putative_unit_locY = []
for x, y in list(filtered_locations.values()):
    putative_unit_locX.append(x)
    putative_unit_locY.append(y)

In [None]:
import numpy as np
from scipy.io import savemat


spiking_data = {
    'spike_frames': list(df_sorted['spike_times']),
    'spike_units': list(df_sorted['spike_channel_ids'] ),
    'spike_channels': list(df_sorted['spike_unit_ids']),
    'putative_unit_ids':list(filtered_locations.keys()),
    'putative_unit_locX':putative_unit_locX,
    'putative_unit_locY':putative_unit_locY

}

# Save the data as a .mat file
savemat('Syngap3_3Jan_16657.mat', {"spiking_data":spiking_data})

In [None]:
list(df_sorted['spike_times'])

In [None]:
t_start = 0 
t_end = int(300*fs)
dt = 1
#initialising the spike train.
units= len(waveforms.unit_ids)
frame_numbers = t_end
spike_array = np.zeros((units,frame_numbers), dtype= int)
for idx, unit_id in enumerate(waveforms.unit_ids):
    spike_train = sorting_good.get_unit_spike_train(unit_id,start_frame=t_start,end_frame=t_end)
    for spike_time in spike_train:
        spike_array[idx,spike_time] = 1

print(spike_array)

spike_array.shape

In [None]:
np.savez_compressed('spike_array_22Aug.npz',spike_array)

In [None]:
from scipy.io import savemat
savemat('spikearray_matfile_22aug.mat',{'spikearray':spike_array})

In [None]:
with np.load('spike_array_compressed_blockactivity.npz') as data:
    decompressed_data = data['arr_0']

print(np.array_equal(spike_array, decompressed_data))


In [None]:

extremum_channels_ids =sp.get_template_extremum_channel(waveforms,peak_sign ='both',mode='at_index')
print(extremum_channels_ids)



In [None]:

fig, ax = plt.subplots()
for i, unit_id in enumerate([15]):
    
    wf = waveforms.get_waveforms(unit_id)
    
    #print(int(extremum_channels[unit_id]))
    number = waveforms.channel_ids_to_indices([str(int(extremum_channels_ids[unit_id]))])
    print(number)
    ax.plot(wf[:,:, number[0]].T, lw=3,color='cornflowerblue')
    ax.set_title(f"waveforms unit {unit_id}")
    ax.set_ylabel("Amplitude (µV)")
    ax.set_xlabel("Sampled timepoints (5e-2 ms)")
  
plt.savefig('./plots/waveforms_unit12.pdf',format='pdf')

In [None]:
wf.shape

In [None]:




fig, ax = plt.subplots()
for i, unit_id in enumerate([15]):
    
    template = waveforms.get_template(unit_id,mode='median')
    print(template.shape)
    #print(int(extremum_channels[unit_id]))
    number = waveforms.channel_ids_to_indices([str(int(extremum_channels_ids[unit_id]))])
    print(number)
    ax.plot(template[:, number[0]].T, lw=3,color='royalblue')
    ax.set_title(f"template unit {unit_id}")
    ax.set_ylabel("Amplitude (µV)")
    ax.set_xlabel("Sampled timepoints (5e-2 ms)")

plt.savefig('./plots/template_unit12.pdf',format='pdf')

In [None]:
trace =recording_chunk.get_traces(0,int(0.2*fs))
trace.shape
# plt.figure(figsize=(10,3))
# plt.plot(trace)

In [None]:
%matplotlib widget
sw.plot_unit_summary(waveforms,unit_id=27)

In [None]:
%matplotlib widget
sw.plot_unit_templates(waveform_good,unit_ids=[73])

In [None]:
trough = min(template[:,channel_association_dict[number]])
peak = max(template[:,channel_association_dict[number]])
peak_index = np.where(template[:,channel_association_dict[number]]==peak)[0]
trough_index = np.where(template[:,channel_association_dict[number]]==trough)[0]
trough_to_peak = (peak_index-trough_index)[0]
#index at half peak. 
half_peak_index = np.where(template[:,channel_association_dict[number]]<=0.5*peak)
half_peak_index= half_peak_index[0][np.where(half_peak_index[0]>peak_index)[0][0]]
repol_time = half_peak_index - peak_index[0]


In [None]:
waveform_good.unit_ids

In [None]:
def get_channel_association_dict(recording):
    """
    Takes in recording object and gives a channel association dict.
    
    """
    channel_ids = recording.get_channel_ids()

    return {int(y):x for x,y in enumerate(channel_ids) }

def get_template_characteristics(waveforms):

    """
    returns a pd dataframe of unit templates and their characterisitcs Amplitude, trough to peak time, repolarisation time.

    """

    #get unit ids in the waveform object

    unit_ids = waveforms.unit_ids
    #get extremum channels
    extremum_channels_ids =spikeinterface.full.get_template_extremum_channel(waveform_good, peak_sign='both')

    df = []
    for unit_id in unit_ids:
        
        template = waveforms.get_template(unit_id)

        template_segment = template[:,waveforms.channel_ids_to_indices([str(int(extremum_channels[unit_id]))])]
        trough = min(template_segment)
        peak = max(template_segment)
        peak_index = np.where(template_segment==peak)[0]
        trough_index = np.where(template_segment==trough)[0]
        trough_to_peak = (peak_index-trough_index)[0]
        #index at half peak. 
        half_peak_index = np.where(template_segment<=0.5*peak)
        half_peak_index_2= half_peak_index[0][np.where(half_peak_index[0]>peak_index)[0][0]]
        repol_time = half_peak_index_2 - peak_index[0]
        data = {"Unit": unit_id,"Amp":trough[0],"TTP":trough_to_peak,"REP":repol_time}
        df.append(data)
    df = pd.DataFrame(df)
    return df
        
    

In [None]:

data= get_template_characteristics(waveform_good)
print(data)

In [None]:
data.to_excel('new_temp_metrics.xlsx')

In [None]:
%debug

In [None]:
w = sw.plot_unit_templates(waveform_good, unit_ids=[183],plot_channels=False )

In [None]:
w = sw.plot_unit_waveforms(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_templates(waveform_good, unit_ids=[2,4,7])
w = sw.plot_unit_probe_map(waveform_good, unit_ids=[2,4,7])

In [None]:
import spikeinterface_gui

In [None]:
import importlib
importlib.reload(spikeinterface_gui)

In [None]:




# This creates a Qt app
app = spikeinterface_gui.mkQApp() 

# create the mainwindow and show
win = spikeinterface_gui.MainWindow(waveforms)
win.show()
# run the main Qt6 loop
app.exec_()



In [None]:
# Create a sample dictionary
my_dict = {'key1': 'value1', 'key2': 'value2', 'key3': 'value1', 'key4': 'value3', 'key5':'value2'}

# Create an empty dictionary to store our results
result_dict = {}

# Loop through each key-value pair in my_dict
for key, value in my_dict.items():
    # Check if the value already exists in result_dict
    if value in result_dict:
        # If it does, append the current key to the list of keys that have the same value
        result_dict[value].append(key)
    else:
        # If it doesn't, create a new entry in result_dict with the value as the key and a list containing the current key
        result_dict[value] = [key]

# Create a list called output that contains only the values from result_dict that have more than one key
output = []
for value in result_dict.values():
    if len(value) > 1:
        output.append(value)

# Print the output to the console
print(output)