Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ContinuousWavelet: add tests for dtype and remove unused **kwargs #547

Merged
merged 2 commits into from Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/ref/wavelets.rst
Expand Up @@ -260,13 +260,14 @@ from plain Python lists of filter coefficients and a *filter bank-like* object.
``ContinuousWavelet`` object
----------------------------

.. class:: ContinuousWavelet(name)
.. class:: ContinuousWavelet(name, dtype=np.float64)

Describes properties of a continuous wavelet identified by the specified wavelet ``name``.
In order to use a built-in wavelet the ``name`` parameter must be a valid
wavelet name from the :func:`pywt.wavelist` list.

:param name: Wavelet name
:param dtype: numpy.dtype to use for the wavelet. Can be numpy.float64 or numpy.float32.

**Example:**

Expand Down
13 changes: 7 additions & 6 deletions pywt/_extensions/_pywt.pyx
Expand Up @@ -677,23 +677,24 @@ cdef public class Wavelet [type WaveletType, object WaveletObject]:

cdef public class ContinuousWavelet [type ContinuousWaveletType, object ContinuousWaveletObject]:
"""
ContinuousWavelet(name) object describe properties of
ContinuousWavelet(name, dtype) object describe properties of
a continuous wavelet identified by name.

In order to use a built-in wavelet the parameter name must be
a valid name from the wavelist() list.

"""
#cdef readonly properties
def __cinit__(self, name=u"", dtype = None, **kwargs):
def __cinit__(self, name=u"", dtype=np.float64):
cdef object family_code, family_number

# builtin wavelet
self.name = name.lower()
if (dtype is None):
self.dt = np.float64
else:
self.dt = dtype
self.dt = dtype
if np.dtype(self.dt) not in [np.float32, np.float64]:
raise ValueError(
"Only np.float32 and np.float64 dtype are supported for "
"ContinuousWavelet objects.")
if len(self.name) >= 4 and self.name[:4] in ['cmor', 'shan', 'fbsp']:
base_name = self.name[:4]
if base_name == self.name:
Expand Down
15 changes: 15 additions & 0 deletions pywt/tests/test_cwt_wavelets.py
Expand Up @@ -115,6 +115,21 @@ def test_gaus():
assert_allclose(X, x)


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_continuous_wavelet_dtype(dtype):
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0', dtype)
int_psi, x = pywt.integrate_wavelet(wavelet)
assert int_psi.real.dtype == dtype
assert x.dtype == dtype


def test_continuous_wavelet_invalid_dtype():
with pytest.raises(ValueError):
pywt.ContinuousWavelet('gaus5', np.complex64)
with pytest.raises(ValueError):
pywt.ContinuousWavelet('gaus5', np.int)


def test_cgau():
LB = -5
UB = 5
Expand Down