diff --git a/neo/core/analogsignal.py b/neo/core/analogsignal.py index 04c89ba81..476b732c2 100644 --- a/neo/core/analogsignal.py +++ b/neo/core/analogsignal.py @@ -29,7 +29,7 @@ import numpy as np import quantities as pq -from neo.core.baseneo import BaseNeo, MergeError, merge_annotations +from neo.core.baseneo import BaseNeo, MergeError, merge_annotations, intersect_annotations from neo.core.dataobject import DataObject from copy import copy, deepcopy @@ -657,3 +657,136 @@ def rectify(self, **kwargs): rectified_signal.array_annotations = self.array_annotations.copy() return rectified_signal + + def concatenate(self, *signals, overwrite=False, padding=False): + """ + Concatenate multiple neo.AnalogSignal objects across time. + + Units, sampling_rate and number of signal traces must be the same + for all signals. Otherwise a ValueError is raised. + Note that timestamps of concatenated signals might shift in oder to + align the sampling times of all signals. + + Parameters + ---------- + signals: neo.AnalogSignal objects + AnalogSignals that will be concatenated + overwrite : bool + If True, samples of the earlier (lower index in `signals`) + signals are overwritten by that of later (higher index in `signals`) + signals. + If False, samples of the later are overwritten by earlier signal. + Default: False + padding : bool, scalar quantity + Sampling values to use as padding in case signals do not overlap. + If False, do not apply padding. Signals have to align or + overlap. If True, signals will be padded using + np.NaN as pad values. If a scalar quantity is provided, this + will be used for padding. The other signal is moved + forward in time by maximum one sampling period to + align the sampling times of both signals. + Default: False + + Returns + ------- + signal: neo.AnalogSignal + concatenated output signal + """ + + # Sanity of inputs + if not hasattr(signals, '__iter__'): + raise TypeError('signals must be iterable') + if not all([isinstance(a, AnalogSignal) for a in signals]): + raise TypeError('Entries of anasiglist have to be of type neo.AnalogSignal') + if len(signals) == 0: + return self + + signals = [self] + list(signals) + + # Check required common attributes: units, sampling_rate and shape[-1] + shared_attributes = ['units', 'sampling_rate'] + attribute_values = [tuple((getattr(anasig, attr) for attr in shared_attributes)) + for anasig in signals] + # add shape dimensions that do not relate to time + attribute_values = [(attribute_values[i] + (signals[i].shape[1:],)) + for i in range(len(signals))] + if not all([attrs == attribute_values[0] for attrs in attribute_values]): + raise MergeError( + f'AnalogSignals have to share {shared_attributes} attributes to be concatenated.') + units, sr, shape = attribute_values[0] + + # find gaps between Analogsignals + combined_time_ranges = self._concatenate_time_ranges( + [(s.t_start, s.t_stop) for s in signals]) + missing_time_ranges = self._invert_time_ranges(combined_time_ranges) + if len(missing_time_ranges): + diffs = np.diff(np.asarray(missing_time_ranges), axis=1) + else: + diffs = [] + + if padding is False and any(diffs > signals[0].sampling_period): + raise MergeError(f'Signals are not continuous. Can not concatenate signals with gaps. ' + f'Please provide a padding value.') + if padding is not False: + logger.warning('Signals will be padded using {}.'.format(padding)) + if padding is True: + padding = np.NaN * units + if isinstance(padding, pq.Quantity): + padding = padding.rescale(units).magnitude + else: + raise MergeError('Invalid type of padding value. Please provide a bool value ' + 'or a quantities object.') + + t_start = min([a.t_start for a in signals]) + t_stop = max([a.t_stop for a in signals]) + n_samples = int(np.rint(((t_stop - t_start) * sr).rescale('dimensionless').magnitude)) + shape = (n_samples,) + shape + + # Collect attributes and annotations across all concatenated signals + kwargs = {} + common_annotations = signals[0].annotations + common_array_annotations = signals[0].array_annotations + for anasig in signals[1:]: + common_annotations = intersect_annotations(common_annotations, anasig.annotations) + common_array_annotations = intersect_annotations(common_array_annotations, + anasig.array_annotations) + + kwargs['annotations'] = common_annotations + kwargs['array_annotations'] = common_array_annotations + + for name in ("name", "description", "file_origin"): + attr = [getattr(s, name) for s in signals] + if all([a == attr[0] for a in attr]): + kwargs[name] = attr[0] + else: + kwargs[name] = f'concatenation ({attr})' + + conc_signal = AnalogSignal(np.full(shape=shape, fill_value=padding, dtype=signals[0].dtype), + sampling_rate=sr, t_start=t_start, units=units, **kwargs) + + if not overwrite: + signals = signals[::-1] + while len(signals) > 0: + conc_signal.splice(signals.pop(0), copy=False) + + return conc_signal + + def _concatenate_time_ranges(self, time_ranges): + time_ranges = sorted(time_ranges) + new_ranges = time_ranges[:1] + for t_start, t_stop in time_ranges[1:]: + # time range are non continuous -> define new range + if t_start > new_ranges[-1][1]: + new_ranges.append((t_start, t_stop)) + # time range is continuous -> extend time range + elif t_stop > new_ranges[-1][1]: + new_ranges[-1] = (new_ranges[-1][0], t_stop) + return new_ranges + + def _invert_time_ranges(self, time_ranges): + i = 0 + new_ranges = [] + while i < len(time_ranges) - 1: + new_ranges.append((time_ranges[i][1], time_ranges[i + 1][0])) + i += 1 + return new_ranges diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index a11c2ed62..f9e2120d6 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -3,6 +3,7 @@ used by all :module:`neo.core` classes. """ +from copy import deepcopy from datetime import datetime, date, time, timedelta from decimal import Decimal import logging @@ -109,6 +110,37 @@ def merge_annotations(A, *Bs): return merged +def intersect_annotations(A, B): + """ + Identify common entries in dictionaries A and B + and return these in a separate dictionary. + + Entries have to share key as well as value to be + considered common. + + Parameters + ---------- + A, B : dict + Dictionaries to merge. + """ + + result = {} + + for key in set(A.keys()) & set(B.keys()): + v1, v2 = A[key], B[key] + assert type(v1) == type(v2), 'type({}) {} != type({}) {}'.format(v1, type(v1), + v2, type(v2)) + if isinstance(v1, dict) and v1 == v2: + result[key] = deepcopy(v1) + elif isinstance(v1, str) and v1 == v2: + result[key] = A[key] + elif isinstance(v1, list) and v1 == v2: + result[key] = deepcopy(v1) + elif isinstance(v1, np.ndarray) and all(v1 == v2): + result[key] = deepcopy(v1) + return result + + def _reference_name(class_name): """ Given the name of a class, return an attribute name to be used for diff --git a/neo/core/basesignal.py b/neo/core/basesignal.py index f03868897..7d271b676 100644 --- a/neo/core/basesignal.py +++ b/neo/core/basesignal.py @@ -21,7 +21,7 @@ import numpy as np import quantities as pq -from neo.core.baseneo import BaseNeo, MergeError, merge_annotations +from neo.core.baseneo import MergeError, merge_annotations from neo.core.dataobject import DataObject, ArrayDict from neo.core.channelindex import ChannelIndex @@ -282,11 +282,52 @@ def merge(self, other): # merge channel_index (move to ChannelIndex.merge()?) if self.channel_index and other.channel_index: signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1]), - channel_ids=np.hstack( - [self.channel_index.channel_ids, other.channel_index.channel_ids]), - channel_names=np.hstack( - [self.channel_index.channel_names, other.channel_index.channel_names])) + channel_ids=np.hstack( + [self.channel_index.channel_ids, + other.channel_index.channel_ids]), + channel_names=np.hstack( + [self.channel_index.channel_names, + other.channel_index.channel_names])) else: signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1])) return signal + + def time_slice(self, t_start, t_stop): + ''' + Creates a new AnalogSignal corresponding to the time slice of the + original Signal between times t_start, t_stop. + ''' + NotImplementedError('Needs to be implemented for subclasses.') + + def concatenate(self, *signals): + ''' + Concatenate multiple signals across time. + + The signal objects are concatenated vertically + (row-wise, :func:`np.vstack`). Concatenation can be + used to combine signals across segments. + Note: Only (array) annotations common to + both signals are attached to the concatenated signal. + + If the attributes of the signals are not + compatible, an Exception is raised. + + Parameters + ---------- + signals : multiple neo.BaseSignal objects + The objects that is concatenated with this one. + + Returns + ------- + signal : neo.BaseSignal + Signal containing all non-overlapping samples of + the source signals. + + Raises + ------ + MergeError + If `other` object has incompatible attributes. + ''' + + NotImplementedError('Patching need to be implemented in sublcasses') diff --git a/neo/core/irregularlysampledsignal.py b/neo/core/irregularlysampledsignal.py index 7dee88f3d..e0f5a79b0 100644 --- a/neo/core/irregularlysampledsignal.py +++ b/neo/core/irregularlysampledsignal.py @@ -31,7 +31,7 @@ import numpy as np import quantities as pq -from neo.core.baseneo import BaseNeo, MergeError, merge_annotations +from neo.core.baseneo import MergeError, merge_annotations, intersect_annotations from neo.core.basesignal import BaseSignal from neo.core.analogsignal import AnalogSignal from neo.core.channelindex import ChannelIndex @@ -514,3 +514,94 @@ def merge(self, other): signal.channel_index = ChannelIndex(index=np.arange(signal.shape[1])) return signal + + def concatenate(self, other, allow_overlap=False): + ''' + Combine this and another signal along the time axis. + + The signal objects are concatenated vertically + (row-wise, :func:`np.vstack`). Patching can be + used to combine signals across segments. + Note: Only array annotations common to + both signals are attached to the concatenated signal. + + If the attributes of the two signal are not + compatible, an Exception is raised. + + Required attributes of the signal are used. + + Parameters + ---------- + other : neo.BaseSignal + The object that is merged into this one. + allow_overlap : bool + If false, overlapping samples between the two + signals are not permitted and an ValueError is raised. + If true, no check for overlapping samples is + performed and all samples are combined. + + Returns + ------- + signal : neo.IrregularlySampledSignal + Signal containing all non-overlapping samples of + both source signals. + + Raises + ------ + MergeError + If `other` object has incompatible attributes. + ''' + + for attr in self._necessary_attrs: + if not (attr[0] in ['signal', 'times', 't_start', 't_stop', 'times']): + if getattr(self, attr[0], None) != getattr(other, attr[0], None): + raise MergeError( + "Cannot concatenate these two signals as the %s differ." % attr[0]) + + if hasattr(self, "lazy_shape"): + if hasattr(other, "lazy_shape"): + if self.lazy_shape[-1] != other.lazy_shape[-1]: + raise MergeError("Cannot concatenate signals as they contain" + " different numbers of traces.") + merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0], self.lazy_shape[-1]) + else: + raise MergeError("Cannot concatenate a lazy object with a real object.") + if other.units != self.units: + other = other.rescale(self.units) + + new_times = np.hstack((self.times, other.times)) + sorting = np.argsort(new_times) + new_samples = np.vstack((self.magnitude, other.magnitude)) + + kwargs = {} + for name in ("name", "description", "file_origin"): + attr_self = getattr(self, name) + attr_other = getattr(other, name) + if attr_self == attr_other: + kwargs[name] = attr_self + else: + kwargs[name] = "merge({}, {})".format(attr_self, attr_other) + merged_annotations = merge_annotations(self.annotations, other.annotations) + kwargs.update(merged_annotations) + + kwargs['array_annotations'] = intersect_annotations(self.array_annotations, + other.array_annotations) + + if not allow_overlap: + if max(self.t_start, other.t_start) <= min(self.t_stop, other.t_stop): + raise ValueError('Can not combine signals that overlap in time. Allow for ' + 'overlapping samples using the "no_overlap" parameter.') + + t_start = min(self.t_start, other.t_start) + t_stop = max(self.t_start, other.t_start) + + signal = IrregularlySampledSignal(signal=new_samples[sorting], times=new_times[sorting], + units=self.units, dtype=self.dtype, copy=False, + t_start=t_start, t_stop=t_stop, **kwargs) + signal.segment = None + signal.channel_index = None + + if hasattr(self, "lazy_shape"): + signal.lazy_shape = merged_lazy_shape + + return signal diff --git a/neo/test/coretest/test_analogsignal.py b/neo/test/coretest/test_analogsignal.py index 608241960..de16b9ede 100644 --- a/neo/test/coretest/test_analogsignal.py +++ b/neo/test/coretest/test_analogsignal.py @@ -28,6 +28,7 @@ HAVE_SCIPY = True from numpy.testing import assert_array_equal +from neo.core.baseneo import MergeError from neo.core.analogsignal import AnalogSignal, _get_sampling_rate from neo.core.channelindex import ChannelIndex from neo.core import Segment @@ -369,18 +370,17 @@ def test__repr(self): def test__pretty(self): for i, signal in enumerate(self.signals): prepr = pretty(signal) - targ = ( - ('AnalogSignal with %d channels of length %d; units %s; datatype %s \n' - '' % (signal.shape[1], signal.shape[0], - signal.units.dimensionality.unicode, signal.dtype)) - + ('annotations: %s\n' % signal.annotations) - + ('sampling rate: {} {}\n'.format(float(signal.sampling_rate), - signal.sampling_rate.dimensionality.unicode)) - + ('time: {} {} to {} {}'.format(float(signal.t_start), - signal.t_start.dimensionality.unicode, - float(signal.t_stop), - signal.t_stop.dimensionality.unicode)) - ) + targ = (('AnalogSignal with %d channels of length %d; units %s; datatype %s \n' + '' % (signal.shape[1], signal.shape[0], + signal.units.dimensionality.unicode, signal.dtype)) + + ('annotations: %s\n' % signal.annotations) + + ('sampling rate: {} {}\n'.format( + float(signal.sampling_rate), + signal.sampling_rate.dimensionality.unicode)) + + ('time: {} {} to {} {}'.format(float(signal.t_start), + signal.t_start.dimensionality.unicode, + float(signal.t_stop), + signal.t_stop.dimensionality.unicode))) self.assertEqual(prepr, targ) @@ -1290,6 +1290,7 @@ def test_rectify(self): assert_arrays_equal(rectified_signal.array_annotations['anno2'], target_signal.array_annotations['anno2']) + class TestAnalogSignalEquality(unittest.TestCase): def test__signals_with_different_data_complement_should_be_not_equal(self): signal1 = AnalogSignal(np.arange(55.0).reshape((11, 5)), units="mV", @@ -1572,6 +1573,160 @@ def test__merge(self): assert_arrays_equal(mergeddata23, targdata23) assert_arrays_equal(mergeddata24, targdata24) + def test_concatenate_simple(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop) + + result = signal1.concatenate(signal2) + assert_array_equal(np.arange(7).reshape((-1, 1)), result.magnitude) + for attr in signal1._necessary_attrs: + self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None)) + + def test_concatenate_no_signals(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + self.assertIs(signal1, signal1.concatenate()) + + def test_concatenate_reverted_order(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop) + + result = signal2.concatenate(signal1) + assert_array_equal(np.arange(7).reshape((-1, 1)), result.magnitude) + for attr in signal1._necessary_attrs: + self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None)) + + def test_concatenate_no_overlap(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, t_start=10 * pq.s) + + with self.assertRaises(MergeError): + signal1.concatenate(signal2) + + def test_concatenate_multi_trace(self): + data1 = np.arange(4).reshape(2, 2) + data2 = np.arange(4, 8).reshape(2, 2) + signal1 = AnalogSignal(data1 * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal(data2 * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop) + + result = signal1.concatenate(signal2) + data_expected = np.array([[0, 1], [2, 3], [4, 5], [6, 7]]) + assert_array_equal(data_expected, result.magnitude) + for attr in signal1._necessary_attrs: + self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None)) + + def test_concatenate_overwrite_true(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop - signal1.sampling_period) + + result = signal1.concatenate(signal2, overwrite=True) + assert_array_equal(np.array([0, 1, 2, 4, 5, 6]).reshape((-1, 1)), result.magnitude) + + def test_concatenate_overwrite_false(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop - signal1.sampling_period) + + result = signal1.concatenate(signal2, overwrite=False) + assert_array_equal(np.array([0, 1, 2, 3, 5, 6]).reshape((-1, 1)), result.magnitude) + + def test_concatenate_padding_False(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=10 * pq.s) + + with self.assertRaises(MergeError): + result = signal1.concatenate(signal2, overwrite=False, padding=False) + + def test_concatenate_padding_True(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop + 3 * signal1.sampling_period) + + result = signal1.concatenate(signal2, overwrite=False, padding=True) + assert_array_equal( + np.array([0, 1, 2, 3, np.NaN, np.NaN, np.NaN, 4, 5, 6]).reshape((-1, 1)), + result.magnitude) + + def test_concatenate_padding_quantity(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop + 3 * signal1.sampling_period) + + result = signal1.concatenate(signal2, overwrite=False, padding=-1 * pq.mV) + assert_array_equal(np.array([0, 1, 2, 3, -1e-3, -1e-3, -1e-3, 4, 5, 6]).reshape((-1, 1)), + result.magnitude) + + def test_concatenate_padding_invalid(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop + 3 * signal1.sampling_period) + + with self.assertRaises(MergeError): + result = signal1.concatenate(signal2, overwrite=False, padding=1) + with self.assertRaises(MergeError): + result = signal1.concatenate(signal2, overwrite=False, padding=[1]) + with self.assertRaises(MergeError): + result = signal1.concatenate(signal2, overwrite=False, padding='a') + with self.assertRaises(MergeError): + result = signal1.concatenate(signal2, overwrite=False, padding=np.array([1, 2, 3])) + + def test_concatenate_array_annotations(self): + array_anno1 = {'first': ['a', 'b']} + array_anno2 = {'first': ['a', 'b'], + 'second': ['c', 'd']} + data1 = np.arange(4).reshape(2, 2) + data2 = np.arange(4, 8).reshape(2, 2) + signal1 = AnalogSignal(data1 * pq.V, sampling_rate=1 * pq.Hz, + array_annotations=array_anno1) + signal2 = AnalogSignal(data2 * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop, + array_annotations=array_anno2) + + result = signal1.concatenate(signal2) + assert_array_equal(array_anno1.keys(), result.array_annotations.keys()) + + for k in array_anno1.keys(): + assert_array_equal(np.asarray(array_anno1[k]), result.array_annotations[k]) + + def test_concatenate_complex(self): + signal1 = self.signal1 + assert_neo_object_is_compliant(self.signal1) + + signal2 = AnalogSignal(self.data1quant, sampling_rate=1 * pq.kHz, name='signal2', + description='test signal', file_origin='testfile.txt', + array_annotations=self.arr_ann1, + t_start=signal1.t_stop) + + concatenated12 = self.signal1.concatenate(signal2) + + for attr in signal1._necessary_attrs: + self.assertEqual(getattr(signal1, attr[0], None), + getattr(concatenated12, attr[0], None)) + + assert_array_equal(np.vstack((signal1.magnitude, signal2.magnitude)), + concatenated12.magnitude) + + def test_concatenate_multi_signal(self): + signal1 = AnalogSignal([0, 1, 2, 3] * pq.V, sampling_rate=1 * pq.Hz) + signal2 = AnalogSignal([4, 5, 6] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop + 3 * signal1.sampling_period) + signal3 = AnalogSignal([40] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop + 3 * signal1.sampling_period) + signal4 = AnalogSignal([30, 35] * pq.V, sampling_rate=1 * pq.Hz, + t_start=signal1.t_stop - signal1.sampling_period) + + concatenated = signal1.concatenate(signal2, signal3, signal4, padding=-1 * pq.V, + overwrite=True) + for attr in signal1._necessary_attrs: + self.assertEqual(getattr(signal1, attr[0], None), + getattr(concatenated, attr[0], None)) + assert_arrays_equal(np.array([0, 1, 2, 30, 35, -1, -1, 40, 5, 6]).reshape((-1, 1)), + concatenated.magnitude) + class TestAnalogSignalFunctions(unittest.TestCase): def test__pickle_1d(self): diff --git a/neo/test/coretest/test_base.py b/neo/test/coretest/test_base.py index 91e948784..bd7b010db 100644 --- a/neo/test/coretest/test_base.py +++ b/neo/test/coretest/test_base.py @@ -20,7 +20,8 @@ HAVE_IPYTHON = True from neo.core.baseneo import (BaseNeo, _check_annotations, - merge_annotations, merge_annotation) + merge_annotations, merge_annotation, + intersect_annotations) from neo.test.tools import assert_arrays_equal @@ -1266,5 +1267,50 @@ def test__pretty(self): self.assertEqual(res, targ) +class Test_intersect_annotations(unittest.TestCase): + ''' + TestCase for intersect_annotations + ''' + + def setUp(self): + self.dict1 = {1: '1', 2: '2'} + self.dict2 = {1: '1'} + self.dict3 = {'list1': [1, 2, 3]} + self.dict4 = {'list1': [1, 2, 3], 'list2': [1, 2, 3]} + self.dict5 = {'list1': [1, 2]} + self.dict6 = {'array1': np.array([1, 2])} + self.dict7 = {'array1': np.array([1, 2]), 'array2': np.array([1, 2]), + 'array3': np.array([1, 2, 3])} + + self.all_simple_dicts = [self.dict1, self.dict2, self.dict3, + self.dict4, self.dict5, ] + + def test_simple(self): + result = intersect_annotations(self.dict1, self.dict2) + self.assertDictEqual(self.dict2, result) + + def test_intersect_self(self): + for d in self.all_simple_dicts: + result = intersect_annotations(d, d) + self.assertDictEqual(d, result) + + def test_list(self): + result = intersect_annotations(self.dict3, self.dict4) + self.assertDictEqual(self.dict3, result) + + def test_list_values(self): + result = intersect_annotations(self.dict4, self.dict5) + self.assertDictEqual({}, result) + + def test_keys(self): + result = intersect_annotations(self.dict1, self.dict4) + self.assertDictEqual({}, result) + + def test_arrays(self): + result = intersect_annotations(self.dict6, self.dict7) + self.assertEqual(self.dict6.keys(), result.keys()) + np.testing.assert_array_equal([1, 2], result['array1']) + + if __name__ == "__main__": unittest.main() diff --git a/neo/test/coretest/test_irregularysampledsignal.py b/neo/test/coretest/test_irregularysampledsignal.py index 6b026878f..ecce9c7ec 100644 --- a/neo/test/coretest/test_irregularysampledsignal.py +++ b/neo/test/coretest/test_irregularysampledsignal.py @@ -239,13 +239,11 @@ def test_IrregularlySampledSignal_repr(self): if np.__version__.split(".")[:2] > ['1', '13']: # see https://github.com/numpy/numpy/blob/master/doc/release/1.14.0-notes.rst#many # -changes-to-array-printing-disableable-with-the-new-legacy-printing-mode - targ = ( - '') + targ = ('') else: - targ = ( - '') + targ = ('') res = repr(sig) self.assertEqual(targ, res) @@ -461,7 +459,7 @@ def test__time_slice_deepcopy_array_annotations(self): length = self.signal1.shape[-1] params1 = {'test0': ['y{}'.format(i) for i in range(length)], 'test1': ['deeptest' for i in range(length)], - 'test2': [(-1)**i > 0 for i in range(length)]} + 'test2': [(-1) ** i > 0 for i in range(length)]} self.signal1.array_annotate(**params1) result = self.signal1.time_slice(None, None) @@ -479,7 +477,7 @@ def test__time_slice_deepcopy_array_annotations(self): == result.array_annotations['test2'])) # Change annotations of result - params3 = {'test0': ['z{}'.format(i) for i in range(1, result.shape[-1]+1)]} + params3 = {'test0': ['z{}'.format(i) for i in range(1, result.shape[-1] + 1)]} result.array_annotate(**params3) result.array_annotations['test1'][0] = 'shallow2' self.assertFalse(all(self.signal1.array_annotations['test0'] @@ -493,12 +491,12 @@ def test__time_slice_deepcopy_data(self): result = self.signal1.time_slice(None, None) # Change values of original array - self.signal1[2] = 7.3*self.signal1.units + self.signal1[2] = 7.3 * self.signal1.units self.assertFalse(all(self.signal1 == result)) # Change values of sliced array - result[3] = 9.5*result.units + result[3] = 9.5 * result.units self.assertFalse(all(self.signal1 == result)) @@ -945,6 +943,73 @@ def test__merge(self): self.assertRaises(MergeError, signal1.merge, signal3) + def test_concatenate_simple(self): + signal1 = IrregularlySampledSignal(signal=[0, 1, 2, 3] * pq.s, + times=[1, 10, 11, 14] * pq.s) + signal2 = IrregularlySampledSignal(signal=[4, 5, 6] * pq.s, times=[15, 16, 21] * pq.s) + + result = signal1.concatenate(signal2) + assert_array_equal(np.array([0, 1, 2, 3, 4, 5, 6]).reshape((-1, 1)), result.magnitude) + assert_array_equal(np.array([1, 10, 11, 14, 15, 16, 21]), result.times) + for attr in signal1._necessary_attrs: + if attr[0] == 'times': + continue + self.assertEqual(getattr(signal1, attr[0], None), getattr(result, attr[0], None)) + + def test_concatenate_no_overlap(self): + signal1 = IrregularlySampledSignal(signal=[0, 1, 2, 3] * pq.s, times=range(4) * pq.s) + signal2 = IrregularlySampledSignal(signal=[4, 5, 6] * pq.s, times=range(4, 7) * pq.s) + + for allow_overlap in [True, False]: + result = signal1.concatenate(signal2, allow_overlap=allow_overlap) + assert_array_equal(np.arange(7).reshape((-1, 1)), result.magnitude) + assert_array_equal(np.arange(7), result.times) + + def test_concatenate_overlap_exception(self): + signal1 = IrregularlySampledSignal(signal=[0, 1, 2, 3] * pq.s, times=range(4) * pq.s) + signal2 = IrregularlySampledSignal(signal=[4, 5, 6] * pq.s, times=range(2, 5) * pq.s) + + self.assertRaises(ValueError, signal1.concatenate, signal2, allow_overlap=False) + + def test_concatenate_overlap(self): + signal1 = IrregularlySampledSignal(signal=[0, 1, 2, 3] * pq.s, times=range(4) * pq.s) + signal2 = IrregularlySampledSignal(signal=[4, 5, 6] * pq.s, times=range(2, 5) * pq.s) + + result = signal1.concatenate(signal2, allow_overlap=True) + assert_array_equal(np.array([0, 1, 2, 4, 3, 5, 6]).reshape((-1, 1)), result.magnitude) + assert_array_equal(np.array([0, 1, 2, 2, 3, 3, 4]), result.times) + + def test_concatenate_multi_trace(self): + data1 = np.arange(4).reshape(2, 2) + data2 = np.arange(4, 8).reshape(2, 2) + n1 = len(data1) + n2 = len(data2) + signal1 = IrregularlySampledSignal(signal=data1 * pq.s, times=range(n1) * pq.s) + signal2 = IrregularlySampledSignal(signal=data2 * pq.s, times=range(n1, n1 + n2) * pq.s) + + result = signal1.concatenate(signal2) + data_expected = np.array([[0, 1], [2, 3], [4, 5], [6, 7]]) + assert_array_equal(data_expected, result.magnitude) + + def test_concatenate_array_annotations(self): + array_anno1 = {'first': ['a', 'b']} + array_anno2 = {'first': ['a', 'b'], + 'second': ['c', 'd']} + data1 = np.arange(4).reshape(2, 2) + data2 = np.arange(4, 8).reshape(2, 2) + n1 = len(data1) + n2 = len(data2) + signal1 = IrregularlySampledSignal(signal=data1 * pq.s, times=range(n1) * pq.s, + array_annotations=array_anno1) + signal2 = IrregularlySampledSignal(signal=data2 * pq.s, times=range(n1, n1 + n2) * pq.s, + array_annotations=array_anno2) + + result = signal1.concatenate(signal2) + assert_array_equal(array_anno1.keys(), result.array_annotations.keys()) + + for k in array_anno1.keys(): + assert_array_equal(np.asarray(array_anno1[k]), result.array_annotations[k]) + class TestAnalogSignalFunctions(unittest.TestCase): def test__pickle(self):