Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions neo/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def setUp(self):

class TestUtilsWithoutProxyObjects(unittest.TestCase):
def test__get_events(self):
starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s)
starts_1.annotate(event_type='trial start', pick='me')
starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s,
labels=['label1', 'label2', 'label3'],
name='pick_me')
starts_1.annotate(event_type='trial start')
starts_1.array_annotate(trial_id=[1, 2, 3])

stops_1 = Event(times=[5.5, 14.9, 30.1] * pq.s)
Expand All @@ -56,9 +58,9 @@ def test__get_events(self):
block = Block()
block.segments = [seg, seg2]

# test getting one whole event via annotation
# test getting one whole event via annotation or attribute
extracted_starts1 = get_events(seg, event_type='trial start')
extracted_starts1b = get_events(block, pick='me')
extracted_starts1b = get_events(block, name='pick_me')

self.assertEqual(len(extracted_starts1), 1)
self.assertEqual(len(extracted_starts1b), 1)
Expand Down Expand Up @@ -128,6 +130,20 @@ def test__get_events(self):
assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2]))
self.assertIsInstance(trials_1_2.array_annotations, ArrayDict)

# test selecting event times by label
trials_1_2 = get_events(block, labels=['label1', 'label2'])

self.assertEqual(len(trials_1_2), 1)

trials_1_2 = trials_1_2[0]

self.assertEqual(starts_1.name, trials_1_2.name)
self.assertEqual(starts_1.description, trials_1_2.description)
self.assertEqual(starts_1.file_origin, trials_1_2.file_origin)
self.assertEqual(starts_1.annotations['event_type'], trials_1_2.annotations['event_type'])
assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2]))
self.assertIsInstance(trials_1_2.array_annotations, ArrayDict)

# test getting more than one event time of more than one event
trials_1_2b = get_events(block, trial_id=[1, 2])

Expand Down
35 changes: 23 additions & 12 deletions neo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import neo
import copy
import warnings
import inspect
import numpy as np
import quantities as pq

Expand Down Expand Up @@ -49,23 +48,26 @@ def get_events(container, **properties):

Example:
--------
>>> import neo
>>> from neo.utils import get_events
>>> import quantities as pq
>>> event = neo.Event(times=[0.5, 10.0, 25.2] * pq.s)
>>> event.annotate(event_type='trial start')
>>> event.array_annotate(trial_id=[1, 2, 3])
>>> seg = neo.Segment()
>>> seg.events = [event]

# Will return a list with the complete event object
>>> get_events(seg, properties={'event_type': 'trial start'})
>>> get_events(seg, event_type='trial start')

# Will return an empty list
>>> get_events(seg, properties={'event_type': 'trial stop'})
>>> get_events(seg, event_type='trial stop')

# Will return a list with an Event object, but only with trial 2
>>> get_events(seg, properties={'trial_id': 2})
>>> get_events(seg, trial_id=2)

# Will return a list with an Event object, but only with trials 1 and 2
>>> get_events(seg, properties={'trial_id': [1, 2]})
>>> get_events(seg, trial_id=[1, 2])
"""
if isinstance(container, neo.Segment):
return _get_from_list(container.events, prop=properties)
Expand Down Expand Up @@ -118,24 +120,27 @@ def get_epochs(container, **properties):

Example:
--------
>>> import neo
>>> from neo.utils import get_epochs
>>> import quantities as pq
>>> epoch = neo.Epoch(times=[0.5, 10.0, 25.2] * pq.s,
durations = [100, 100, 100] * pq.ms)
>>> epoch.annotate(event_type='complete trial',
trial_id=[1, 2, 3])
... durations=[100, 100, 100] * pq.ms,
... epoch_type='complete trial')
>>> epoch.array_annotate(trial_id=[1, 2, 3])
>>> seg = neo.Segment()
>>> seg.epochs = [epoch]

# Will return a list with the complete event object
>>> get_epochs(seg, prop={'epoch_type': 'complete trial'})
>>> get_epochs(seg, epoch_type='complete trial')

# Will return an empty list
>>> get_epochs(seg, prop={'epoch_type': 'error trial'})
>>> get_epochs(seg, epoch_type='error trial')

# Will return a list with an Event object, but only with trial 2
>>> get_epochs(seg, prop={'trial_id': 2})
>>> get_epochs(seg, trial_id=2)

# Will return a list with an Event object, but only with trials 1 and 2
>>> get_epochs(seg, prop={'trial_id': [1, 2]})
>>> get_epochs(seg, trial_id=[1, 2])
"""
if isinstance(container, neo.Segment):
return _get_from_list(container.epochs, prop=properties)
Expand Down Expand Up @@ -236,6 +241,12 @@ def _get_valid_ids(obj, annotation_key, annotation_value):
if annotation_key in obj.annotations and obj.annotations[annotation_key] == annotation_value:
valid_mask = np.ones(obj.shape)

elif annotation_key == 'labels':
# wrap annotation value to be list
if not type(annotation_value) in [list, np.ndarray]:
annotation_value = [annotation_value]
valid_mask = np.in1d(obj.labels, annotation_value)

elif annotation_key in obj.array_annotations:
# wrap annotation value to be list
if not type(annotation_value) in [list, np.ndarray]:
Expand Down