In [None]:
from mountainlab_pytools import mdaio
from mountainlab_pytools import mlproc as mlp
import os
import json

def sort_dataset(*,dataset_dir,output_dir,adjacency_radius,detect_threshold,opts={}):
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        
    # Dataset parameters
    ds_params=read_dataset_params(dataset_dir)
    
    # Whiten
    whiten(
        timeseries=dataset_dir+'/raw.mda',
        timeseries_out=output_dir+'/pre.mda.prv',
        opts=opts
    )
    
    # Sort
    detect_sign=1
    if 'spike_sign' in ds_params:
        detect_sign=ds_params['spike_sign']
    if 'detect_sign' in ds_params:
        detect_sign=ds_params['detect_sign']
    ms4alg_sort(
        timeseries=output_dir+'/pre.mda.prv',
        geom=dataset_dir+'/geom.csv',
        firings_out=output_dir+'/firings.mda',
        adjacency_radius=adjacency_radius,
        detect_sign=detect_sign,
        detect_threshold=detect_threshold,
        opts=opts
    )
    
    
def read_dataset_params(dsdir):
    params_fname=mlp.realizeFile(dsdir+'/params.json')
    if not os.path.exists(params_fname):
        raise Exception('Dataset parameter file does not exist: '+params_fname)
    with open(params_fname) as f:
        return json.load(f)

def whiten(*,timeseries,timeseries_out,opts={}):
    return mlp.runProcess(
        'ephys.whiten',
        {
            'timeseries':timeseries
        },
        {
            'timeseries_out':timeseries_out
        },
        {},
        opts
    )

def ms4alg_sort(*,timeseries,geom,firings_out,detect_sign,adjacency_radius,detect_threshold=3,opts={}):
    pp={}
    pp['detect_sign']=detect_sign
    pp['adjacency_radius']=adjacency_radius
    pp['detect_threshold']=detect_threshold
    mlp.runProcess(
        'ms4alg.sort',
        {
            'timeseries':timeseries,
            'geom':geom
        },
        {
            'firings_out':'output/firings.mda'
        },
        pp,
        opts
    )