From 84d224e1c22e71f01b4dab4c4726669966c77449 Mon Sep 17 00:00:00 2001 From: Jesse Livezey Date: Thu, 14 Apr 2022 15:57:14 -0700 Subject: [PATCH] use fast fft lens by default and loop resample --- src/process_nwb/linenoise_notch.py | 5 +++-- src/process_nwb/resample.py | 32 +++++++++++++++------------- src/process_nwb/utils.py | 20 +++++++++++------ src/process_nwb/wavelet_transform.py | 12 ++++++----- tests/test_resample.py | 29 ++++++++++++++----------- 5 files changed, 57 insertions(+), 41 deletions(-) diff --git a/src/process_nwb/linenoise_notch.py b/src/process_nwb/linenoise_notch.py index afbc6e0..6750357 100644 --- a/src/process_nwb/linenoise_notch.py +++ b/src/process_nwb/linenoise_notch.py @@ -52,7 +52,7 @@ def _apply_notches(X, notches, rate, fft=True, precision='single'): return Xp.astype(X_dtype, copy=False) -def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad=0, precision='single'): +def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad='fast', precision='single'): """Apply notch filters at 60 Hz (by default) and its harmonics. Filters +/- 1 Hz around the frequencies. @@ -68,7 +68,8 @@ def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad=0, precision='si noise_hz: float Frequency to notch out npad : int - Padding to add to beginning and end of timeseries. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. precision : str Either `single` for float32/complex64 or `double` for float/complex. diff --git a/src/process_nwb/resample.py b/src/process_nwb/resample.py index 199b15c..e4a65b9 100644 --- a/src/process_nwb/resample.py +++ b/src/process_nwb/resample.py @@ -40,7 +40,7 @@ """ -def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='single'): +def resample_func(X, num, npad='fast', pad='reflect_limited', real=True, precision='single'): """Resample an array. Operates along the first dimension of the array. This is the low-level code. Users shoud likely use `resample()` rather than this function. @@ -51,7 +51,8 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s num : int Number of samples in resampled signal. npad : int - Padding to add to beginning and end of timeseries. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. pad : str Type of padding. The default is ``'reflect_limited'``. real : bool @@ -74,9 +75,9 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s n_time = X.shape[0] ratio = float(num) / n_time npads, to_removes, new_len = _npads(X, npad, ratio=ratio) + X = _smart_pad(X, npads, pad) # do the resampling using an adaptation of scipy's FFT-based resample() - X = _smart_pad(X, npads, pad) old_len = len(X) shorter = new_len < old_len use_len = new_len if shorter else old_len @@ -101,7 +102,7 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s return y -def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single'): +def resample(X, new_freq, old_freq, real=True, axis=0, npad='fast', precision='single', loop=True): """Resamples the timeseries from the original sampling frequency to a new frequency. Parameters @@ -117,9 +118,12 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single axis : int Which axis to resample. npad : int - Padding to add to beginning and end of timeseries. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. precision : str Either `single` for float32/complex64 or `double` for float/complex. + loop : bool + Whether or not to loop across channels when resampling. Returns ------- @@ -135,15 +139,12 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single n_time = X.shape[0] new_n_time = int(np.ceil(n_time * new_freq / old_freq)) - loop = False - if X.size >= 10**8 and X.shape[1] > 1: - loop = True - if loop: - Xds = np.zeros((new_n_time,) + X.shape[1:], dtype=X_dtype) - for ii in range(X.shape[1]): - Xds[:, ii] = resample_func(X[:, [ii]], new_n_time, npad=npad, real=real, - precision=precision)[:, 0] + Xds = np.zeros((new_n_time, np.prod(X.shape[1:])), dtype=X_dtype) + for ii in range(np.prod(X.shape[1:])): + Xds[:, ii] = resample_func(X.reshape(X.shape[0], -1)[:, [ii]], new_n_time, npad=npad, + real=real, precision=precision)[:, 0] + Xds = Xds.reshape((new_n_time,) + X.shape[1:]) else: Xds = resample_func(X, new_n_time, npad=npad, real=real, precision=precision) if axis != 0: @@ -152,7 +153,7 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single return Xds -def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad=0, precision='single'): +def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad='fast', precision='single'): """Resamples the `ElectricalSeries` from the original sampling frequency to a new frequency and store the results in a new ElectricalSeries. @@ -170,7 +171,8 @@ def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad Scale the values by this. Can help with accuracy of downstream operations if the raw values are too small. Default = 1e6. npad : int - Padding to add to beginning and end of timeseries. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. precision : str Either `single` for float32/complex64 or `double` for float/complex. diff --git a/src/process_nwb/utils.py b/src/process_nwb/utils.py index d7b2062..4b9c3e6 100644 --- a/src/process_nwb/utils.py +++ b/src/process_nwb/utils.py @@ -1,4 +1,5 @@ import numpy as np +from scipy.fft import next_fast_len from datetime import datetime from dateutil.tz import tzlocal @@ -106,15 +107,20 @@ def _npads(X, npad, ratio=1.): """Calculate padding parameters. """ n_time = X.shape[0] - bad_msg = 'npad must be "auto" or an integer' + bad_msg = 'npad must be "auto", "fast", or an integer' if isinstance(npad, str): - if npad != 'auto': + if npad == 'auto': + # Figure out reasonable pad that gets us to a power of 2 + min_add = min(n_time // 8, 100) * 2 + npad = 2 ** int(np.ceil(np.log2(n_time + min_add))) - n_time + npad, extra = divmod(npad, 2) + npads = np.array([npad, npad + extra], int) + elif npad == 'fast': + npad = next_fast_len(n_time) - n_time + npad, extra = divmod(npad, 2) + npads = np.array([npad, npad + extra], int) + else: raise ValueError(bad_msg) - # Figure out reasonable pad that gets us to a power of 2 - min_add = min(n_time // 8, 100) * 2 - npad = 2 ** int(np.ceil(np.log2(n_time + min_add))) - n_time - npad, extra = divmod(npad, 2) - npads = np.array([npad, npad + extra], int) else: if npad != int(npad): raise ValueError(bad_msg) diff --git a/src/process_nwb/wavelet_transform.py b/src/process_nwb/wavelet_transform.py index df3045d..744f794 100644 --- a/src/process_nwb/wavelet_transform.py +++ b/src/process_nwb/wavelet_transform.py @@ -173,7 +173,7 @@ def __init__(self, X, rate, filters='rat', npad=None, hg_only=True, post_resampl # Need to pad X before predicting chunk and filter shape: self.npads, self.to_removes, _ = _npads(X, npad) - self.wavelet_time = X.shape[0] + 2 * npad + self.wavelet_time = X.shape[0] + self.npads.sum() self.filterbank, self.cfs, self.sds = get_filterbank(filters, self.wavelet_time, self.rate, hg_only, precision=self.precision) self.resample_time = self.X.shape[0] @@ -231,7 +231,7 @@ def recommended_data_shape(self): return (self.resample_time, self.nch, self.nbands) -def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0, to_removes=None, +def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad='fast', to_removes=None, precision='single'): """Apply a wavelet transform using a prespecified set of filters. @@ -256,7 +256,8 @@ def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0 Precomputed product of X_fft and heavyside. Useful for when bands are computed independently. npad : int - Length of padding in samples. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. to_removes : int Number of samples to remove at the beginning and end of the timeseries. Default None. precision : str @@ -307,7 +308,7 @@ def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0 def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True, abs_only=True, - npad=0, post_resample_rate=None, chunked=True, precision='single', + npad='fast', post_resample_rate=None, chunked=True, precision='single', source_series=None): """Apply a wavelet transform using a prespecified set of filters. Results are stored in the NWB file as a `DecompositionSeries`. @@ -332,7 +333,8 @@ def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True abs_only : bool If True, only the amplitude is stored. npad : int - Padding to add to beginning and end of timeseries. Default 0. + Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next + fastest length. post_resample_rate : float If not `None`, resample the computed wavelet amplitudes to this rate. chunked : bool diff --git a/tests/test_resample.py b/tests/test_resample.py index 33f28c4..e53e80e 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -4,6 +4,8 @@ from process_nwb.resample import resample +import pytest + def test_resample_shape(): X = np.random.randn(2000, 32) @@ -11,18 +13,20 @@ def test_resample_shape(): Xp = resample(X, 100, 200) assert Xp.shape == (1000, 32) + X = np.random.randn(2000, 32, 2) + + Xp = resample(X, 100, 200) + assert Xp.shape == (1000, 32, 2) + -def test_resample_ones(): - chs = [2, 32, 100] - ts = [999, 1000, 1001, 5077] +@pytest.mark.parametrize("ch", [2, 32, 100]) +@pytest.mark.parametrize("t", [999, 1000, 1001, 5077]) +@pytest.mark.parametrize("fr", [.5, 1, 1.5, 2]) +def test_resample_ones(ch, t, fr): rate = 200 - fracs = [.5, 1, 1.5, 2] - for ch in chs: - for t in ts: - for fr in fracs: - X = np.ones((t, ch)) - Xp = resample(X, rate * fr, rate) - assert_allclose(Xp, 1., atol=1e-3) + X = np.ones((t, ch)) + Xp = resample(X, rate * fr, rate) + assert_allclose(Xp, 1., atol=1e-3) def test_resample_low_freqs(): @@ -33,17 +37,18 @@ def test_resample_low_freqs(): t = np.linspace(0, dt, int(dt * rate)) t = np.tile(t[:, np.newaxis], (1, 5)) freqs = np.linspace(1, 5.33, 20) + phase = np.linspace(0, np.pi / 2., 5)[np.newaxis] X = np.zeros_like(t) for f in freqs: - X += np.sin(2 * np.pi * f * t) + X += np.sin(2 * np.pi * f * t + phase) new_rate = 211. # Hz t = np.linspace(0, dt, int(dt * new_rate)) t = np.tile(t[:, np.newaxis], (1, 5)) X_new_rate = np.zeros_like(t) for f in freqs: - X_new_rate += np.sin(2 * np.pi * f * t) + X_new_rate += np.sin(2 * np.pi * f * t + phase) Xds = resample(X, new_rate, rate) assert_allclose(cosine(Xds.ravel(), X_new_rate.ravel()), 0., atol=1e-3)