Skip to content

Commit

Permalink
Butterworth supports sosfiltfilt filter_function (NeuralEnsemble#234)
Browse files Browse the repository at this point in the history
* Butterworth supports sosfiltfilt filter_function

* higher order filters comment
  • Loading branch information
dizcza authored and mdenker committed Jul 10, 2019
1 parent 5dfdab9 commit 3c30574
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 48 deletions.
111 changes: 65 additions & 46 deletions elephant/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,16 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
order : int
Order of Butterworth filter. Default is 4.
filter_function : string
Filtering function to be used. Either 'filtfilt'
(`scipy.signal.filtfilt()`) or 'lfilter' (`scipy.signal.lfilter()`). In
most applications 'filtfilt' should be used, because it doesn't bring
about phase shift due to filtering. Default is 'filtfilt'.
Filtering function to be used. Available filters:
* 'filtfilt': `scipy.signal.filtfilt()`;
* 'lfilter': `scipy.signal.lfilter()`;
* 'sosfiltfilt': `scipy.signal.sosfiltfilt()`.
In most applications 'filtfilt' should be used, because it doesn't
bring about phase shift due to filtering. For numerically stable
filtering, in particular higher order filters, use 'sosfiltfilt'
(see issue
https://github.com/NeuralEnsemble/elephant/issues/220).
Default is 'filtfilt'.
fs : Quantity or float
The sampling frequency of the input time series. When given as float,
its value is taken as frequency in Hz. When the input is given as neo
Expand All @@ -322,42 +328,53 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
Filtered input data. The shape and type is identical to those of the
input.
"""

def _design_butterworth_filter(Fs, hpfreq=None, lpfreq=None, order=4):
# set parameters for filter design
Fn = Fs / 2.
# - filter type is determined according to the values of cut-off
# frequencies
if lpfreq and hpfreq:
if hpfreq < lpfreq:
Wn = (hpfreq / Fn, lpfreq / Fn)
btype = 'bandpass'
else:
Wn = (lpfreq / Fn, hpfreq / Fn)
btype = 'bandstop'
elif lpfreq:
Wn = lpfreq / Fn
btype = 'lowpass'
elif hpfreq:
Wn = hpfreq / Fn
btype = 'highpass'
else:
raise ValueError(
"Either highpass_freq or lowpass_freq must be given"
)

# return filter coefficients
return scipy.signal.butter(order, Wn, btype=btype)
Raises
------
ValueError
If `filter_function` is not one of 'lfilter', 'filtfilt',
or 'sosfiltfilt'.
When both `highpass_freq` and `lowpass_freq` are None.
"""
available_filters = 'lfilter', 'filtfilt', 'sosfiltfilt'
if filter_function not in available_filters:
raise ValueError("Invalid `filter_function`: {filter_function}. "
"Available filters: {available_filters}".format(
filter_function=filter_function,
available_filters=available_filters))
# design filter
Fs = signal.sampling_rate.rescale(pq.Hz).magnitude \
if hasattr(signal, 'sampling_rate') else fs
Fh = highpass_freq.rescale(pq.Hz).magnitude \
if isinstance(highpass_freq, pq.quantity.Quantity) else highpass_freq
Fl = lowpass_freq.rescale(pq.Hz).magnitude \
if isinstance(lowpass_freq, pq.quantity.Quantity) else lowpass_freq
b, a = _design_butterworth_filter(Fs, Fh, Fl, order)
if hasattr(signal, 'sampling_rate'):
fs = signal.sampling_rate.rescale(pq.Hz).magnitude
if isinstance(highpass_freq, pq.quantity.Quantity):
highpass_freq = highpass_freq.rescale(pq.Hz).magnitude
if isinstance(lowpass_freq, pq.quantity.Quantity):
lowpass_freq = lowpass_freq.rescale(pq.Hz).magnitude
Fn = fs / 2.
# filter type is determined according to the values of cut-off
# frequencies
if lowpass_freq and highpass_freq:
if highpass_freq < lowpass_freq:
Wn = (highpass_freq / Fn, lowpass_freq / Fn)
btype = 'bandpass'
else:
Wn = (lowpass_freq / Fn, highpass_freq / Fn)
btype = 'bandstop'
elif lowpass_freq:
Wn = lowpass_freq / Fn
btype = 'lowpass'
elif highpass_freq:
Wn = highpass_freq / Fn
btype = 'highpass'
else:
raise ValueError(
"Either highpass_freq or lowpass_freq must be given"
)
if filter_function == 'sosfiltfilt':
output = 'sos'
else:
output = 'ba'
designed_filter = scipy.signal.butter(order, Wn, btype=btype,
output=output)

# When the input is AnalogSignal, the axis for time index (i.e. the
# first axis) needs to be rolled to the last
Expand All @@ -366,17 +383,19 @@ def _design_butterworth_filter(Fs, hpfreq=None, lpfreq=None, order=4):
data = np.rollaxis(data, 0, len(data.shape))

# apply filter
if filter_function is 'lfilter':
filtered_data = scipy.signal.lfilter(b, a, data, axis=axis)
elif filter_function is 'filtfilt':
filtered_data = scipy.signal.filtfilt(b, a, data, axis=axis)
if filter_function == 'lfilter':
b, a = designed_filter
filtered_data = scipy.signal.lfilter(b=b, a=a, x=data, axis=axis)
elif filter_function == 'filtfilt':
b, a = designed_filter
filtered_data = scipy.signal.filtfilt(b=b, a=a, x=data, axis=axis)
else:
raise ValueError(
"filter_func must to be either 'filtfilt' or 'lfilter'"
)
filtered_data = scipy.signal.sosfiltfilt(sos=designed_filter,
x=data, axis=axis)

if isinstance(signal, neo.AnalogSignal):
return signal.duplicate_with_new_data(np.rollaxis(filtered_data, -1, 0))
filtered_data = np.rollaxis(filtered_data, -1, 0)
return signal.duplicate_with_new_data(filtered_data)
elif isinstance(signal, pq.quantity.Quantity):
return filtered_data * signal.units
else:
Expand Down
15 changes: 13 additions & 2 deletions elephant/test/test_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,18 @@ def test_butter_filter_type(self):
self.assertAlmostEqual(psd[0, 256], 0)

def test_butter_filter_function(self):
"""
`elephant.signal_processing.butter` return values test for all
available filters (result has to be almost equal):
* lfilter
* filtfilt
* sosfiltfilt
"""
# generate white noise AnalogSignal
noise = neo.AnalogSignal(
np.random.normal(size=5000),
sampling_rate=1000 * pq.Hz, units='mV')

# test if the filter performance is as well with filftunc=lfilter as
# with filtfunc=filtfilt (i.e. default option)
kwds = {'signal': noise, 'highpass_freq': 250.0 * pq.Hz,
'lowpass_freq': None, 'filter_function': 'filtfilt'}
filtered_noise = elephant.signal_processing.butter(**kwds)
Expand All @@ -398,7 +403,13 @@ def test_butter_filter_function(self):
_, psd_lfilter = spsig.welch(
filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)

kwds['filter_function'] = 'sosfiltfilt'
filtered_noise = elephant.signal_processing.butter(**kwds)
_, psd_sosfiltfilt = spsig.welch(
filtered_noise.T, nperseg=1024, fs=1000.0, detrend=lambda x: x)

self.assertAlmostEqual(psd_filtfilt[0, 0], psd_lfilter[0, 0])
self.assertAlmostEqual(psd_filtfilt[0, 0], psd_sosfiltfilt[0, 0])

def test_butter_invalid_filter_function(self):
# generate a dummy AnalogSignal
Expand Down

0 comments on commit 3c30574

Please sign in to comment.