# Test MoutainSort5 on high density probes

In [16]:
import numpy as np
import spikeinterface.preprocessing as spre
import mountainsort5 as ms5
from mountainsort5.util import create_cached_recording
from utils.loaders import rhd_load, load_intan_impedance
from pathlib import Path
from probeinterface import read_probeinterface
from utils import clean_channels_by_imp
import spikeinterface.extractors as se
import spikeinterface as si
import spikeinterface.widgets as sw
import spikeinterface.exporters as sexp

## Try with scheme1
Scheme1 is meant to test parameters

Load testing data

In [2]:
rhd_folder = Path(r'D:\Yongzhi_Sun\01_Raw_Data\Yongzhi_Sun\intan\curved_120_design1\20240905_m1\w5')
recording = rhd_load(rhd_folder, (3, 13))
probe = read_probeinterface('probe/120_curved.json').probes[0]
recording = recording.set_probe(probe)

Loading RHD2000 files: 100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


Filter channels

In [3]:
imps = load_intan_impedance(next(rhd_folder.glob('*.csv')))
recording = clean_channels_by_imp(recording, imps, 3e6)

Preprocessing

In [4]:
recording_filtered = spre.bandpass_filter(recording, freq_min=300, freq_max=6000, dtype=np.float32)
recording_preprocessed = spre.whiten(recording_filtered)

Sorting

Params to care about:
* sign of the spikes, default negative
* detect channel radius: merge duplicated spike counts on nearby channels, default infinite
* mask radius: remove interference to nearby channels

In [8]:
data_dir = Path(r'G:\mountainsort_result') / rhd_folder.relative_to(r'D:\Yongzhi_Sun\01_Raw_Data\Yongzhi_Sun\intan')
data_dir.mkdir(exist_ok=True, parents=True)
# cache the recording to a temporary directory for efficient reading
recording_cached = create_cached_recording(recording_preprocessed, folder=data_dir)

# use scheme 1
sorting = ms5.sorting_scheme1(
    recording=recording_cached,
    sorting_parameters=ms5.Scheme1SortingParameters(
        detect_sign=-1,
        detect_channel_radius=50,
        snippet_mask_radius=100
    )
)

write_binary_recording: 100%|##########| 601/601 [02:12<00:00,  4.53it/s]


Number of channels: 108
Number of timepoints: 12006400
Sampling frequency: 20000.0 Hz
Channel 0: [ 240.03134393 -117.87395681]
Channel 1: [ 140.21150298 -111.87401064]
Channel 2: [  40.39166204 -105.87406447]
Channel 3: [-59.4281789 -99.8741183]
Channel 4: [-159.24801985  -93.87417213]
Channel 5: [-259.06786079  -87.87422595]
Channel 6: [-358.88770173  -81.87427978]
Channel 7: [-458.70754268  -75.87433361]
Channel 8: [-558.52738362  -69.87438744]
Channel 9: [-658.34722456  -63.87444126]
Channel 10: [-758.16706551  -57.87449509]
Channel 11: [-857.98690645  -51.87454892]
Channel 12: [-957.80674739  -45.87460275]
Channel 13: [-1057.62658834   -39.87465658]
Channel 14: [-1157.44642928   -33.8747104 ]
Channel 15: [-1257.26627022   -27.87476423]
Channel 16: [-1357.08611117   -21.87481806]
Channel 17: [-1456.90595211   -15.87487189]
Channel 18: [-1556.72579305    -9.87492571]
Channel 19: [-1656.545634      -3.87497954]
Channel 20: [-1756.36547494     2.12496663]
Channel 21: [-1856.18531588   

Analysis

In [13]:
print(sorting)
se.NpzSortingExtractor.write_sorting(sorting, data_dir / 'sorting.npz')

NumpySorting: 86 units - 1 segments - 20.0kHz


In [19]:
analyzer = si.create_sorting_analyzer(sorting, recording_preprocessed, format='binary_folder', folder=data_dir / 'analysis')
print(analyzer)

estimate_sparsity: 100%|##########| 601/601 [01:16<00:00,  7.81it/s]


SortingAnalyzer: 108 channels - 86 units - 1 segments - binary_folder - sparse - has recording
Loaded 0 extensions


In [22]:
job_kwargs = dict(n_jobs=8, chunk_duration="1s", progress_bar=True)
compute_dict = {
    "spike_amplitudes": {},
    'random_spikes': {'method': 'uniform', 'max_spikes_per_unit': 500},
    'waveforms': {'ms_before': 1.0, 'ms_after': 2.0},
    'templates': {'operators': ["average", "median", "std"]},
    "correlograms": {"bin_ms": 0.1},
    'noise_levels': {},
}
analyzer.compute(compute_dict, **job_kwargs)

compute_waveforms: 100%|██████████| 601/601 [00:16<00:00, 36.93it/s]
Compute : spike_amplitudes: 100%|██████████| 601/601 [00:22<00:00, 27.18it/s]


In [23]:
qm = {'firing_range': {'bin_size_s': 5, 'percentiles': (5, 95)},
 'isi_violation': {'isi_threshold_ms': 1.5, 'min_isi_ms': 0},
 'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'}}
analyzer.compute("quality_metrics", qm)

  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


<spikeinterface.qualitymetrics.quality_metric_calculator.ComputeQualityMetrics at 0x20fab93d450>

In [26]:
sexp.export_to_phy(analyzer, data_dir / "phy", verbose=True, n_jobs=8)

write_binary_recording: 100%|##########| 601/601 [02:15<00:00,  4.43it/s]
Fitting PCA: 100%|██████████| 86/86 [00:04<00:00, 21.00it/s]
Projecting waveforms: 100%|██████████| 86/86 [00:00<00:00, 632.34it/s]
extract PCs: 100%|##########| 601/601 [02:49<00:00,  3.54it/s]


Run:
phy template-gui  G:\mountainsort_result\curved_120_design1\20240905_m1\w5\phy\params.py
