In [None]:
# Basic scientific python imports
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook 
# Spikeinterface imports (could do this cleaner, oh well)
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.preprocessing as spre
import spikeinterface.exporters as sxp
import spikeinterface.widgets as sw
import spikeinterface.full as si
import probeinterface as pi
import spikeinterface.curation as scur

# import spikeinterface_gui
# Core python imports
import os
import time
from pathlib import Path
from IPython.display import Audio

In [None]:
## If reading series of intan recordings
rec_name = 'poke1_230520_153135'
path_to_folder = Path('C:/Users/lwood39/Documents/VNCMP/2023-05-20/' + rec_name)
dircontents = os.listdir(path_to_folder)
file_names = [x for x in dircontents if '.rhd' in x]
recording_list = []
for file in file_names:
    path_to_file = os.path.join(path_to_folder, file)
    recording_list.append(se.IntanRecordingExtractor(path_to_file, stream_id='0'))
recording = si.concatenate_recordings(recording_list)
display(recording)

# Uncomment to grab just a section of a recording
# recording = recording.frame_slice(start_frame=0, end_frame=int(231*30000))

## If reading open ephys recording session
# path_to_folder = Path('G:/SponbergLab/Data/Leo_2023-03-29_15-03-40/Record Node 103/experiment1')
# # path_to_folder = Path('G:/VNCMP/20230308/2023-03-08_13-08-30/Record Node 104')
# recording = se.read_openephys(path_to_folder, block_index=0, stream_id='0')
# recording = si.SelectSegmentRecording(recording, 0)

In [None]:
# Remove analog input channels if present, not needed for spike sorting
if any('ADC' in s for s in recording.get_channel_ids()):  
    recording = recording.remove_channels([x for x in recording.get_channel_ids() if 'ADC' in x])
recording.get_channel_ids()

In [None]:
probe = pi.read_probeinterface('A32_A1x32-Poly5-6mm-35s-100.json')
pi.plotting.plot_probe_group(probe, with_channel_index=True, with_device_index=True)
# recording.set_probegroup(probe)
recording.set_probe(probe.probes[0], in_place=True)
recording.get_probe()

In [None]:
# Save a cache of filtered data and cache of raw data
recording_cache_filter = si.bandpass_filter(recording, freq_min=350, freq_max=5000).save(format='binary', n_jobs=8, chunk_duration='10s')
recording_cache_raw = recording.save(format='binary', n_jobs=8, chunk_duration='10s')

## Run Sorter

In [None]:
si.available_sorters()

In [None]:
sorter = 'mountainsort5'
print(ss.get_default_sorter_params(sorter))
ss.get_sorter_params_description(sorter)

In [None]:
params = ss.get_default_sorter_params(sorter)
params['scheme'] = '2'
params['detect_sign'] = 1
params['detect_threshold'] = 5.5
params['freq_min'] = 350
params['freq_max'] = 7000
params['npca_per_subdivision'] = 15
params['scheme2_phase1_detect_channel_radius'] = 50
params['scheme2_training_duration_sec'] = 800
params['scheme2_max_num_snippets_per_training_batch'] = 10000
params['scheme3_block_duration_sec'] = 500
   
tic = time.perf_counter()
sort = ss.run_sorter(
    sorter,
    recording=recording_cache_raw,
    output_folder=sorter,
#     docker_image="spikeinterface/" + sorter + "-compiled-base:latest", #<- use for kilosort
    verbose=False,
    **params)
print(f'{time.perf_counter()-tic} seconds elapsed')
Audio('notification-sound.wav', autoplay=True)

print(sort)

In [None]:
job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True)

# wave_extract = si.extract_waveforms(
#     recording_cache_justfilter, 
#     sort, 
#     './waveforms',
#     ms_before=2., ms_after=2.,
#     max_spikes_per_unit=100000,
#     overwrite=True, 
#     **job_kwargs)
# # Find redundant units, perform extraction again with those removed
# sort_no_redundant = scur.remove_redundant_units(wave_extract, align=True)
wave_extract = si.extract_waveforms(
    recording_cache_justfilter, 
#     sort_no_redundant,
    sort,
    './waveforms',
    ms_before=2., ms_after=2.,
    max_spikes_per_unit=100000,
    overwrite=True, 
    **job_kwargs)

## Phy GUI

In [None]:
phy_save_path = './phy_folder/' + sorter + '_' + rec_name
sxp.export_to_phy(wave_extract, 
                  phy_save_path, 
                  remove_if_exists=True,
                  **job_kwargs)

# save record of params
with open(phy_save_path+'/params_log.txt', 'w') as f: 
    for key, value in params.items(): 
        f.write('%s:%s\n' % (key, value))
        
        

 To run Phy, use cmd or Powershell, and do one of the following:
 
 1. Run the command spit out by the cell above, often something like:
 
 ```phy template-gui  C:\Users\lwood39\Documents\AutoSpikeSort\phy_folder\mountainsort5_poke1_230520_153135\params.py```
 
 
 2. Navigate to phy_folder created by the above cell, then run phy command
 
 Example:
 
 ```cd C:/Users/lwood39/Documents/AutoSpikeSort/phy_folder_kilosort```
 
 ```phy template-gui params.py```
 

## WHY IS IT CRASHING? Some general notes

I've noticed a few common patterns that will lead to crashes. Here are some notes on those, in no particular order.

- **Kilosort is crashing. Stack trace has something related to a gpuArray error, and/or mentions something like a nan**

  This almost always seems to be one of the batches of kilosort containing no detected spikes. Kilosort runs on independent "batches", sections of your data of a certain length. If one of those has no spikes, it seems to typically create a nan value that gums up the works. The easy solution is to either (a) Only feed in data that has spiking of some kind throughout, or (b) increase the batch size, with the 'NT' parameter

# Further plotting or exploration:

In [None]:
# To plot traces straight:
# sw.plot_timeseries(
#     recording_cache_raw, 
#     time_range=(69, 420), 
#     channel_ids=['0'], 
#     return_scaled=True)
# plt.show()

# To get traces of a specific channel from specific frames:
# data = recording_cache_raw.get_traces(
#     start_frame=your_start_frame_here, 
#     end_frame=your_end_frame_here,
#     return_scaled=True, 
#     channel_ids=['3'])