# Ground truth recordings for validation of spike sorting algorithms


Spampinato from Institut de la Vision have publish data from mice retina recorded with from dense array.
The data contain one ground truth to benchmark spike sorting tools.

Here the official publication of this open dataset:
https://zenodo.org/record/1205233#.W9mq1HWLTIF


This datasets was used by Pierre Yger publish spyking circus:
https://elifesciences.org/articles/34518


Here a notebook that compare some sorter on theses recording.

Each recording have several units and **one** of theses have a ground truth recorded with juxta cellular.
The SNR on MEA is differents on each file so we can easily compare the false positive and true positive score by sorter and SNR.


all tar.gz files must in "rawfiles" path


In [25]:
import zipfile, tarfile
import re
import os, shutil

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import spiketoolkit as st
import spikeextractors as se
import spikewidgets as sw

%matplotlib notebook

## global variables

In [12]:
# my working path
# basedir = '/media/samuel/SamCNRS/DataSpikeSorting/pierre/zenodo/'
basedir = '/mnt/data/sam/DataSpikeSorting/pierre_zenodo/'

# input file
recording_folder = basedir + 'rawfiles/'

# where output will be
working_folder = basedir + 'run_comparison/'

# ground truth information
ground_truth_folder = basedir + 'ground_truth/'

# file_list
rec_names = ['20160415_patch2','20160415_patch2']
             
# ambiguous juxta detection
# ['20170630_patch1', '20170627_patch1']
             

# sorter list
sorter_list = ['tridesclous', 'spykingcircus']


## Step 1 : unzip all

This extract tar.gz files to folder

In [None]:
for rec_name in rec_names:
    filename = recording_folder + rec_name + '.tar.gz'

    if os.path.exists(recording_folder+rec_name) and os.path.isdir(recording_folder+rec_name):
        continue
    t = tarfile.open(filename, mode='r|gz')
    t.extractall(recording_folder+rec_name)

## Step 2 : detect ground truth spike on juxta cellular

This is done in a separate script to keep this doc simple.

The following script:
  * detect GT spikes
  * detect on which MEA channel the spike is
  * get the mea peak value
  * measure the noise on this channel
  * make the ratio peak_value/noise_mad


In [None]:
%run detect_ground_truth_spike_on_juxta.py

In [11]:
gt_info = pd.read_excel(ground_truth_folder+'gt_info.xlsx', index_col=0)
gt_info.loc[rec_names, :]

Unnamed: 0,nb_spike,max_on_channel,max_value,mea_peak_snr,noise_mad
20160415_patch2,3514,69,32369.0,704.277142,45.9606
20170630_patch1,2368,232,32106.0,832.89232,38.5476
20170627_patch1,3144,221,32331.0,681.467523,47.4432


## Step 3: run sorters on all files

Important note : the file have 256 channels but only 252 are usefull.
The PRB file contain all channels needed so we need to explicit **grouping_property='group'**
to be sure to only take in account the channel in the unique group.

In [None]:
# make a recordings dict
recordings = {}
for rec_name in rec_names:
    dirname = recording_folder + rec_name + '/'

    for f in os.listdir(dirname):
        if f.endswith('.raw') and not f.endswith('juxta.raw'):
            raw_filename = dirname + f

    # raw files have an internal offset that depend on the channel count
    # a simple built header can be parsed to get it
    with open(raw_filename.replace('.raw', '.txt'), mode='r') as f:
        offset = int(re.findall('padding = (\d+)', f.read())[0])

    # recording
    rec = se.BinDatRecordingExtractor(raw_filename, 20000, 256, 'uint16', offset=offset, frames_first=True)
    chan_ids = rec.getChannelIds()

    rec = se.loadProbeFile(rec, basedir + 'mea_256.prb')

    recordings[rec_name] = rec



In [None]:
# run then all
results = st.sorters.run_sorters(sorter_list, recordings, working_folder, grouping_property='group', debug=False)

## Step 4 : compare sorter to ground truth

In [39]:
columns = ['tp_rate', 'fn_rate']
index = pd.MultiIndex.from_product([rec_names, sorter_list])
performance = pd.DataFrame(index=index, columns=columns)

results = st.sorters.collect_results(working_folder)
for rec_name, result_one_dataset in results.items():
    print()
    print(rec_name)
    for sorter_name, sorting in result_one_dataset.items():
        print(sorter_name)
        print(sorting)
        print(sorting.getUnitIds())

        # make the ground truth
        gt_indexes = np.fromfile(ground_truth_folder + rec_name + '/juxta_peak_indexes.raw', dtype='int64')
        gt_sorting = se.NumpySortingExtractor()
        gt_sorting.setTimesLabels(gt_indexes, np.zeros(gt_indexes.size, dtype='int64'))

        # run comparison
        comp = st.comparison.SortingComparison(gt_sorting, sorting, count=True)
        
        comp.counts
        performance.loc[(rec_name, sorter_name), 'tp_rate'] = float(counts['TP']) / counts['TOT_ST1'] * 100
        performance.loc[(rec_name, sorter_name), 'fn_rate'] = float(counts['FN']) / counts['TOT_ST1'] * 100
        
        # st.comparison.compute_performance(comp)

        # comp_table = sw.SortingComparisonTable(comp)
        # comp_table.display()

performance

AssertionError: files does not exist

## Compare run speed

In [27]:


run_times = pd.read_csv(working_folder + 'run_time.csv', sep='\t')
run_times.columns = ['rec_name', 'sorter_name', 'run_time']
run_times





Unnamed: 0,rec_name,sorter_name,run_time
0,20170630_patch1,tridesclous,417.43079
1,20170627_patch1,tridesclous,1684.29996


In [28]:
import seaborn as sns
ax = sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name')
ax.axhline(300, ls='--', color='r')

<IPython.core.display.Javascript object>

<matplotlib.lines.Line2D at 0x7f052c598e48>