diff --git a/neo/test/test_utils.py b/neo/test/test_utils.py index 7fbb6bea0..bf00a7584 100644 --- a/neo/test/test_utils.py +++ b/neo/test/test_utils.py @@ -32,8 +32,10 @@ def setUp(self): class TestUtilsWithoutProxyObjects(unittest.TestCase): def test__get_events(self): - starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s) - starts_1.annotate(event_type='trial start', pick='me') + starts_1 = Event(times=[0.5, 10.0, 25.2] * pq.s, + labels=['label1', 'label2', 'label3'], + name='pick_me') + starts_1.annotate(event_type='trial start') starts_1.array_annotate(trial_id=[1, 2, 3]) stops_1 = Event(times=[5.5, 14.9, 30.1] * pq.s) @@ -56,9 +58,9 @@ def test__get_events(self): block = Block() block.segments = [seg, seg2] - # test getting one whole event via annotation + # test getting one whole event via annotation or attribute extracted_starts1 = get_events(seg, event_type='trial start') - extracted_starts1b = get_events(block, pick='me') + extracted_starts1b = get_events(block, name='pick_me') self.assertEqual(len(extracted_starts1), 1) self.assertEqual(len(extracted_starts1b), 1) @@ -128,6 +130,20 @@ def test__get_events(self): assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2])) self.assertIsInstance(trials_1_2.array_annotations, ArrayDict) + # test selecting event times by label + trials_1_2 = get_events(block, labels=['label1', 'label2']) + + self.assertEqual(len(trials_1_2), 1) + + trials_1_2 = trials_1_2[0] + + self.assertEqual(starts_1.name, trials_1_2.name) + self.assertEqual(starts_1.description, trials_1_2.description) + self.assertEqual(starts_1.file_origin, trials_1_2.file_origin) + self.assertEqual(starts_1.annotations['event_type'], trials_1_2.annotations['event_type']) + assert_arrays_equal(trials_1_2.array_annotations['trial_id'], np.array([1, 2])) + self.assertIsInstance(trials_1_2.array_annotations, ArrayDict) + # test getting more than one event time of more than one event trials_1_2b = get_events(block, trial_id=[1, 2]) diff --git a/neo/utils.py b/neo/utils.py index 7cfd08bd4..31564f7a5 100644 --- a/neo/utils.py +++ b/neo/utils.py @@ -7,7 +7,6 @@ import neo import copy import warnings -import inspect import numpy as np import quantities as pq @@ -49,6 +48,9 @@ def get_events(container, **properties): Example: -------- + >>> import neo + >>> from neo.utils import get_events + >>> import quantities as pq >>> event = neo.Event(times=[0.5, 10.0, 25.2] * pq.s) >>> event.annotate(event_type='trial start') >>> event.array_annotate(trial_id=[1, 2, 3]) @@ -56,16 +58,16 @@ def get_events(container, **properties): >>> seg.events = [event] # Will return a list with the complete event object - >>> get_events(seg, properties={'event_type': 'trial start'}) + >>> get_events(seg, event_type='trial start') # Will return an empty list - >>> get_events(seg, properties={'event_type': 'trial stop'}) + >>> get_events(seg, event_type='trial stop') # Will return a list with an Event object, but only with trial 2 - >>> get_events(seg, properties={'trial_id': 2}) + >>> get_events(seg, trial_id=2) # Will return a list with an Event object, but only with trials 1 and 2 - >>> get_events(seg, properties={'trial_id': [1, 2]}) + >>> get_events(seg, trial_id=[1, 2]) """ if isinstance(container, neo.Segment): return _get_from_list(container.events, prop=properties) @@ -118,24 +120,27 @@ def get_epochs(container, **properties): Example: -------- + >>> import neo + >>> from neo.utils import get_epochs + >>> import quantities as pq >>> epoch = neo.Epoch(times=[0.5, 10.0, 25.2] * pq.s, - durations = [100, 100, 100] * pq.ms) - >>> epoch.annotate(event_type='complete trial', - trial_id=[1, 2, 3]) + ... durations=[100, 100, 100] * pq.ms, + ... epoch_type='complete trial') + >>> epoch.array_annotate(trial_id=[1, 2, 3]) >>> seg = neo.Segment() >>> seg.epochs = [epoch] # Will return a list with the complete event object - >>> get_epochs(seg, prop={'epoch_type': 'complete trial'}) + >>> get_epochs(seg, epoch_type='complete trial') # Will return an empty list - >>> get_epochs(seg, prop={'epoch_type': 'error trial'}) + >>> get_epochs(seg, epoch_type='error trial') # Will return a list with an Event object, but only with trial 2 - >>> get_epochs(seg, prop={'trial_id': 2}) + >>> get_epochs(seg, trial_id=2) # Will return a list with an Event object, but only with trials 1 and 2 - >>> get_epochs(seg, prop={'trial_id': [1, 2]}) + >>> get_epochs(seg, trial_id=[1, 2]) """ if isinstance(container, neo.Segment): return _get_from_list(container.epochs, prop=properties) @@ -236,6 +241,12 @@ def _get_valid_ids(obj, annotation_key, annotation_value): if annotation_key in obj.annotations and obj.annotations[annotation_key] == annotation_value: valid_mask = np.ones(obj.shape) + elif annotation_key == 'labels': + # wrap annotation value to be list + if not type(annotation_value) in [list, np.ndarray]: + annotation_value = [annotation_value] + valid_mask = np.in1d(obj.labels, annotation_value) + elif annotation_key in obj.array_annotations: # wrap annotation value to be list if not type(annotation_value) in [list, np.ndarray]: