In [1]:
# Imports
import numpy as np
import spikeextractors as se
import spiketoolkit as st
np.random.seed(0)

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  from collections import Mapping, Set, Iterable


In [15]:
import copy
import spikeextractors as se

class SpikeElement:
    """Base class for spike elements"""

    def __init__(self, interface_id, interface_class, interface_name):
        self._interface_id = interface_id
        self._interface_class = interface_class
        self._interface_name = interface_name
        self._params = copy.deepcopy(interface_class.gui_params())

    @property
    def interface_id(self):
        return self._interface_id

    @property
    def interface_class(self):
        return self._interface_class

    @property
    def interface_name(self):
        return self._interface_name

    @property
    def params(self):
        return self._params

    def setup(self):
        pass

    def run(self, input_payload=None):
        pass

    
class Extractor(SpikeElement):
    """Extractor class"""

    def __init__(self, interface_class, interface_id):
        SpikeElement.__init__(self, interface_id, interface_class, interface_class.extractor_name)

    def run(self, input_payload=None):
        if(not self._interface_class.has_locations):
            probe_path = self._params.pop(-1)['value']
        params = self._params
        params_dict = {}
        for param in params:
            param_name = param['name']
            # param_type = param['type']
            # param_title = param['title']
            param_value = param['value']
            params_dict[param_name] = param_value
        recording = self._interface_class(**params_dict)
        if(not self._interface_class.has_locations):
            se.load_probe_file(recording, probe_path)
        return recording
    
class Preprocessor(SpikeElement):
    """Preprocessor class"""

    def __init__(self, interface_class, interface_id):
        SpikeElement.__init__(self, interface_id, interface_class,
                              interface_class.preprocessor_name)

    def run(self, input_payload):
        params_dict = {}
        params_dict['recording'] = input_payload
        params = self._params
        for param in params:
            param_name = param['name']
            # param_type = param['type']
            # param_title = param['title']
            param_value = param['value']
            params_dict[param_name] = param_value
        pp = self._interface_class(**params_dict)
        return pp

class Sorter(SpikeElement):
    """Sorter class"""

    def __init__(self, interface_class, interface_id):
        SpikeElement.__init__(self, interface_id, interface_class,
                              interface_class.sorter_name)

    def run(self, input_payload):
        base_sorter_param_dict = {}
        base_sorter_param_dict['recording'] = input_payload
            
            
        params = self._params
        output_folder = params[0]
        parallel = params[1]
        base_sorter_param_dict[output_folder['name']] = output_folder['value']
        base_sorter_param_dict[parallel['name']] = parallel['value']
        sorter = self._interface_class(**base_sorter_param_dict)

        sub_sorter_param_dict = {}
        for param in params[2:]:
            param_name = param['name']
            # param_type = param['type']
            # param_title = param['title']
            param_value = param['value']
            sub_sorter_param_dict[param_name] = param_value
        sorter.set_params(**sub_sorter_param_dict)
        sorter.run()

        return sorter.get_result()




#####EXTRACTOR######
extractor_class = se.extractorlist.installed_recording_extractor_list[0]
extractor = Extractor(extractor_class, 0)
##User
extractor.params[0]['value'] = 'sample_mountainsort_dataset'
##
payload = extractor.run()

#####PREPROCESSOR######
preprocessor_class = st.preprocessing.preprocessinglist.installed_preprocessers_list[0]
preprocessor = Preprocessor(preprocessor_class, 1)
payload = preprocessor.run(payload)

######SORTER#######
sorter_class = st.sorters.sorterlist.installed_sorter_list[0]
sorter = Sorter(sorter_class, 1)
payload = sorter.run(payload)

{'detect_sign': -1, 'adjacency_radius': -1, 'freq_min': 300.0, 'freq_max': 6000.0, 'filter': False, 'curation': True, 'whiten': True, 'clip_size': 50, 'detect_threshold': 3, 'detect_interval': 10, 'noise_overlap_threshold': 0.15}
Using 2 workers.
Using tmpdir: /var/folders/s8/mkz4gjm57z3_4wqbv20t2jcm0000gn/T/tmp5l_5pf9k
Num. workers = 2
Preparing /var/folders/s8/mkz4gjm57z3_4wqbv20t2jcm0000gn/T/tmp5l_5pf9k/timeseries.hdf5...
Preparing neighborhood sorters (M=7, N=600000)...
Neighboorhood of channel 0 has 7 channels.
Neighboorhood of channel 1 has 7 channels.
Detecting events on channel 1 (phase1)...
Detecting events on channel 2 (phase1)...
Elapsed time for detect on neighborhood: 0:00:00.152869
Num events detected on channel 1 (phase1): 576
Elapsed time for detect on neighborhood: 0:00:00.156341
Num events detected on channel 2 (phase1): 685
Computing PCA features for channel 1 (phase1)...
Computing PCA features for channel 2 (phase1)...
Clustering for channel 1 (phase1)...
Found 1 cl

In [None]:
se.extractorlist.installed_recording_extractor_list[4].has_locations

Neighboorhood of channel 2 has 7 channels.
Neighboorhood of channel 3 has 7 channels.
Neighboorhood of channel 4 has 7 channels.
Neighboorhood of channel 5 has 7 channels.
Neighboorhood of channel 6 has 7 channels.


In [None]:
# Properties of the in-memory dataset
num_channels=7
samplerate=30000
duration=20
num_timepoints=int(samplerate*duration)
num_units=5
num_events=20

In [None]:
# Generate a pure-noise timeseries dataset and a linear geometry
timeseries=np.random.normal(0,10,(num_channels,num_timepoints))
geom=np.zeros((num_channels,2))
geom[:,0]=range(num_channels)

# Define the in-memory recording extractor
RX=se.NumpyRecordingExtractor(timeseries=timeseries,geom=geom,samplerate=samplerate)

# Generate some random events
times=np.int_(np.sort(np.random.uniform(0,num_timepoints,num_events)))
labels=np.random.randint(1,num_units+1,size=num_events)
    
# Define the in-memory sorting extractor
SX=se.NumpySortingExtractor()
for k in range(1,num_units+1):
    times_k=times[np.where(labels==k)[0]]
    SX.add_unit(unit_id=k,times=times_k)
    
#Add some features to the sorting extractor. These will be merged and split appropriately during curation
spikes = 0
for unit_id in SX.get_unit_ids():
    SX.set_unit_spike_features(unit_id, feature_name='f_int', value=range(spikes, spikes + len(SX.get_unit_spike_train(unit_id))))
    spikes += len(SX.get_unit_spike_train(unit_id))
    
spikes = 0
for unit_id in SX.get_unit_ids():
    SX.set_unit_spike_features(unit_id, feature_name='f_float', value=np.arange(float(spikes) + .1, float(spikes + len(SX.get_unit_spike_train(unit_id) + .1))))
    spikes += len(SX.get_unit_spike_train(unit_id))
    
#Features that are not shared across ALL units will not be merged and split correctly (will disappear)
SX.set_unit_spike_features(1, feature_name='bad_feature', value=np.repeat(1, len(SX.get_unit_spike_train(1))))
SX.set_unit_spike_features(2, feature_name='bad_feature', value=np.repeat(2, len(SX.get_unit_spike_train(2))))
SX.set_unit_spike_features(3, feature_name='bad_feature', value=np.repeat(3, len(SX.get_unit_spike_train(3))))

In [None]:
# Demonstrate the API for extracting information
print('Unit ids = {}'.format(SX.get_unit_ids()))
st=SX.get_unit_spike_train(unit_id=1)
print('Num. events for unit 1 = {}'.format(len(st)))

In [None]:
# Now we can curate the results using a CuratedSortingExtractor
CSX = se.CurationSortingExtractor(parent_sorting=SX)

In [None]:
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Original Unit Ids: " + str(SX.get_unit_ids()))

print("Curated ST: " + str(CSX.get_unit_spike_train(1)))
print("Original ST: " + str(SX.get_unit_spike_train(1)))

In [None]:
#Lets split one unit from the sorting result (this could be two units incorrectly clustered as one)
CSX.split_unit(unit_id=1, indices=[0, 1])
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Original Spike Train: " + str(SX.get_unit_spike_train(1)))
print("Split Spike Train 1: " + str(CSX.get_unit_spike_train(6)))
print("Split Spike Train 2: " + str(CSX.get_unit_spike_train(7)))
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
# If the split was incorrect, we can always merge the two units back together
CSX.merge_units(unit_ids=[6, 7])
print("Curated Spike Train: " + str(CSX.get_unit_spike_train(8)))
print("Original Spike Train: " + str(SX.get_unit_spike_train(1)))
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
# We can also exclude units, so let's get rid of 8 since we are just confused about this unit
CSX.exclude_units(unit_ids=[8])
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
#Now let's merge 3 and 4 together (This will create a new unit which encapsulates both previous units)
CSX.merge_units(unit_ids=[3, 4])
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Merged Spike Train: " + str(CSX.get_unit_spike_train(9)))
print("Original Spike Trains concatenated: " + str(np.sort(np.concatenate((SX.get_unit_spike_train(3), SX.get_unit_spike_train(4))))))
print("\nCuration Tree")
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
#Now let's merge 2 and 6 together

CSX.merge_units(unit_ids=[2, 9])
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Merged Spike Train: " + str(CSX.get_unit_spike_train(10)))
merged_spike_train = []
for unit_id in SX.get_unit_ids():
    if(unit_id != 1 and unit_id != 5):
        merged_spike_train.append(SX.get_unit_spike_train(unit_id))
merged_spike_train = np.asarray(merged_spike_train)
merged_spike_train = np.sort(np.concatenate(merged_spike_train).ravel())
print("Original Spike Trains concatenated: " + str(merged_spike_train))
print("\nCuration Tree")
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
#Now let's split unit 5 with given indices

CSX.split_unit(unit_id=5, indices=[0, 1])
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Original Spike Train: " + str(SX.get_unit_spike_train(5)))
print("Split Spike Train 1: " + str(CSX.get_unit_spike_train(11)))
print("Split Spike Train 2: " + str(CSX.get_unit_spike_train(12)))
print("\nCuration Tree")
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
#Finally, we can merge units 7 and 8

CSX.merge_units(unit_ids=[10, 11])
print("Curated Unit Ids: " + str(CSX.get_unit_ids()))
print("Merged Spike Train: " + str(CSX.get_unit_spike_train(13)))
original_spike_train = (np.sort(np.concatenate((SX.get_unit_spike_train(3), SX.get_unit_spike_train(4), SX.get_unit_spike_train(2), SX.get_unit_spike_train(5)[np.asarray([0,1])]))))
print("Original Spike Train: " + str(original_spike_train))
print("\nCuration Tree")
for unit_id in CSX.get_unit_ids():
    CSX.printCurationTree(unit_id=unit_id)

In [None]:
# Write the input/output in the MountainSort format
se.MdaRecordingExtractor.write_recording(recording=RX,save_path='sample_mountainsort_dataset')
se.MdaSortingExtractor.write_sorting(sorting=CSX,save_path='sample_mountainsort_dataset/firings_true.mda')

In [None]:
# Read this dataset with the Mda input extractor (we can now have a normal sorting extractor with our curations)
RX2=se.MdaRecordingExtractor(dataset_directory='sample_mountainsort_dataset')
SX2=se.MdaSortingExtractor(firings_file='sample_mountainsort_dataset/firings_true.mda')

In [None]:
print("New Unit Ids: " + str(SX2.get_unit_ids()))
print("New Unit Spike Train: " + str(SX2.get_unit_spike_train(13)))
print("Previous Curated Unit Spike Train: " + str(CSX.get_unit_spike_train(13)))

In [None]:
#Current spike feature names
print(CSX.get_unit_spike_feature_names())

In [None]:
#All features have been appropriately merged and split according to previous operations
print(CSX.get_unit_spike_features(12, 'f_int'))
print(CSX.get_unit_spike_features(12, 'f_float'))
print(CSX.get_unit_spike_features(13, 'f_int'))
print(CSX.get_unit_spike_features(13, 'f_float'))