Skip to content

Commit

Permalink
Merge pull request #547 from grlee77/ContinuousWavelet_fixes
Browse files Browse the repository at this point in the history
ContinuousWavelet: add tests for dtype and remove unused **kwargs
  • Loading branch information
rgommers committed Mar 28, 2020
2 parents 6c7bfa4 + a10c3b6 commit 6e3641b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
3 changes: 2 additions & 1 deletion doc/source/ref/wavelets.rst
Expand Up @@ -260,13 +260,14 @@ from plain Python lists of filter coefficients and a *filter bank-like* object.
``ContinuousWavelet`` object
----------------------------

.. class:: ContinuousWavelet(name)
.. class:: ContinuousWavelet(name, dtype=np.float64)

Describes properties of a continuous wavelet identified by the specified wavelet ``name``.
In order to use a built-in wavelet the ``name`` parameter must be a valid
wavelet name from the :func:`pywt.wavelist` list.

:param name: Wavelet name
:param dtype: numpy.dtype to use for the wavelet. Can be numpy.float64 or numpy.float32.

**Example:**

Expand Down
13 changes: 7 additions & 6 deletions pywt/_extensions/_pywt.pyx
Expand Up @@ -677,23 +677,24 @@ cdef public class Wavelet [type WaveletType, object WaveletObject]:

cdef public class ContinuousWavelet [type ContinuousWaveletType, object ContinuousWaveletObject]:
"""
ContinuousWavelet(name) object describe properties of
ContinuousWavelet(name, dtype) object describe properties of
a continuous wavelet identified by name.
In order to use a built-in wavelet the parameter name must be
a valid name from the wavelist() list.
"""
#cdef readonly properties
def __cinit__(self, name=u"", dtype = None, **kwargs):
def __cinit__(self, name=u"", dtype=np.float64):
cdef object family_code, family_number

# builtin wavelet
self.name = name.lower()
if (dtype is None):
self.dt = np.float64
else:
self.dt = dtype
self.dt = dtype
if np.dtype(self.dt) not in [np.float32, np.float64]:
raise ValueError(
"Only np.float32 and np.float64 dtype are supported for "
"ContinuousWavelet objects.")
if len(self.name) >= 4 and self.name[:4] in ['cmor', 'shan', 'fbsp']:
base_name = self.name[:4]
if base_name == self.name:
Expand Down
15 changes: 15 additions & 0 deletions pywt/tests/test_cwt_wavelets.py
Expand Up @@ -115,6 +115,21 @@ def test_gaus():
assert_allclose(X, x)


@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_continuous_wavelet_dtype(dtype):
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0', dtype)
int_psi, x = pywt.integrate_wavelet(wavelet)
assert int_psi.real.dtype == dtype
assert x.dtype == dtype


def test_continuous_wavelet_invalid_dtype():
with pytest.raises(ValueError):
pywt.ContinuousWavelet('gaus5', np.complex64)
with pytest.raises(ValueError):
pywt.ContinuousWavelet('gaus5', np.int)


def test_cgau():
LB = -5
UB = 5
Expand Down

0 comments on commit 6e3641b

Please sign in to comment.