Skip to content

Commit

Permalink
Merge pull request #70 from JesseLivezey/fast_fft_len
Browse files Browse the repository at this point in the history
use fast fft lens by default and loop resample
  • Loading branch information
JesseLivezey committed Apr 14, 2022
2 parents 70acae3 + 84d224e commit 800461e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 41 deletions.
5 changes: 3 additions & 2 deletions src/process_nwb/linenoise_notch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _apply_notches(X, notches, rate, fft=True, precision='single'):
return Xp.astype(X_dtype, copy=False)


def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad=0, precision='single'):
def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad='fast', precision='single'):
"""Apply notch filters at 60 Hz (by default) and its harmonics.
Filters +/- 1 Hz around the frequencies.
Expand All @@ -68,7 +68,8 @@ def apply_linenoise_notch(X, rate, fft=True, noise_hz=60., npad=0, precision='si
noise_hz: float
Frequency to notch out
npad : int
Padding to add to beginning and end of timeseries. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
precision : str
Either `single` for float32/complex64 or `double` for float/complex.
Expand Down
32 changes: 17 additions & 15 deletions src/process_nwb/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"""


def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='single'):
def resample_func(X, num, npad='fast', pad='reflect_limited', real=True, precision='single'):
"""Resample an array. Operates along the first dimension of the array. This is the low-level
code. Users shoud likely use `resample()` rather than this function.
Expand All @@ -51,7 +51,8 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s
num : int
Number of samples in resampled signal.
npad : int
Padding to add to beginning and end of timeseries. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
pad : str
Type of padding. The default is ``'reflect_limited'``.
real : bool
Expand All @@ -74,9 +75,9 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s
n_time = X.shape[0]
ratio = float(num) / n_time
npads, to_removes, new_len = _npads(X, npad, ratio=ratio)
X = _smart_pad(X, npads, pad)

# do the resampling using an adaptation of scipy's FFT-based resample()
X = _smart_pad(X, npads, pad)
old_len = len(X)
shorter = new_len < old_len
use_len = new_len if shorter else old_len
Expand All @@ -101,7 +102,7 @@ def resample_func(X, num, npad=0, pad='reflect_limited', real=True, precision='s
return y


def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single'):
def resample(X, new_freq, old_freq, real=True, axis=0, npad='fast', precision='single', loop=True):
"""Resamples the timeseries from the original sampling frequency to a new frequency.
Parameters
Expand All @@ -117,9 +118,12 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single
axis : int
Which axis to resample.
npad : int
Padding to add to beginning and end of timeseries. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
precision : str
Either `single` for float32/complex64 or `double` for float/complex.
loop : bool
Whether or not to loop across channels when resampling.
Returns
-------
Expand All @@ -135,15 +139,12 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single
n_time = X.shape[0]
new_n_time = int(np.ceil(n_time * new_freq / old_freq))

loop = False
if X.size >= 10**8 and X.shape[1] > 1:
loop = True

if loop:
Xds = np.zeros((new_n_time,) + X.shape[1:], dtype=X_dtype)
for ii in range(X.shape[1]):
Xds[:, ii] = resample_func(X[:, [ii]], new_n_time, npad=npad, real=real,
precision=precision)[:, 0]
Xds = np.zeros((new_n_time, np.prod(X.shape[1:])), dtype=X_dtype)
for ii in range(np.prod(X.shape[1:])):
Xds[:, ii] = resample_func(X.reshape(X.shape[0], -1)[:, [ii]], new_n_time, npad=npad,
real=real, precision=precision)[:, 0]
Xds = Xds.reshape((new_n_time,) + X.shape[1:])
else:
Xds = resample_func(X, new_n_time, npad=npad, real=real, precision=precision)
if axis != 0:
Expand All @@ -152,7 +153,7 @@ def resample(X, new_freq, old_freq, real=True, axis=0, npad=0, precision='single
return Xds


def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad=0, precision='single'):
def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad='fast', precision='single'):
"""Resamples the `ElectricalSeries` from the original sampling frequency to a new frequency and
store the results in a new ElectricalSeries.
Expand All @@ -170,7 +171,8 @@ def store_resample(elec_series, processing, new_freq, axis=0, scaling=None, npad
Scale the values by this. Can help with accuracy of downstream operations if the raw values
are too small. Default = 1e6.
npad : int
Padding to add to beginning and end of timeseries. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
precision : str
Either `single` for float32/complex64 or `double` for float/complex.
Expand Down
20 changes: 13 additions & 7 deletions src/process_nwb/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.fft import next_fast_len
from datetime import datetime
from dateutil.tz import tzlocal

Expand Down Expand Up @@ -106,15 +107,20 @@ def _npads(X, npad, ratio=1.):
"""Calculate padding parameters.
"""
n_time = X.shape[0]
bad_msg = 'npad must be "auto" or an integer'
bad_msg = 'npad must be "auto", "fast", or an integer'
if isinstance(npad, str):
if npad != 'auto':
if npad == 'auto':
# Figure out reasonable pad that gets us to a power of 2
min_add = min(n_time // 8, 100) * 2
npad = 2 ** int(np.ceil(np.log2(n_time + min_add))) - n_time
npad, extra = divmod(npad, 2)
npads = np.array([npad, npad + extra], int)
elif npad == 'fast':
npad = next_fast_len(n_time) - n_time
npad, extra = divmod(npad, 2)
npads = np.array([npad, npad + extra], int)
else:
raise ValueError(bad_msg)
# Figure out reasonable pad that gets us to a power of 2
min_add = min(n_time // 8, 100) * 2
npad = 2 ** int(np.ceil(np.log2(n_time + min_add))) - n_time
npad, extra = divmod(npad, 2)
npads = np.array([npad, npad + extra], int)
else:
if npad != int(npad):
raise ValueError(bad_msg)
Expand Down
12 changes: 7 additions & 5 deletions src/process_nwb/wavelet_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(self, X, rate, filters='rat', npad=None, hg_only=True, post_resampl

# Need to pad X before predicting chunk and filter shape:
self.npads, self.to_removes, _ = _npads(X, npad)
self.wavelet_time = X.shape[0] + 2 * npad
self.wavelet_time = X.shape[0] + self.npads.sum()
self.filterbank, self.cfs, self.sds = get_filterbank(filters, self.wavelet_time, self.rate,
hg_only, precision=self.precision)
self.resample_time = self.X.shape[0]
Expand Down Expand Up @@ -231,7 +231,7 @@ def recommended_data_shape(self):
return (self.resample_time, self.nch, self.nbands)


def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0, to_removes=None,
def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad='fast', to_removes=None,
precision='single'):
"""Apply a wavelet transform using a prespecified set of filters.
Expand All @@ -256,7 +256,8 @@ def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0
Precomputed product of X_fft and heavyside. Useful for when bands are computed
independently.
npad : int
Length of padding in samples. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
to_removes : int
Number of samples to remove at the beginning and end of the timeseries. Default None.
precision : str
Expand Down Expand Up @@ -307,7 +308,7 @@ def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad=0


def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True, abs_only=True,
npad=0, post_resample_rate=None, chunked=True, precision='single',
npad='fast', post_resample_rate=None, chunked=True, precision='single',
source_series=None):
"""Apply a wavelet transform using a prespecified set of filters. Results are stored in the
NWB file as a `DecompositionSeries`.
Expand All @@ -332,7 +333,8 @@ def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True
abs_only : bool
If True, only the amplitude is stored.
npad : int
Padding to add to beginning and end of timeseries. Default 0.
Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next
fastest length.
post_resample_rate : float
If not `None`, resample the computed wavelet amplitudes to this rate.
chunked : bool
Expand Down
29 changes: 17 additions & 12 deletions tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,29 @@

from process_nwb.resample import resample

import pytest


def test_resample_shape():
X = np.random.randn(2000, 32)

Xp = resample(X, 100, 200)
assert Xp.shape == (1000, 32)

X = np.random.randn(2000, 32, 2)

Xp = resample(X, 100, 200)
assert Xp.shape == (1000, 32, 2)


def test_resample_ones():
chs = [2, 32, 100]
ts = [999, 1000, 1001, 5077]
@pytest.mark.parametrize("ch", [2, 32, 100])
@pytest.mark.parametrize("t", [999, 1000, 1001, 5077])
@pytest.mark.parametrize("fr", [.5, 1, 1.5, 2])
def test_resample_ones(ch, t, fr):
rate = 200
fracs = [.5, 1, 1.5, 2]
for ch in chs:
for t in ts:
for fr in fracs:
X = np.ones((t, ch))
Xp = resample(X, rate * fr, rate)
assert_allclose(Xp, 1., atol=1e-3)
X = np.ones((t, ch))
Xp = resample(X, rate * fr, rate)
assert_allclose(Xp, 1., atol=1e-3)


def test_resample_low_freqs():
Expand All @@ -33,17 +37,18 @@ def test_resample_low_freqs():
t = np.linspace(0, dt, int(dt * rate))
t = np.tile(t[:, np.newaxis], (1, 5))
freqs = np.linspace(1, 5.33, 20)
phase = np.linspace(0, np.pi / 2., 5)[np.newaxis]

X = np.zeros_like(t)
for f in freqs:
X += np.sin(2 * np.pi * f * t)
X += np.sin(2 * np.pi * f * t + phase)

new_rate = 211. # Hz
t = np.linspace(0, dt, int(dt * new_rate))
t = np.tile(t[:, np.newaxis], (1, 5))
X_new_rate = np.zeros_like(t)
for f in freqs:
X_new_rate += np.sin(2 * np.pi * f * t)
X_new_rate += np.sin(2 * np.pi * f * t + phase)

Xds = resample(X, new_rate, rate)
assert_allclose(cosine(Xds.ravel(), X_new_rate.ravel()), 0., atol=1e-3)
Expand Down

0 comments on commit 800461e

Please sign in to comment.