diff --git a/neo/test/test_utils.py b/neo/test/test_utils.py index 1623e5029..55c85d626 100644 --- a/neo/test/test_utils.py +++ b/neo/test/test_utils.py @@ -8,7 +8,7 @@ import quantities as pq from neo.rawio.examplerawio import ExampleRawIO from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy, - EventProxy, EpochProxy) + EventProxy, EpochProxy) from neo.core.dataobject import ArrayDict from neo.core import (Block, Segment, AnalogSignal, IrregularlySampledSignal, @@ -20,7 +20,10 @@ assert_same_attributes, assert_same_annotations) -from neo.utils import (get_events, get_epochs, add_epoch, match_events, cut_block_by_epochs) +from neo.utils import (get_events, get_epochs, add_epoch, match_events, + cut_block_by_epochs, merge_anasiglist) + +from copy import copy class BaseProxyTest(unittest.TestCase): @@ -465,6 +468,65 @@ def test__cut_block_by_epochs(self): epoch2.time_shift(- epoch.times[0]).time_slice(t_start=0 * pq.s, t_stop=epoch.durations[0])) + def test_merge_anasiglist(self): + baselist = [AnalogSignal(np.arange(55.0).reshape((11, 5)), + units="mV", + sampling_rate=1 * pq.kHz)] * 2 + + # Sanity of inputs + self.assertRaises(TypeError, merge_anasiglist, baselist[0]) + self.assertRaises(TypeError, merge_anasiglist, baselist, axis=1.0) + self.assertRaises(TypeError, merge_anasiglist, baselist, axis=9) + self.assertRaises(ValueError, merge_anasiglist, []) + self.assertRaises(ValueError, merge_anasiglist, [baselist[0]]) + + # Different units + wrongunits = AnalogSignal(np.arange(55.0).reshape((11, 5)), + units="uV", + sampling_rate=1 * pq.kHz) + analist = baselist + [wrongunits] + self.assertRaises(ValueError, merge_anasiglist, analist) + + # Different sampling rate + wrongsampl = AnalogSignal(np.arange(55.0).reshape((11, 5)), + units="mV", + sampling_rate=100 * pq.kHz) + analist = baselist + [wrongsampl] + self.assertRaises(ValueError, merge_anasiglist, analist) + + # Different t_start + wrongstart = AnalogSignal(np.arange(55.0).reshape((11, 5)), + t_start=10 * pq.s, + units="mV", + sampling_rate=1 * pq.kHz) + analist = baselist + [wrongstart] + self.assertRaises(ValueError, merge_anasiglist, analist) + + # Different shape + wrongshape = AnalogSignal(np.arange(50.0).reshape((10, 5)), + units="mV", + sampling_rate=1 * pq.kHz) + analist = baselist + [wrongshape] + self.assertRaises(ValueError, merge_anasiglist, analist) + + # Different shape + wrongshape = AnalogSignal(np.arange(50.0).reshape((5, 10)), + units="mV", + sampling_rate=1 * pq.kHz) + analist = baselist + [wrongshape] + self.assertRaises(ValueError, merge_anasiglist, analist, axis=0) + + # Check that the generated analogsignals are the corresponding ones + for axis in [0, 1]: + ana = np.concatenate((np.arange(55.0).reshape((11, 5)), + np.arange(55.0).reshape((11, 5))), + axis=axis) + signal1 = AnalogSignal(ana, units="mV", sampling_rate=1 * pq.kHz) + signal2 = merge_anasiglist(copy(baselist), axis=axis) + assert_arrays_equal(signal1.magnitude, signal2.magnitude) + assert_same_attributes(signal1, signal2) + assert_same_annotations(signal1, signal2) + class TestUtilsWithProxyObjects(BaseProxyTest): def test__get_events(self): diff --git a/neo/utils.py b/neo/utils.py index 84eac6193..92edc0e20 100644 --- a/neo/utils.py +++ b/neo/utils.py @@ -10,6 +10,86 @@ import quantities as pq +def merge_anasiglist(anasiglist, axis=1): + """ + Merges neo.AnalogSignal objects into a single object. + + Units, sampling_rate, t_start, t_stop and signals shape must be the same + for all signals. Otherwise a ValueError is raised. + + Parameters + ---------- + anasiglist: list of neo.AnalogSignal + list of analogsignals that will be merged + axis: int + axis along which to perform the merging + `axis = 1` corresponds to stacking the analogsignals + `axis = 0` corresponds to concatenating the analogsignals + Default: 1 + + Returns + ------- + merged_anasig: neo.AnalogSignal + merged output signal + """ + + # Sanity of inputs + if not isinstance(anasiglist, list): + raise TypeError('anasiglist must be a list') + if not isinstance(axis, int) or axis not in [0, 1]: + raise TypeError('axis must be 0 or 1') + if len(anasiglist) == 0: + raise ValueError('Empty list! nothing to be merged') + if len(anasiglist) == 1: + raise ValueError('Passed list of length 1, nothing to be merged') + + # Check units, sampling_rate, t_start, t_stop and signal shape + for anasig in anasiglist: + if not anasiglist[0].units == anasig.units: + raise ValueError('Units must be the same for all signals') + if not anasiglist[0].sampling_rate == anasig.sampling_rate: + raise ValueError('Sampling rate must be the same for all signals') + if not anasiglist[0].t_start == anasig.t_start: + raise ValueError('t_start must be the same for all signals') + if axis == 0: + if not anasiglist[0].magnitude.shape[1] == \ + anasig.magnitude.shape[1]: + raise ValueError('Analogsignals appear to contain different ' + 'number of channels!') + if axis == 1: + if not anasiglist[0].magnitude.shape[0] == \ + anasig.magnitude.shape[0]: + raise ValueError('Cannot merge signals of different length.') + + # Initialize the arrays + anasig0 = anasiglist.pop(0) + data_array = anasig0.magnitude + sr = anasig0.sampling_rate + t_start = anasig0.t_start + units = anasig0.units + + # Get the full array annotations + for anasig in anasiglist: + anasig0.array_annotations = anasig0._merge_array_annotations(anasig) + + array_annot = anasig0.array_annotations + del anasig0 + + while len(anasiglist) != 0: + anasig = anasiglist.pop(0) + data_array = np.concatenate((data_array, anasig.magnitude), + axis=axis) + del anasig + + # Create new analogsignal object to contain the analogsignals + merged_anasig = neo.AnalogSignal(data_array, + sampling_rate=sr, + t_start=t_start, + units=units, + array_annotations=array_annot) + return merged_anasig + + def get_events(container, **properties): """ This function returns a list of Event objects, corresponding to given