Skip to content

Commit

Permalink
fix/multitaper_psd_tests (NeuralEnsemble#529)
Browse files Browse the repository at this point in the history
* Fixed the intermittently failing test_multitaper_psd_behavior
* Added new unit tests
* Fixed multitaper_psd docstring
  • Loading branch information
rjurkus committed Nov 25, 2022
1 parent d12d80f commit 4de1232
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 21 deletions.
2 changes: 1 addition & 1 deletion elephant/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def multitaper_psd(signal, n_segments=1, len_segment=None,
Parameters
----------
signal : neo.AnalogSignal
signal : neo.AnalogSignal or pq.Quantity or np.ndarray
Time series data of which PSD is estimated. When `signal` is np.ndarray
sampling frequency should be given through keyword argument `fs`.
Signal should be passed as (n_channels, n_samples)
Expand Down
113 changes: 93 additions & 20 deletions elephant/test/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import scipy.signal as spsig
import scipy.fft
import quantities as pq
import neo.core as n
from numpy.testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -170,32 +171,66 @@ def test_welch_psd_multidim_input(self):
class MultitaperPSDTestCase(unittest.TestCase):
def test_multitaper_psd_errors(self):
# generate dummy data
signal = n.AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s,
data_length = 5000
signal = n.AnalogSignal(np.zeros(data_length),
sampling_period=0.001 * pq.s,
units='mV')
fs = 1000 * pq.Hz
nw = 3
fs = signal.sampling_rate

# check for invalid parameter values
# - number of tapers
self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal,
fs, nw, num_tapers=-5)
num_tapers=-5)
self.assertRaises(TypeError, elephant.spectral.multitaper_psd, signal,
fs, nw, num_tapers=-5.0)
# - frequency resolution
num_tapers=-5.0)
# - peak resolution
self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal,
fs, nw, peak_resolution=-1)
peak_resolution=-1)

# - frequency resolution
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
frequency_resolution=-10)

# - n per segment
# n_per_seg = int(fs / dF), where dF is the frequency_resolution
broken_freq_resolution = fs / (data_length+1)
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
frequency_resolution=broken_freq_resolution)

# - length of segment (negative)
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
len_segment=-10)

# - length of segment (larger than data length)
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
len_segment=data_length+1)

# - number of segments (negative)
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
n_segments=-10)

# - number of segments (larger than data length)
self.assertRaises(ValueError,
elephant.spectral.multitaper_psd, signal,
n_segments=data_length+1)

def test_multitaper_psd_behavior(self):
# generate data by adding white noise and a sinusoid
data_length = 5000
# generate data (frequency domain to time domain)
r = np.ones(2501) * 0.2
r[0], r[500] = 0, 10 # Zero DC, peak at 100 Hz
phi = np.random.uniform(-np.pi, np.pi, len(r))
fake_coeffs = r*np.exp(1j * phi)
fake_ts = scipy.fft.irfft(fake_coeffs)
sampling_period = 0.001
signal_freq = 100.0
noise = np.random.normal(size=data_length)
signal = [np.sin(2 * np.pi * signal_freq * t)
for t in np.arange(0, data_length * sampling_period,
sampling_period)]
data = n.AnalogSignal(np.array(signal + noise),
sampling_period=sampling_period * pq.s,
freqs = scipy.fft.rfftfreq(len(fake_ts), d=sampling_period)
signal_freq = freqs[r.argmax()]

data = n.AnalogSignal(fake_ts, sampling_period=sampling_period * pq.s,
units='mV')

# consistency between different ways of specifying number of tapers
Expand All @@ -208,14 +243,28 @@ def test_multitaper_psd_behavior(self):
num_tapers=6)
self.assertTrue((psd1 == psd2).all() and (freqs1 == freqs2).all())

# frequency resolution and consistency with data
freq_res = 1.0 * pq.Hz
# consistency between different ways of specifying n_per_seg
# n_per_seg = int(fs/dF) and n_per_seg = len_segment
frequency_resolution = 1 * pq.Hz
len_segment = int(data.sampling_rate / frequency_resolution)

freqs_fr, psd_fr = elephant.spectral.multitaper_psd(
data, frequency_resolution=frequency_resolution)

freqs_ls, psd_ls = elephant.spectral.multitaper_psd(
data, len_segment=len_segment)

np.testing.assert_array_equal(freqs_fr, freqs_ls)
np.testing.assert_array_equal(psd_fr, psd_ls)

# peak resolution and consistency with data
peak_res = 1.0 * pq.Hz
freqs, psd = elephant.spectral.multitaper_psd(
data, peak_resolution=freq_res)
data, peak_resolution=peak_res)
self.assertEqual(freqs[psd.argmax()], signal_freq)
freqs_np, psd_np = elephant.spectral.multitaper_psd(
data.magnitude.flatten(), fs=1 / sampling_period,
peak_resolution=freq_res)
peak_resolution=peak_res)
self.assertTrue((freqs == freqs_np).all() and (psd == psd_np).all())

def test_multitaper_psd_parameter_hierarchy(self):
Expand Down Expand Up @@ -306,6 +355,30 @@ def test_multitaper_psd_input_types(self):
self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity))
self.assertFalse(isinstance(psd_np, pq.quantity.Quantity))

# frequency resolution with and without units
freq_res_hz = 1 * pq.Hz
freq_res_int = 1

freqs_int, psd_int = elephant.spectral.multitaper_psd(
data, frequency_resolution=freq_res_int)

freqs_hz, psd_hz = elephant.spectral.multitaper_psd(
data, frequency_resolution=freq_res_hz)

np.testing.assert_array_equal(freqs_int, freqs_hz)
np.testing.assert_array_equal(psd_int, psd_hz)

# fs with and without units
fs_hz = 1 * pq.Hz
fs_int = 1
freqs_fs_hz, psd_fs_hz = elephant.spectral.multitaper_psd(
data.magnitude.T, fs=fs_hz)
freqs_fs_int, psd_fs_int = elephant.spectral.multitaper_psd(
data.magnitude.T, fs=fs_int)

np.testing.assert_array_equal(freqs_fs_hz, freqs_fs_int)
np.testing.assert_array_equal(psd_fs_hz, psd_fs_int)

# check if the results from different input types are identical
self.assertTrue(
(freqs_neo == freqs_pq).all() and (
Expand Down

0 comments on commit 4de1232

Please sign in to comment.