In [79]:
from mountainlab_pytools import mlproc as mlp


def bandpass_filter(*, timeseries, timeseries_out, samplerate, freq_min, freq_max, opts={}):
    return mlp.runProcess(
        'ephys.bandpass_filter',
        {
            'timeseries': timeseries
        }, {
            'timeseries_out': timeseries_out
        },
        {
            'samplerate': samplerate,
            'freq_min': freq_min,
            'freq_max': freq_max
        },
        opts
    )


def _mask_artifacts(*, timeseries, timeseries_out, threshold=6, chunk_size=2000, num_write_chunks=150, opts={}):
    return mlp.runProcess(
        'ephys.mask_out_artifacts',
        {
            'timeseries': timeseries
        },
        {
            'timeseries_out': timeseries_out
        },
        {
            'threshold': threshold,
            'chunk_size': chunk_size,
            'num_write_chunks': num_write_chunks,
        },
        opts
    )


def _whiten(*, timeseries, timeseries_out, opts={}):
    return mlp.runProcess(
        'ephys.whiten',
        {
            'timeseries': timeseries
        },
        {
            'timeseries_out': timeseries_out
        },
        {},
        opts
    )


def ms4alg_sort(*, timeseries, geom, firings_out, detect_sign, adjacency_radius, detect_threshold, detect_interval,
                clip_size, num_workers=os.cpu_count(), opts={}):
    pp = {}
    pp['detect_sign'] = detect_sign
    pp['adjacency_radius'] = adjacency_radius
    pp['detect_threshold'] = detect_threshold
    pp['clip_size'] = clip_size
    pp['detect_interval'] = detect_interval
    pp['num_workers'] = num_workers
    
    inputs = {'timeseries': timeseries}
    if geom is not None:
        inputs['geom'] = geom

    mlp.runProcess(
        'ms4alg.sort',
        inputs,
        {
            'firings_out': firings_out
        },
        pp,
        opts
    )


def compute_cluster_metrics(*, timeseries, firings, metrics_out, samplerate, opts={}):
    metrics1 = mlp.runProcess(
        'ms3.cluster_metrics',
        {
            'timeseries': timeseries,
            'firings': firings
        },
        {
            'cluster_metrics_out': True
        },
        {
            'samplerate': samplerate
        },
        opts
    )['cluster_metrics_out']
    metrics2 = mlp.runProcess(
        'ms3.isolation_metrics',
        {
            'timeseries': timeseries,
            'firings': firings
        },
        {
            'metrics_out': True
        },
        {
            'compute_bursting_parents': 'true'
        },
        opts
    )['metrics_out']
    return mlp.runProcess(
        'ms3.combine_cluster_metrics',
        {
            'metrics_list': [metrics1, metrics2]
        },
        {
            'metrics_out': metrics_out
        },
        {},
        opts
    )


def add_curation_tags(*, cluster_metrics, output_filename, firing_rate_thresh=0.05,
                      isolation_thresh=0.95, noise_overlap_thresh=0.03, peak_snr_thresh=1.5, opts={}):
    # Automated curation
    mlp.runProcess(
        'pyms.add_curation_tags',
        {
            'metrics': cluster_metrics
        },
        {
            'metrics_tagged': output_filename
        },
        {
            'firing_rate_thresh': firing_rate_thresh,
            'isolation_thresh': isolation_thresh,
            'noise_overlap_thresh': noise_overlap_thresh,
            'peak_snr_thresh': peak_snr_thresh
        },
        opts
    )

In [104]:
from mountainlab_pytools import mlproc as mlp
import os
import json


def sort_dataset(*,
                 raw_fname=None, filt_fname = None, pre_fname=None, geom_fname=None, params_fname=None,
                 firings_out, filt_out_fname='', pre_out_fname='', metrics_out_fname='', masked_out_fname='',
                 freq_min=300, freq_max=7000, samplerate=30000, detect_sign=1,
                 adjacency_radius=-1, detect_threshold=3, detect_interval=10, clip_size=50,
                 firing_rate_thresh=0.05, isolation_thresh=0.95, noise_overlap_thresh=0.03,
                 peak_snr_thresh=1.5, mask_artifacts='true', whiten='true',
                 mask_threshold=6, mask_chunk_size=2000,
                 mask_num_write_chunks=15, num_workers=os.cpu_count()):
    """
    Custom Sorting Pipeline. It will pre-process, sort, and curate (using ms_taggedcuration pipeline).

    Parameters
    ----------
    raw_fname : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints). If you input this it will pre-process the data.
    filt_fname : INPUT
        MxN raw timeseries array (M = #channels, N = #timepoints). This input contains data that has already been filtered.
    pre_fname : INPUT
        MxN pre-processed array timeseries array (M = #channels, N = #timepoints). This is if you want to analyze already pre-processed data.
    geom_fname : INPUT
        (Optional) geometry file (.csv format).
    params_fname : INPUT
        (Optional) parameter file (.json format), where the key is the any of the parameters for this pipeline. Any values in this .json file will overwrite any defaults.

    firings_out : OUTPUT
        The filename that will contain the spike data (.mda file), default to '/firings.mda'
    filt_out_fname : OUTPUT
        Optional filename for the filtered data (just filtered, no whitening).
    masked_out_fname : OUTPUT
        Optional filename for the masked_data.
    pre_out_fname : OUTPUT
        Optional filename for the pre-processed data (filtered and whitened).
    metrics_out_fname : OUTPUT
        The optional  output filename (.json) for the metrics that will be computed for each unit.

    samplerate : float
        (Optional) The sampling rate in Hz
    freq_min : float
        (Optional) The lower endpoint of the frequency band (Hz)
    freq_max : float
        (Optional) The upper endpoint of the frequency band (Hz)
    adjacency_radius : float
        (Optional) Radius of local sorting neighborhood, corresponding to the geometry file (same units). 0 means each channel is sorted independently. -1 means all channels are included in every neighborhood.
    detect_sign : int
        (Optional) Use 1, -1, or 0 to detect positive peaks, negative peaks, or both, respectively
    detect_threshold : float
        (Optional) Threshold for event detection, corresponding to the input file. So if the input file is normalized to have noise standard deviation 1 (e.g., whitened), then this is in units of std. deviations away from the mean.
    detect_interval : int
        (Optional) The minimum number of timepoints between adjacent spikes detected in the same channel neighborhood.
    clip_size : int
        (Optional) Size of extracted clips or snippets, used throughout
    firing_rate_thresh : float64
        (Optional) firing rate must be above this
    isolation_thresh : float64
        (Optional) isolation must be above this
    noise_overlap_thresh : float64
        (Optional) noise_overlap_thresh must be below this
    peak_snr_thresh : float64
        (Optional) peak snr must be above this
    mask_artifacts : str
        (Optional) if set to 'true', it will mask the large amplitude artifacts, if 'false' it will not.
    whiten : str
        (Optional) if set to 'true', it will whiten the signal (assuming the input is raw_fname, if 'false' it will not.
    mask_threshold : int
        (Optional) Number of standard deviations away from the mean RSS for the chunk to be considered as artifact.
    mask_chunk_size: int
        This chunk size will be the number of samples that will be set to zero if the RSS of this chunk is above threshold.
    mask_num_write_chunks: int
        How many mask_chunks will be simultaneously written to mask_out_fname (default of 150).
    num_workers : int
        (Optional) Number of simultaneous workers (or processes). The default is multiprocessing.cpu_count().
    """

    if mask_artifacts == 'true':
        mask = True
    elif mask_artifacts == 'false':
        mask_artifacts = False
    else:
        raise Exception("mask_artifacts must be set to 'true' or 'false'!")

    if whiten == 'true':
        whiten = True
    elif whiten == 'false':
        whiten = False
    else:
        raise Exception("whiten must be set to 'true' or 'false'!")

    # if you do not provide an input, it will set the value as an empty string via mountainlab

    # TODO: find a more pythonic way to do this
    
    if raw_fname == '':
        raw_fname = None
    
    if filt_fname == '':
        filt_fname = None
    
    if masked_out_fname == '':
        masked_out_fname = None

    if pre_out_fname == '':
        pre_out_fname = None

    if filt_out_fname == '':
        filt_out_fname = None

    if metrics_out_fname == '':
        metrics_out_fname = None

    if pre_fname == '':
        pre_fname = None

    if geom_fname == '':
        geom_fname = None

    if params_fname == '':
        params_fname = None

    if firings_out == '':
        firings_out = None

    # END TODO

    if raw_fname is None and pre_fname is None and filt_fname is None:
        raise Exception('You must input a raw_fname, filt_fname, or a pre_fname!')

    if raw_fname is not None and pre_fname is not None:
        raise Exception('You defined both the raw_fname and the pre_fname, can only use one!')

    params = {'freq_min': freq_min,
              'freq_max': freq_max,
              'samplerate': samplerate,
              'detect_sign': detect_sign,
              'adjacency_radius': adjacency_radius,
              'detect_threshold': detect_threshold,
              'detect_interval': detect_interval,
              'clip_size': clip_size,
              'firing_rate_thresh':  firing_rate_thresh,
              'isolation_thresh': isolation_thresh,
              'noise_overlap_thresh': noise_overlap_thresh,
              'peak_snr_thresh': peak_snr_thresh,
              'mask_threshold': mask_threshold,
              'mask_chunk_size': mask_chunk_size,
              'mask_num_write_chunks': mask_num_write_chunks,
              'mask_artifacts': mask_artifacts,
              'num_workers': num_workers,
    }

    if params_fname is not None:
        if os.path.exists(params_fname):
            ds_params = read_dataset_params(params_fname)

        # override the default parameters
        for key, value in ds_params.items():
            params[key] = value
    else:
        pass

    if raw_fname is not None:
        # no pre-processing has done, so perform the pre-processing
        if not os.path.exists(raw_fname):
            raise Exception('The following timeseries does not exist: %s!' % raw_fname)

        output_dir = os.path.dirname(raw_fname)

        if filt_out_fname is None:
            filt_out_fname = output_dir + '/filt.mda.prv'

        # Bandpass filter
        bandpass_filter(
            timeseries=raw_fname,
            timeseries_out=filt_out_fname,
            samplerate=params['samplerate'],
            freq_min=params['freq_min'],
            freq_max=params['freq_max'],
            # opts=opts
        )
        
        if params['mask_artifacts']:
            # if the user decided to mask the artifacts, do so
            if masked_out_fname is None:
                masked_out_fname = output_dir + '/masked.mda.prv'
            
            _mask_artifacts(
                timeseries=filt_out_fname,
                timeseries_out=masked_out_fname,
                threshold=params['mask_threshold'],
                chunk_size=params['mask_chunk_size'],
                num_write_chunks=params['mask_num_write_chunks'],
                # opts=opts
            )

            whiten_input = masked_out_fname

        else:
            # otherwise use the bandpassed data as an input
            whiten_input = filt_out_fname

        if whiten:
            if pre_out_fname is None:
                pre_out_fname = output_dir + '/pre.mda.prv'

            # Whiten
            _whiten(
                timeseries=whiten_input,
                timeseries_out=pre_out_fname,
                # opts=opts
            )

            sort_fname = pre_out_fname
        else:
            sort_fname = whiten_input

    elif filt_fname is not None:
        # then the data has already been filtered so just mask artifacts/whiten if desired
        
        if not os.path.exists(filt_fname):
            raise Exception('The following timeseries does not exist: %s!' % filt_fname)
            
        output_dir = os.path.dirname(filt_fname)
            
        if params['mask_artifacts']:
            # if the user decided to mask the artifacts, do so
            if masked_out_fname is None:
                masked_out_fname = output_dir + '/masked.mda.prv'

            _mask_artifacts(
                timeseries=filt_fname,
                timeseries_out=masked_out_fname,
                threshold=params['mask_threshold'],
                chunk_size=params['mask_chunk_size'],
                num_write_chunks=params['mask_num_write_chunks'],
                # opts=opts
            )

            whiten_input = masked_out_fname

        else:
            # otherwise use the filtered data as an input
            whiten_input = filt_fname

        if whiten:
            if pre_out_fname is None:
                pre_out_fname = output_dir + '/pre.mda.prv'

            # Whiten
            _whiten(
                timeseries=whiten_input,
                timeseries_out=pre_out_fname,
                # opts=opts
            )

            sort_fname = pre_out_fname
        else:
            sort_fname = whiten_input
        
    else:
        # then the data has alreayd been pre-processed as the pre_fname is the one defined
        if not os.path.exists(pre_fname):
            raise Exception('The following timeseries does not exist: %s!' % pre_fname)

        output_dir = os.path.dirname(pre_fname)
        sort_fname = pre_fname

    # Sort

    if firings_out is None:
        firings_out = output_dir + '/firings.mda'

    ms4alg_sort(
        timeseries=sort_fname,
        geom=geom_fname,
        firings_out=firings_out,
        adjacency_radius=params['adjacency_radius'],
        detect_sign=params['detect_sign'],
        detect_threshold=params['detect_threshold'],
        detect_interval=params['detect_interval'],
        clip_size=params['clip_size'],
        num_workers=params['num_workers'],
        # opts=opts
    )

    temp_metrics = output_dir + '/temp_metrics.json'

    if metrics_out_fname is None:
        metrics_out_fname = output_dir + '/cluster_metrics.json'

    # Compute cluster metrics
    compute_cluster_metrics(
        timeseries=sort_fname,
        firings=firings_out,
        metrics_out=temp_metrics,
        samplerate=params['samplerate'],
        # opts=opts
    )
    
    add_curation_tags(cluster_metrics=temp_metrics,
                      output_filename=metrics_out_fname,
                      firing_rate_thresh=params['firing_rate_thresh'],
                      isolation_thresh=params['isolation_thresh'],
                      noise_overlap_thresh=params['noise_overlap_thresh'],
                      peak_snr_thresh=params['peak_snr_thresh'],
                      # opts=opts
                    )

    os.remove(temp_metrics)
    return True

def read_dataset_params(params_fname):
    params_fname = mlp.realizeFile(params_fname)
    if not os.path.exists(params_fname):
        raise Exception('Dataset parameter file does not exist: ' + params_fname)
    with open(params_fname) as f:
        return json.load(f)


In [81]:
import os
import numpy as np

In [82]:
def find_sub(string, sub):
    '''finds all instances of a substring within a string and outputs a list of indices'''
    result = []
    k = 0
    while k < len(string):
        k = string.find(sub, k)
        if k == -1:
            return result
        else:
            result.append(k)
            k += 1  # change to k += len(sub) to not search overlapping results
    return result


def get_ubuntu_path(filepath):
    # get the drive letter

    drive_letter_i = filepath.find(':/')

    if drive_letter_i == -1:
        drive_letter_i = filepath.find(':\\')

    drive_letter = filepath[:drive_letter_i].lower()

    i = 1
    while drive_letter_i + i == '/':
        i += 1

    remaining_path = filepath[drive_letter_i + i + 1:]
    linux_path = '/mnt/%s/%s' % (drive_letter, remaining_path)
    
    # add single quotes so linux can understand the special characters
    if '(' in linux_path or ')' in linux_path:
        linux_path = "'%s'" % linux_path

    return os.path.normpath((linux_path)).replace('\\', '/')

def get_windows_filename(filename):
    # remove the single quotes if added
    if filename[0] == "'" and filename[-1] == "'":
        filename = filename[1:-1]

    filename_split = filename.split('/')
    mnt_i = np.where(np.array(filename_split) == 'mnt')[0][0]

    drive_letter = filename_split[mnt_i + 1].upper()

    remaining = '\\'.join(list(filename_split[mnt_i + 2:]))

    return '%s:\\%s' % (drive_letter, remaining)

In [144]:
# directory = get_ubuntu_path('E:\\Apollo_D_Drive\\data\\VirtualMazeData\\b6_march_18_1')
# directory = get_ubuntu_path('E:\\Apollo_D_Drive\\data\\VirtualMazeData\\b6_august_18_1\\SimpleCircularTrack')
directory = get_ubuntu_path('E:\\Apollo_D_Drive\\data\\MSData\\whatever')
# directory = get_ubuntu_path('E:\\Apollo_D_Drive\\data\VirtualMazeData\\b6_august_18_2\\NoiseTesting')
# directory = get_ubuntu_path('E:\\Apollo_D_Drive\\data\\VirtualMazeData\\b6_august_18_2\\SimpleCircularTrack')

raw_fnames = [os.path.join(directory, file) for file in os.listdir(directory) if '_raw.mda' in file]
filt_fnames = [os.path.join(directory, file) for file in os.listdir(directory) if '_filt.mda' in file]

for file in raw_fnames:
    print(file)
    
print('-----------------------')

for file in filt_fnames:
    print(file)

-----------------------
/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T1_filt.mda
/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_filt.mda
/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T3_filt.mda
/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T4_filt.mda


In [145]:
#tint = False
tint = True

file_index = 1

if tint:

    filt_fname = get_windows_filename(filt_fnames[file_index])
    print(filt_fname)

    mda_basename = os.path.splitext(filt_fname)[0]
    mda_basename = mda_basename[:find_sub(mda_basename, '_')[-1]]

    masked_out_fname = get_ubuntu_path(mda_basename + '_masked.mda')
    firings_out = get_ubuntu_path(mda_basename + '_firings.mda')
    # filt_out_fname = get_ubuntu_path(mda_basename + '_filt.mda')
    pre_out_fname = get_ubuntu_path(mda_basename + '_pre.mda')
    metrics_out_fname = get_ubuntu_path(mda_basename + '_metrics.json')

    filt_fname = get_ubuntu_path(filt_fname)
    
else:
    
    raw_fname = get_windows_filename(raw_fnames[file_index])
    print(raw_fname)

    mda_basename = os.path.splitext(raw_fname)[0]
    mda_basename = mda_basename[:find_sub(mda_basename, '_')[-1]]

    masked_out_fname = get_ubuntu_path(mda_basename + '_masked.mda')
    firings_out = get_ubuntu_path(mda_basename + '_firings.mda')
    filt_out_fname = get_ubuntu_path(mda_basename + '_filt.mda')
    pre_out_fname = get_ubuntu_path(mda_basename + '_pre.mda')
    metrics_out_fname = get_ubuntu_path(mda_basename + '_metrics.json')

    raw_fname = get_ubuntu_path(raw_fname)

E:\Apollo_D_Drive\data\MSData\whatever\20181001-1050-1-raw-downstairs_T2_filt.mda


In [146]:
if tint:
    samplerate=int(48e3)
else:
    samplerate=int(24e3)
    # samplerate=int(30e3)
    pass

# whiten='false'
whiten='true'
detect_interval=10
detect_sign=0

if whiten == 'true':
    detect_threshold=3
else:
    detect_threshold=30
    
freq_min=300
freq_max=6000
mask_threshold=6
masked_chunk_size = int(samplerate/10)
mask_num_write_chunks=100
clip_size=50

if tint:
    print('filename: %s ' % os.path.basename(filt_fname))
else:
    print('filename: %s '% os.path.basename(raw_fname))
    
print('whiten: %s' % whiten)
print('Fs: %d' % samplerate)
print('Tint: %s' % (str(tint).lower()))
print('threshold: %d' % detect_threshold)

filename: 20181001-1050-1-raw-downstairs_T2_filt.mda 
whiten: true
Fs: 48000
Tint: true
threshold: 3


In [147]:
if tint:
    sort_dataset(filt_fname=filt_fname, pre_out_fname=pre_out_fname, 
                 metrics_out_fname=metrics_out_fname, firings_out=firings_out, masked_out_fname=masked_out_fname,
                 samplerate=samplerate, detect_interval=detect_interval, detect_sign=detect_sign, 
                 detect_threshold=detect_threshold, freq_min=freq_min, freq_max=freq_max, mask_threshold=mask_threshold, 
                 mask_chunk_size=masked_chunk_size, mask_num_write_chunks=mask_num_write_chunks, whiten=whiten, clip_size=clip_size
                )
else:
    sort_dataset(raw_fname=raw_fname, filt_out_fname=filt_out_fname, pre_out_fname=pre_out_fname, 
                 metrics_out_fname=metrics_out_fname, firings_out=firings_out, masked_out_fname=masked_out_fname,
                 samplerate=samplerate, detect_interval=detect_interval, detect_sign=detect_sign, 
                 detect_threshold=detect_threshold, freq_min=freq_min, freq_max=freq_max, mask_threshold=mask_threshold, 
                 mask_chunk_size=masked_chunk_size, mask_num_write_chunks=mask_num_write_chunks, whiten=whiten, clip_size=clip_size
                )

RUNNING: ml-run-process ephys.mask_out_artifacts --inputs timeseries:/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_filt.mda --parameters chunk_size:4800 num_write_chunks:100 threshold:6 --outputs timeseries_out:/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_masked.mda
[34m[ Getting processor spec... ][0m
[34m[ Checking inputs and substituting prvs ... ][0m
[34m[ Computing process signature ... ][0m
[34mProcess signature: c3e5a94b0d47cc5e5c0b32810878e3ac2c282a1c[0m
[34m[ Checking outputs... ][0m
[34m{"timeseries_out":"/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_masked.mda"}[0m
[34mProcessing ouput - /mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_masked.mda[0m
[34mfalse[0m
[34m{"timeseries_out":"/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_masked.mda"}[0m
[34m[ Checking process cache ... ][0m
[34m[ Creating tempor

RUNNING: ml-run-process ms3.cluster_metrics --inputs firings:/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_firings.mda timeseries:/mnt/e/Apollo_D_Drive/data/MSData/whatever/20181001-1050-1-raw-downstairs_T2_pre.mda --parameters samplerate:48000 --outputs cluster_metrics_out:/mnt/e/Apollo_D_Drive/MountainSortTempJs/mountainlab/tmp_short_term/output_cluster_metrics_out_0a999ba82efa122cf04a229ecdd8aad35af64e2c.prv
[34m[ Getting processor spec... ][0m
[34m[ Checking inputs and substituting prvs ... ][0m
[34m[ Computing process signature ... ][0m
[34mProcess signature: 01d6244ac39272b4389f911bcd0cbe9e3e4ce0ff[0m
[34m[ Checking outputs... ][0m
[34m{"cluster_metrics_out":"/mnt/e/Apollo_D_Drive/MountainSortTempJs/mountainlab/tmp_short_term/output_cluster_metrics_out_0a999ba82efa122cf04a229ecdd8aad35af64e2c.prv"}[0m
[34mProcessing ouput - /mnt/e/Apollo_D_Drive/MountainSortTempJs/mountainlab/tmp_short_term/output_cluster_metrics_out_0a999ba82efa122cf04