In [23]:
from collections import namedtuple
import file_utility
import os
import pandas as pd

import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as sorters
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import json
import pickle
import spikeinterfaceHelper
from tqdm import tqdm
import numpy as np
import settings
import logging
from types import SimpleNamespace
import matplotlib.pylab as plt

In [21]:
from scipy.signal import butter,filtfilt

def filterRecording(recording, sampling_freq, lp_freq=300,hp_freq=6000,order=3):
    fn = sampling_freq / 2.
    band = np.array([lp_freq, hp_freq]) / fn

    b, a = butter(order, band, btype='bandpass')

    if not (np.all(np.abs(np.roots(a)) < 1) and np.all(np.abs(np.roots(a)) < 1)):
        raise ValueError('Filter is not stable')
    
    for i in tqdm(range(recording._timeseries.shape[0])):
        recording._timeseries[i,:] = filtfilt(b,a,recording._timeseries[i,:])

    return recording

def plot_waveforms(sorted_df, figure_path):
    print('I will plot the waveform shapes for each cluster.')
    for cluster in tqdm(range(len(sorted_df))):
        #extract waveforms from dataframe
        waveforms = sorted_df.waveforms[cluster]
        waveforms = np.stack([w for w in waveforms if w is not None])
        max_channel = sorted_df.max_channel.values[cluster]
        cluster_id = sorted_df.unit_id[cluster]
        tetrode = max_channel//settings.num_tetrodes #get the tetrode number
        
        #plot spike waveform from the same tetrode
        fig = plt.figure()
        for i in range(4):
            ax = fig.add_subplot(2,2,i+1)
            ax.plot(waveforms[:,tetrode+i,:].T,color='lightslategray')
            template = waveforms[:,tetrode+i,:].mean(0)
            ax.plot(template, color='red')
            
        plt.savefig(figure_path + '/' + sorted_df.session_id[cluster] + '_' + str(cluster_id) + '_waveforms.png', dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()


### Define file paths

In [26]:
sinput = SimpleNamespace()
soutput = SimpleNamespace()

sinput.recording_to_sort = '/media/data2/pipeline_testing_data/M5_2018-03-06_15-34-44_of'

#make output folder
try:
    os.mkdir(sinput.recording_to_sort+'/processed/')
except FileExistsError:
    print('Folder already there')
    
sinput.probe_file =   'sorting_files/tetrode_16.prb'
sinput.sort_param = 'sorting_files/params.json'
sinput.tetrode_geom = 'sorting_files/geom_all_tetrodes_original.csv'
sinput.dead_channel = sinput.recording_to_sort +'/dead_channels.txt'

soutput.sorter_df = sinput.recording_to_sort +'/processed/sorted_df.pkl'
soutput.sorter_curated_df = sinput.recording_to_sort +'/processed/sorted_curated_df.pkl'

soutput.waveform_figure = sinput.recording_to_sort+'/processed'

Folder already there


### Loading files

In [4]:
signal = file_utility.load_OpenEphysRecording(sinput.recording_to_sort)
geom = pd.read_csv(sinput.tetrode_geom, header=None).values
bad_channel = file_utility.getDeadChannel(sinput.dead_channel)

Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...
Loading continuous data...


### Creating extractor and filtering

In [5]:
recording = se.NumpyRecordingExtractor(signal,settings.sampling_rate,geom)
recording = recording.load_probe_file(sinput.probe_file) #load probe definition
filterRecording(recording,settings.sampling_rate) #filer recording


100%|██████████| 16/16 [00:18<00:00,  1.15s/it]


<spikeextractors.extractors.numpyextractors.numpyextractors.NumpyRecordingExtractor at 0x7f70d1413c90>

In [6]:
recording = st.preprocessing.remove_bad_channels(recording, bad_channel_ids=bad_channel) #remove bad channel

### Sort recordings

In [7]:
with open(sinput.sort_param) as f:
    param = json.load(f)
sorting_ms4 = sorters.run_sorter(settings.sorterName,recording, output_folder=settings.sorterName,
    adjacency_radius=param['adjacency_radius'], detect_sign=param['detect_sign'],verbose=True)

    

Using 2 workers.
Using tmpdir: /tmp/tmp6tovl658
Num. workers = 2
Preparing /tmp/tmp6tovl658/timeseries.hdf5...
Preparing neighborhood sorters (M=15, N=57851904)...
Neighboorhood of channel 0 has 15 channels.
Neighboorhood of channel 2 has 15 channels.
Detecting events on channel 1 (phase1)...
Detecting events on channel 3 (phase1)...
Elapsed time for detect on neighborhood: 0:00:20.761078
Num events detected on channel 1 (phase1): 89243
Computing PCA features for channel 1 (phase1)...
Elapsed time for detect on neighborhood: 0:00:20.902034
Num events detected on channel 3 (phase1): 101180
Computing PCA features for channel 3 (phase1)...
Clustering for channel 3 (phase1)...
Clustering for channel 1 (phase1)...
Found 1 clusters for channel 1 (phase1)...
Computing templates for channel 1 (phase1)...
Found 1 clusters for channel 3 (phase1)...
Computing templates for channel 3 (phase1)...
Re-assigning events for channel 1 (phase1)...
Re-assigning events for channel 3 (phase1)...
Neighboorho

Computing PCA features for channel 9 (phase2)...
No duplicate events found for channel 8 in phase2
Found 2 clusters for channel 6 (phase2)...
Neighboorhood of channel 10 has 15 channels.
Computing PCA features for channel 11 (phase2)...
No duplicate events found for channel 10 in phase2
Clustering for channel 9 (phase2)...
Clustering for channel 11 (phase2)...
Found 1 clusters for channel 9 (phase2)...
Neighboorhood of channel 9 has 15 channels.
Computing PCA features for channel 10 (phase2)...
No duplicate events found for channel 9 in phase2
Found 2 clusters for channel 11 (phase2)...
Neighboorhood of channel 11 has 15 channels.
Computing PCA features for channel 12 (phase2)...
No duplicate events found for channel 11 in phase2
Clustering for channel 10 (phase2)...
Clustering for channel 12 (phase2)...
Found 1 clusters for channel 12 (phase2)...
Neighboorhood of channel 12 has 15 channels.
Computing PCA features for channel 13 (phase2)...
No duplicate events found for channel 12 in p

AttributeError: 'types.SimpleNamespace' object has no attribute 'sorter'

### Calculate some sorting metrics

In [8]:
st.postprocessing.get_unit_max_channels(recording, sorting_ms4, max_spikes_per_unit=100)
st.postprocessing.get_unit_waveforms(recording, sorting_ms4, max_spikes_per_unit=100)

for id in sorting_ms4.get_unit_ids():
    number_of_spikes = len(sorting_ms4.get_unit_spike_train(id))
    mean_firing_rate = number_of_spikes/(recording._recording._timeseries.shape[1]/settings.sampling_rate)
    sorting_ms4.set_unit_property(id,'number_of_spikes',number_of_spikes)
    sorting_ms4.set_unit_property(id, 'mean_firing_rate', mean_firing_rate)


### Save sorted result

In [27]:
session_id = sinput.recording_to_sort.split('/')[-1]
sorter_df=spikeinterfaceHelper.sorter2dataframe(sorting_ms4,session_id)
sorter_df.to_pickle(soutput.sorter_df)

### Curate sortings

In [13]:
sorting_ms4_curated = st.curation.threshold_snr(sorting=sorting_ms4, recording = recording,
  threshold =2, threshold_sign='less', max_snr_spikes_per_unit=100, apply_filter=False) #remove when less than threshold
print(sorting_ms4_curated.get_unit_ids())

sorting_ms4_curated=st.curation.threshold_firing_rate(sorting_ms4_curated,
    threshold=0.5, threshold_sign='less')
print(sorting_ms4_curated.get_unit_ids())

sorting_ms4_curated=st.curation.threshold_isi_violations(sorting_ms4_curated, threshold = 0.9)
print(sorting_ms4_curated.get_unit_ids())

sorting_ms4_curated = st.curation.threshold_firing_rate(sorting=sorting_ms4_curated,threshold=0.5,threshold_sign='less')
print(sorting_ms4_curated.get_unit_ids())


[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20, 21]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20, 21]
[1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20, 21]
[1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20, 21]


### Save curated data

In [28]:
#save curated data
curated_sorter_df = spikeinterfaceHelper.sorter2dataframe(sorting_ms4_curated, session_id)
curated_sorter_df.to_pickle(soutput.sorter_curated_df)

### plot spike waveforms

In [29]:
plot_waveforms(curated_sorter_df, soutput.waveform_figure)

  0%|          | 0/18 [00:00<?, ?it/s]

I will plot the waveform shapes for each cluster.


100%|██████████| 18/18 [00:16<00:00,  1.10it/s]
