Skip to content

Commit

Permalink
Merge pull request #67 from grlee77/wavedecn_v2
Browse files Browse the repository at this point in the history
ENH: add wavedecn and waverecn functions
  • Loading branch information
Kai committed Dec 17, 2015
2 parents 136a029 + af75758 commit 2a139a3
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 59 deletions.
32 changes: 29 additions & 3 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
raise ValueError("Expected 2 axes")

coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

return idwtn(coeffs, wavelet, mode, axes)


Expand Down Expand Up @@ -168,6 +172,29 @@ def dwtn(data, wavelet, mode='symmetric', axes=None):
return dict(coeffs)


def _fix_coeffs(coeffs):
missing_keys = [k for k, v in coeffs.items() if
v is None]
if missing_keys:
raise ValueError(
"The following detail coefficients were set to None: "
"{}.".format(missing_keys))

invalid_keys = [k for k, v in coeffs.items() if
not set(k) <= set('ad')]
if invalid_keys:
raise ValueError(
"The following invalid keys were found in the detail "
"coefficient dictionary: {}.".format(invalid_keys))

key_lengths = [len(k) for k in coeffs.keys()]
if len(np.unique(key_lengths)) > 1:
raise ValueError(
"All detail coefficient names must have equal length.")

return dict((k, np.asarray(v)) for k, v in coeffs.items())


def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
"""
Single-level n-dimensional Inverse Discrete Wavelet Transform.
Expand Down Expand Up @@ -200,9 +227,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
wavelet = Wavelet(wavelet)
mode = Modes.from_object(mode)

# Ignore any invalid keys
coeffs = dict((k, np.asarray(v)) for k, v in coeffs.items()
if v is not None and set(k) <= set('ad'))
# Raise error for invalid key combinations
coeffs = _fix_coeffs(coeffs)

if any(np.iscomplexobj(v) for v in coeffs.values()):
real_coeffs = dict((k, v.real) for k, v in coeffs.items())
Expand Down
235 changes: 213 additions & 22 deletions pywt/_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,28 @@

from ._pywt import Wavelet
from ._pywt import dwt, idwt, dwt_max_level
from ._multidim import dwt2, idwt2
from ._multidim import dwt2, idwt2, dwtn, idwtn, _fix_coeffs

__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2', 'iswt', 'iswt2']
__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2', 'wavedecn',
'waverecn', 'iswt', 'iswt2']


def _check_level(size, dec_len, level):
"""
Set the default decomposition level or check if requested level is valid.
"""
if level is None:
level = dwt_max_level(size, dec_len)
elif level < 0:
raise ValueError(
"Level value of %d is too low . Minimum level is 0." % level)
else:
max_level = dwt_max_level(size, dec_len)
if level > max_level:
raise ValueError(
"Level value of %d is too high. Maximum allowed is %d." % (
level, max_level))
return level


def wavedec(data, wavelet, mode='symmetric', level=None):
Expand All @@ -32,8 +51,8 @@ def wavedec(data, wavelet, mode='symmetric', level=None):
mode : str, optional
Signal extension mode, see Modes (default: 'symmetric')
level : int, optional
Decomposition level. If level is None (default) then it will be
calculated using `dwt_max_level` function.
Decomposition level (must be >= 0). If level is None (default) then it
will be calculated using the ``dwt_max_level`` function.
Returns
-------
Expand All @@ -56,15 +75,12 @@ def wavedec(data, wavelet, mode='symmetric', level=None):
array([ 5., 13.])
"""
data = np.asarray(data)

if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)

if level is None:
level = dwt_max_level(len(data), wavelet.dec_len)
elif level < 0:
raise ValueError(
"Level value of %d is too low . Minimum level is 0." % level)
level = _check_level(min(data.shape), wavelet.dec_len, level)

coeffs_list = []

Expand Down Expand Up @@ -103,9 +119,12 @@ def waverec(coeffs, wavelet, mode='symmetric'):
if not isinstance(coeffs, (list, tuple)):
raise ValueError("Expected sequence of coefficient arrays.")

if len(coeffs) < 2:
if len(coeffs) < 1:
raise ValueError(
"Coefficient list too short (minimum 2 arrays required).")
"Coefficient list too short (minimum 1 arrays required).")
elif len(coeffs) == 1:
# level 0 transform (just returns the approximation coefficients)
return coeffs[0]

a, ds = coeffs[0], coeffs[1:]

Expand All @@ -130,8 +149,8 @@ def wavedec2(data, wavelet, mode='symmetric', level=None):
mode : str, optional
Signal extension mode, see Modes (default: 'symmetric')
level : int, optional
Decomposition level. If level is None (default) then it will be
calculated using `dwt_max_level` function.
Decomposition level (must be >= 0). If level is None (default) then it
will be calculated using the ``dwt_max_level`` function.
Returns
-------
Expand All @@ -151,7 +170,6 @@ def wavedec2(data, wavelet, mode='symmetric', level=None):
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]])
"""

data = np.asarray(data)

if data.ndim != 2:
Expand All @@ -160,12 +178,7 @@ def wavedec2(data, wavelet, mode='symmetric', level=None):
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)

if level is None:
size = min(data.shape)
level = dwt_max_level(size, wavelet.dec_len)
elif level < 0:
raise ValueError(
"Level value of %d is too low . Minimum level is 0." % level)
level = _check_level(min(data.shape), wavelet.dec_len, level)

coeffs_list = []

Expand Down Expand Up @@ -212,9 +225,12 @@ def waverec2(coeffs, wavelet, mode='symmetric'):
if not isinstance(coeffs, (list, tuple)):
raise ValueError("Expected sequence of coefficient arrays.")

if len(coeffs) < 2:
if len(coeffs) < 1:
raise ValueError(
"Coefficient list too short (minimum 2 arrays required).")
"Coefficient list too short (minimum 1 array required).")
elif len(coeffs) == 1:
# level 0 transform (just returns the approximation coefficients)
return coeffs[0]

a, ds = coeffs[0], coeffs[1:]
a = np.asarray(a)
Expand Down Expand Up @@ -399,3 +415,178 @@ def iswt2(coeffs, wavelet):
output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4

return output


def wavedecn(data, wavelet, mode='symmetric', level=None):
"""
Multilevel nD Discrete Wavelet Transform.
Parameters
----------
data : ndarray
nD input data
wavelet : Wavelet object or name string
Wavelet to use
mode : str, optional
Signal extension mode, see MODES (default: 'sym')
level : int, optional
Dxecomposition level (must be >= 0). If level is None (default) then it
will be calculated using the ``dwt_max_level`` function.
Returns
-------
[cAn, {details_level_n}, ... {details_level_1}] : list
Coefficients list
Examples
--------
>>> from pywt import multilevel
>>> coeffs = multilevel.wavedecn(np.ones((4, 4, 4)), 'db1')
>>> # Levels:
>>> len(coeffs)-1
3
>>> multilevel.waverecn(coeffs, 'db1')
array([[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]]])
"""
data = np.asarray(data)

if len(data.shape) < 1:
raise ValueError("Expected at least 1D input data.")

if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)

level = _check_level(min(data.shape), wavelet.dec_len, level)
coeffs_list = []

a = data
for i in range(level):
coeffs = dwtn(a, wavelet, mode)
a = coeffs.pop('a' * data.ndim)
coeffs_list.append(coeffs)

coeffs_list.append(a)
coeffs_list.reverse()

return coeffs_list


def _match_coeff_dims(a_coeff, d_coeff_dict):
# For each axis, compare the approximation coeff shape to one of the
# stored detail coeffs and truncate the last element along the axis
# if necessary.
if a_coeff is None:
return None
if not d_coeff_dict:
return a_coeff
d_coeff = d_coeff_dict[next(iter(d_coeff_dict))]
size_diffs = np.subtract(a_coeff.shape, d_coeff.shape)
if np.any((size_diffs < 0) | (size_diffs > 1)):
raise ValueError("incompatible coefficient array sizes")
return a_coeff[[slice(s) for s in d_coeff.shape]]


def waverecn(coeffs, wavelet, mode='symmetric'):
"""
Multilevel nD Inverse Discrete Wavelet Transform.
coeffs : array_like
Coefficients list [cAn, {details_level_n}, ... {details_level_1}]
wavelet : Wavelet object or name string
Wavelet to use
mode : str, optional
Signal extension mode, see MODES (default: 'sym')
Returns
-------
nD array of reconstructed data.
Examples
--------
>>> from pywt import multilevel
>>> coeffs = multilevel.wavedecn(np.ones((4, 4, 4)), 'db1')
>>> # Levels:
>>> len(coeffs)-1
2
>>> multilevel.waverecn(coeffs, 'db1')
array([[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]]])
"""
if len(coeffs) < 1:
raise ValueError(
"Coefficient list too short (minimum 1 array required).")

a, ds = coeffs[0], coeffs[1:]

# Raise error for invalid key combinations
ds = list(map(_fix_coeffs, ds))

if not ds:
# level 0 transform (just returns the approximation coefficients)
return coeffs[0]
if a is None and not any(ds):
raise ValueError("At least one coefficient must contain a valid value.")

coeff_ndims = []
if a is not None:
a = np.asarray(a)
coeff_ndims.append(a.ndim)
for d in ds:
coeff_ndims += [v.ndim for k, v in d.items()]

# test that all coefficients have a matching number of dimensions
unique_coeff_ndims = np.unique(coeff_ndims)
if len(unique_coeff_ndims) == 1:
ndim = unique_coeff_ndims[0]
else:
raise ValueError(
"All coefficients must have a matching number of dimensions")

for idx, d in enumerate(ds):
if a is None and not d:
continue
# The following if statement handles the case where the approximation
# coefficient returned at the previous level may exceed the size of the
# stored detail coefficients by 1 on any given axis.
if idx > 0:
a = _match_coeff_dims(a, d)
d['a' * ndim] = a
a = idwtn(d, wavelet, mode)

return a
13 changes: 8 additions & 5 deletions pywt/tests/test__pywt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def test_compare_downcoef_coeffs():
# compare downcoef against wavedec outputs
for nlevels in [1, 2, 3]:
for wavelet in pywt.wavelist():
a = pywt.downcoef('a', r, wavelet, level=nlevels)
d = pywt.downcoef('d', r, wavelet, level=nlevels)
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
assert_allclose(a, coeffs[0])
assert_allclose(d, coeffs[1])
wavelet = pywt.Wavelet(wavelet)
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
if nlevels <= max_level:
a = pywt.downcoef('a', r, wavelet, level=nlevels)
d = pywt.downcoef('d', r, wavelet, level=nlevels)
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
assert_allclose(a, coeffs[0])
assert_allclose(d, coeffs[1])


def test_upcoef_multilevel():
Expand Down

0 comments on commit 2a139a3

Please sign in to comment.