# Kilosort2 vs. pykilosort: Comparison Report

This report simulates electrophysiological data and compares the outputs of spike sorters Kilosort2 and pykilosort run on this data. All configuration for the sortings used should be changed either within the report (or ideally) through environment variables. Don't rely on configuration from external files as they may become out of sync between MATLAB and python.

In [None]:
import enum
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

## Configuration

In [None]:
PYKILOSORT_DIR = os.environ.get('PYKILOSORT_DIR', '.')
KILOSORT2_DIR = os.environ.get('KILOSORT2_DIR', f'{PYKILOSORT_DIR}/../Kilosort2')

BASE_PATH = os.environ.get('BASE_PATH', f'{PYKILOSORT_DIR}/examples/eMouse/data')
PYKILOSORT_SORTING_RESULTS_DIR = f'{BASE_PATH}/python_output/'
MATLAB_SORTING_RESULTS_DIR = f'{BASE_PATH}/matlab_output/'

[os.makedirs(d,exist_ok=True) for d in (PYKILOSORT_SORTING_RESULTS_DIR,MATLAB_SORTING_RESULTS_DIR)]

class Operations(enum.Enum):
    simulation = 'SIMULATION'
    matlab_sorting = 'MATLAB_SORTING'
    pykilosort_sorting = 'PYKILOSORT_SORTING'
    
FORCE_RUN = {Operations.simulation, Operations.matlab_sorting, Operations.pykilosort_sorting}
FORCE_RUN = {Operations.pykilosort_sorting}

simulation_opts = {
    'chanMapName': 'chanMap_3B_64sites.mat',
    'NchanTOT': 64.0
}

opts = {
    'chanMap': f'{BASE_PATH}/{simulation_opts["chanMapName"]}',
    'fs': 30000.,
    'fshigh': 150.,
    'minfr_goodchannels': 0.1000,
    'Th': [6.0, 2.0],
    'lam': 10.,
    'AUCsplit': 0.9000,
    'minFR': 0.0200,
    'momentum': [20., 400],
    'sigmaMask': 30.,
    'ThPre': 8.,
    'reorder': 1,
    'nskip': 25.,
    'spkTh': -6.,
    'GPU': 1,
    'nfilt_factor': 4.,
    'ntbuff': 64.0,
    'NT': 65600.,
    'whiteningRange': 32.,
    'nSkipCov': 25.0,
    'scaleproc': 200.,
    'nPCs': 3.,
    'useRAM': 0,
    'sorting': 2,
    'NchanTOT': float(simulation_opts['NchanTOT']),
    'trange': [0., float('inf')],
    'fproc': '/tmp/temp_wh.dat',
    'rootZ': MATLAB_SORTING_RESULTS_DIR,
    'fbinary': f'{BASE_PATH}/sim_binary.imec.ap.bin',
    'fig': False
}  

## Setup MATLAB<sup>TM</sup> engine

In [None]:
import matlab.engine

# If true we start a new matlab engine, if false we try to connect to an existing open matlab workspace.
# The latter is helpful for debugging.
new_session = True 
if new_session:
    eng = matlab.engine.start_matlab()
else:
    eng = matlab.engine.connect_matlab()
    
eng.addpath(eng.genpath(KILOSORT2_DIR));
eng.addpath(eng.genpath(f'{KILOSORT2_DIR}/../npy-matlab'));

## Generate simulated data using MATLAB<sup>TM</sup> 

In [None]:
if Operations.simulation in FORCE_RUN:
    useGPU = True
    useParPool = False

    opts["chanMap"] = eng.make_eMouseChannelMap_3B_short(BASE_PATH, simulation_opts["NchanTOT"])
    opts["chanMap"] = f'{BASE_PATH}/{opts["chanMap"]}'
    eng.make_eMouseData_drift(BASE_PATH, KILOSORT2_DIR, simulation_opts["chanMapName"], useGPU, useParPool, nargout=0)
else:
    assert os.path.isfile(opts["chanMap"])

Write out the channel data to numpy files too 

In [None]:
if Operations.simulation in FORCE_RUN:
    x = eng.load(opts['chanMap'])
    eng.writeNPY(x['chanMap'], f'{BASE_PATH}/chanMap.npy', nargout=0)
    eng.writeNPY(x['xcoords'], f'{BASE_PATH}/xc.npy', nargout=0)
    eng.writeNPY(x['ycoords'], f'{BASE_PATH}/yc.npy', nargout=0)

## Sort simulated data using Kilosort2 via MATLAB<sup>TM</sup> engine

In [None]:
if Operations.matlab_sorting in FORCE_RUN:
    # make sure to convert list to matlab arrays
    ops = eng.struct({k: (matlab.double(v) if isinstance(v, list) else v) for k,v in opts.items()})
    rootZ = eng.char(opts['rootZ'])
    if not new_session: 
        eng.workspace['ops'] = ops
        eng.workspace['rootZ'] = rootZ

    rez = eng.function_kilosort(rootZ, ops)
print(f"Files generated: {os.listdir(MATLAB_SORTING_RESULTS_DIR)}")

## Sort simulated data using pykilosort

In [None]:
import pykilosort
from pathlib import Path
from importlib import reload
from pykilosort import main
reload(main)

pykilosort.add_default_handler()

In [None]:
probe = pykilosort.Bunch()
probe.NchanTOT = int(opts['NchanTOT'])
probe.chanMap = np.load(BASE_PATH+'/chanMap.npy').flatten().astype(int)
probe.kcoords = np.ones(int(opts['NchanTOT']))
probe.xc = np.load(BASE_PATH+'/xc.npy').flatten()
probe.yc = np.load(BASE_PATH+'/yc.npy').flatten()

rez = main.run(
    dat_path = opts['fbinary'],
    dir_path = Path(PYKILOSORT_SORTING_RESULTS_DIR),
    output_dir = Path(PYKILOSORT_SORTING_RESULTS_DIR),
    params = None,
    probe=probe,
    dtype = np.int16,
    n_channels = int(opts['NchanTOT']),
    sample_rate = opts['fs'],
    clear_context = False #Operations.pykilosort_sorting in FORCE_RUN
)
print(f"Files generated: {os.listdir(PYKILOSORT_SORTING_RESULTS_DIR)}")

---

# Intermediate results comparison

In [None]:
import pydantic

class SortingResults(pydantic.BaseModel):
    templates: np.ndarray
    spike_times: np.ndarray
    spike_clusters: np.ndarray
    channel_positions: np.ndarray
        
    class Config:
        arbitrary_types_allowed = True

def get_results(dirname):
    return SortingResults(
        templates = np.load(f"{dirname}templates.npy"),
        spike_times = np.load(f"{dirname}/spike_times.npy"),
        spike_clusters = np.load(f"{dirname}/spike_clusters.npy"),
        channel_positions = np.load(f"{dirname}/channel_positions.npy"),
    )

results = {
    'matlab': get_results(MATLAB_SORTING_RESULTS_DIR),
    'python': get_results(PYKILOSORT_SORTING_RESULTS_DIR)
}

import h5py

def h5_to_dict(group):
    d = {}
    for k,v in group.items():
        if isinstance(v, h5py.Group):
            d[k] = h5_to_dict(v)
        elif v.attrs.get('MATLAB_class') == b'char':
            d[k] = u''.join(chr(c) for c in v)
        elif hasattr(v, 'shape'):
            d[k] = np.array(v)
        else:
            d[k] = v
    return d

with h5py.File(f'{MATLAB_SORTING_RESULTS_DIR}/rez.mat', 'r') as f:
    matlab_rez = h5_to_dict(f['rez'])
matlab_rez['ccb0'] =  matlab_rez['ccb']
    
python_rez = rez['intermediate']

intermediate_results = {
    'matlab': matlab_rez,
    'python': python_rez,
    'matlab - python': {
        'ccbsort': matlab_rez['ccbsort'] - python_rez['ccbsort'],
        'ccb0': matlab_rez['ccb0'] - python_rez['ccb0']
    }
}

## Aligment at the start of the recording

As 2020/06/29 this seems to be off by 1000 samples.

In [None]:
from spikeextractors import BinDatRecordingExtractor
f,axs = plt.subplots(2, 1, sharex=True, figsize=(10,12))

matlab_proc_dat = BinDatRecordingExtractor(
    intermediate_results['matlab']['ops']['fproc'],
    sampling_frequency=intermediate_results['matlab']['ops']['fs'],
    numchan=intermediate_results['matlab']['ops']['NchanTOT'],
    dtype=np.int16
)

ax =axs[0]
ax.plot(matlab_proc_dat.get_traces(start_frame=0, end_frame=4000).T);

python_proc_dat = BinDatRecordingExtractor(
    (rez['context_path'] / 'proc.dat'),
    sampling_frequency=intermediate_results['matlab']['ops']['fs'],
    numchan=intermediate_results['matlab']['ops']['NchanTOT'],
    dtype=np.int16
)
ax =axs[1]
ax.plot(range(1000, 4000), python_proc_dat.get_traces(start_frame=0, end_frame=3000).T);

## Difference in Whitening Matrices

In [None]:
plt.title("Difference in Whitening Matrices")
plt.imshow(intermediate_results['python']['Wrot'] - intermediate_results['matlab']['Wrot'])
plt.colorbar();

## Difference in initial batch re-ordering

In [None]:
assert matlab_rez['ccb0'].shape == python_rez['ccb0'].shape

In [None]:
f, axs = plt.subplots(2, 3, figsize=(20,10))

vmin, vmax = -np.max(intermediate_results['matlab']['ccbsort']), np.max(intermediate_results['matlab']['ccbsort'])

for j,ccb in enumerate(['ccb0', 'ccbsort']):
    for i, (name,res) in enumerate(intermediate_results.items()):
        ax = axs[j,i]
        im = ax.imshow(res[ccb], vmin=vmin, vmax=vmax)
        ax.set_title(name)
    axs[j,0].set_ylabel(ccb)
    
plt.colorbar(im)
f.suptitle("Batch Dissimilarity Matrices", size=16)

The final order of the batches is, as a result, not the same.

In [None]:
plt.plot(matlab_rez['iorig'][0])
plt.plot(python_rez['iorig']);

# Output comparison

## Similarity matrix between the templates of the identified units

In [None]:
matlab_templates = [results['matlab'].templates[x,:,:].ravel() for x in range(results['matlab'].templates.shape[0])]
python_templates = [results['python'].templates[x,:,:].ravel() for x in range(results['python'].templates.shape[0])]

similarity_matrix = np.array([[np.dot(m, p) for p in python_templates] for m in matlab_templates])

plt.figure(figsize=(13,7))
plt.imshow(similarity_matrix, vmin=0, vmax=1)
plt.xlabel('Python Units')
plt.ylabel('MATLAB Units')

plt.colorbar()

In [None]:
units = np.where(similarity_matrix > 0.7)

for u in range(len(units[0])):
    f, axs = plt.subplots(2, 2, figsize=(8,10), gridspec_kw={'height_ratios': [4,1]})
    f.suptitle(f"Unit pair {u+1}: {(units[0][u],units[1][u])}", size=20)

    for i, (name,res) in enumerate(results.items()):
        templates = res.templates
        unit = units[i][u]

        for channel in range(templates.shape[2]):
            axs[0, i].plot(templates[unit,:,channel].T + 0.1 * channel);

        axs[0, i].set_title(name)

    for i, (name,res) in enumerate(results.items()):
        axs[1, i].vlines(res.spike_times[res.spike_clusters == units[i][0]], 0+i, 1+i)

        axs[1, i].set_xlim(0, 100000)

In [None]:
for unit in range(82):
    plt.plot(results['matlab'].templates[unit,:,0].T);