In [1]:
# Imports
import numpy as np
import spikeextractors as se

np.random.seed(0)

In [2]:
# Properties of the in-memory dataset
num_channels=7
samplerate=30000
duration=20
num_timepoints=int(samplerate*duration)
num_units=5
num_events=20

In [3]:
# Generate a pure-noise timeseries dataset and a linear geometry
timeseries=np.random.normal(0,10,(num_channels,num_timepoints))
geom=np.zeros((num_channels,2))
geom[:,0]=range(num_channels)

# Define the in-memory recording extractor
RX=se.NumpyRecordingExtractor(timeseries=timeseries,geom=geom,samplerate=samplerate)

# Generate some random events
times=np.int_(np.sort(np.random.uniform(0,num_timepoints,num_events)))
labels=np.random.randint(1,num_units+1,size=num_events)
    
# Define the in-memory sorting extractor
SX=se.NumpySortingExtractor()
for k in range(1,num_units+1):
    times_k=times[np.where(labels==k)[0]]
    SX.addUnit(unit_id=k,times=times_k)

In [4]:
# Demonstrate the API for extracting information
print('Unit ids = {}'.format(SX.getUnitIds()))
st=SX.getUnitSpikeTrain(unit_id=1)
print('Num. events for unit 1 = {}'.format(len(st)))

Unit ids = [1, 2, 3, 4, 5]
Num. events for unit 1 = 5


In [5]:
# Now we can curate the results using a CuratedSortingExtractor

CSX = se.CuratedSortingExtractor(parent_sorting=SX)

In [6]:
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Original Unit Ids: " + str(SX.getUnitIds()))

print("Curated ST: " + str(CSX.getUnitSpikeTrain(1)))
print("Original ST: " + str(SX.getUnitSpikeTrain(1)))

Curated Unit Ids: [1, 2, 3, 4, 5]
Original Unit Ids: [1, 2, 3, 4, 5]
Curated ST: [206907 220517 331138 430220 574290]
Original ST: [206907 220517 331138 430220 574290]


In [7]:
#Lets split one unit from the sorting result (this could be two units incorrectly clustered as one)

CSX.splitUnit(unit_id=1, indices=[0, 1])
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Original Spike Train: " + str(SX.getUnitSpikeTrain(1)))
print("Split Spike Train 1: " + str(CSX.getUnitSpikeTrain(6)))
print("Split Spike Train 2: " + str(CSX.getUnitSpikeTrain(7)))
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Unit Ids: [2, 3, 4, 5, 6, 7]
Original Spike Train: [206907 220517 331138 430220 574290]
Split Spike Train 1: [206907 220517]
Split Spike Train 2: [331138 430220 574290]
2

3

4

5

6
^-------1

7
^-------1



In [8]:
# If the split was incorrect, we can always merge the two units back together
CSX.mergeUnits(unit_ids=[6, 7])
print("Curated Spike Train: " + str(CSX.getUnitSpikeTrain(8)))
print("Original Spike Train: " + str(SX.getUnitSpikeTrain(1)))
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Spike Train: [206907 220517 331138 430220 574290]
Original Spike Train: [206907 220517 331138 430220 574290]
2

3

4

5

8
^-------6
	^-------1
^-------7
	^-------1



In [9]:
# We can also exclude units, so let's get rid of 8 since we are just confused about this unit
CSX.excludeUnits(unit_ids=[8])
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

2

3

4

5



In [10]:
#Now let's merge 3 and 4 together (This will create a new unit which encapsulates both previous units)
CSX.mergeUnits(unit_ids=[3, 4])
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Merged Spike Train: " + str(CSX.getUnitSpikeTrain(9)))
print("Original Spike Trains concatenated: " + str(np.sort(np.concatenate((SX.getUnitSpikeTrain(3), SX.getUnitSpikeTrain(4))))))
print("\nCuration Tree")
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Unit Ids: [2, 5, 9]
Merged Spike Train: [183155 210132 220886 398518 445947 477836 507142]
Original Spike Trains concatenated: [183155 210132 220886 398518 445947 477836 507142]

Curation Tree
2

5

9
^-------3
^-------4



In [11]:
#Now let's merge 2 and 6 together

CSX.mergeUnits(unit_ids=[2, 9])
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Merged Spike Train: " + str(CSX.getUnitSpikeTrain(10)))
merged_spike_train = []
for unit_id in SX.getUnitIds():
    if(unit_id != 1 and unit_id != 5):
        merged_spike_train.append(SX.getUnitSpikeTrain(unit_id))
merged_spike_train = np.asarray(merged_spike_train)
merged_spike_train = np.sort(np.concatenate(merged_spike_train).ravel())
print("Original Spike Trains concatenated: " + str(merged_spike_train))
print("\nCuration Tree")
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Unit Ids: [5, 10]
Merged Spike Train: [183155 210132 220886 327869 398518 436875 445947 477836 507142 525257]
Original Spike Trains concatenated: [183155 210132 220886 327869 398518 436875 445947 477836 507142 525257]

Curation Tree
5

10
^-------2
^-------9
	^-------3
	^-------4



In [12]:
#Now let's split unit 5 with given indices

CSX.splitUnit(unit_id=5, indices=[0, 1])
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Original Spike Train: " + str(SX.getUnitSpikeTrain(5)))
print("Split Spike Train 1: " + str(CSX.getUnitSpikeTrain(11)))
print("Split Spike Train 2: " + str(CSX.getUnitSpikeTrain(12)))
print("\nCuration Tree")
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Unit Ids: [10, 11, 12]
Original Spike Train: [168716 256926 272397 318528 470153]
Split Spike Train 1: [168716 256926]
Split Spike Train 2: [272397 318528 470153]

Curation Tree
10
^-------2
^-------9
	^-------3
	^-------4

11
^-------5

12
^-------5



In [13]:
#Finally, we can merge units 7 and 8

CSX.mergeUnits(unit_ids=[10, 11])
print("Curated Unit Ids: " + str(CSX.getUnitIds()))
print("Merged Spike Train: " + str(CSX.getUnitSpikeTrain(13)))
original_spike_train = (np.sort(np.concatenate((SX.getUnitSpikeTrain(3), SX.getUnitSpikeTrain(4), SX.getUnitSpikeTrain(2), SX.getUnitSpikeTrain(5)[np.asarray([0,1])]))))
print("Original Spike Train: " + str(original_spike_train))
print("\nCuration Tree")
for unit_id in CSX.getUnitIds():
    CSX.printCurationTree(unit_id=unit_id)

Curated Unit Ids: [12, 13]
Merged Spike Train: [168716 183155 210132 220886 256926 327869 398518 436875 445947 477836
 507142 525257]
Original Spike Train: [168716 183155 210132 220886 256926 327869 398518 436875 445947 477836
 507142 525257]

Curation Tree
12
^-------5

13
^-------10
	^-------2
	^-------9
		^-------3
		^-------4
^-------11
	^-------5



In [14]:
# Write the input/output in the MountainSort format
se.MdaRecordingExtractor.writeRecording(recording=RX,save_path='sample_mountainsort_dataset')
se.MdaSortingExtractor.writeSorting(sorting=CSX,save_path='sample_mountainsort_dataset/firings_true.mda')

In [15]:
# Read this dataset with the Mda input extractor (we can now have a normal sorting extractor with our curations)
RX2=se.MdaRecordingExtractor(dataset_directory='sample_mountainsort_dataset')
SX2=se.MdaSortingExtractor(firings_file='sample_mountainsort_dataset/firings_true.mda')

In [16]:
print("New Unit Ids: " + str(SX2.getUnitIds()))
print("New Unit Spike Train: " + str(SX2.getUnitSpikeTrain(13)))
print("Previous Curated Unit Spike Train: " + str(CSX.getUnitSpikeTrain(13)))

New Unit Ids: [12 13]
New Unit Spike Train: [168716 183155 210132 220886 256926 327869 398518 436875 445947 477836
 507142 525257]
Previous Curated Unit Spike Train: [168716 183155 210132 220886 256926 327869 398518 436875 445947 477836
 507142 525257]
