Skip to content

Commit

Permalink
Merge pull request #93 from ThomasA/dev_iswt_iswt2
Browse files Browse the repository at this point in the history
Added implementation of iswt and iswt2
  • Loading branch information
aaren committed Sep 26, 2015
2 parents bb00e0b + 9343316 commit b9146ab
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 7 deletions.
166 changes: 163 additions & 3 deletions pywt/multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

from __future__ import division, print_function, absolute_import

__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2']

import numpy as np

from ._pywt import Wavelet
from ._pywt import dwt, idwt, dwt_max_level
from .multidim import dwt2, idwt2

__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2', 'iswt', 'iswt2']


def wavedec(data, wavelet, mode='sym', level=None):
"""
Expand Down Expand Up @@ -182,7 +182,7 @@ def waverec2(coeffs, wavelet, mode='sym'):
"""
Multilevel 2D Inverse Discrete Wavelet Transform.
coeffs : array_like
coeffs : list or tuple
Coefficients list [cAn, (cHn, cVn, cDn), ... (cH1, cV1, cD1)]
wavelet : Wavelet object or name string
Wavelet to use
Expand Down Expand Up @@ -220,3 +220,163 @@ def waverec2(coeffs, wavelet, mode='sym'):
a = idwt2((a, d), wavelet, mode)

return a


def iswt(coeffs, wavelet):
"""
Multilevel 1D Inverse Discrete Stationary Wavelet Transform.
Parameters
----------
coeffs : array_like
Coefficients list of tuples::
[(cA1, cD1), (cA2, cD2), ..., (cAn, cDn)]
where cA is approximation, cD is details, and n is start_level.
wavelet : Wavelet object or name string
Wavelet to use
Returns
-------
1D array of reconstructed data.
Examples
--------
>>> import pywt
>>> coeffs = pywt.swt([1,2,3,4,5,6,7,8], 'db2', level=2)
>>> pywt.iswt(coeffs, 'db2')
array([ 1., 2., 3., 4., 5., 6., 7., 8.])
"""

output = coeffs[0][0].copy() # Avoid modification of input data

# num_levels, equivalent to the decomposition level, n
num_levels = len(coeffs)
for j in range(num_levels,0,-1):
step_size = int(pow(2, j-1))
last_index = step_size
_, cD = coeffs[num_levels - j]
for first in range(last_index): # 0 to last_index - 1

# Getting the indices that we will transform
indices = np.arange(first, len(cD), step_size)

# select the even indices
even_indices = indices[0::2]
# select the odd indices
odd_indices = indices[1::2]

# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
x1 = idwt(output[even_indices], cD[even_indices], wavelet, 'per')
x2 = idwt(output[odd_indices], cD[odd_indices], wavelet, 'per')

# perform a circular shift right
x2 = np.roll(x2, 1)

# average and insert into the correct indices
output[indices] = (x1 + x2)/2.

return output


def iswt2(coeffs, wavelet):
"""
Multilevel 2D Inverse Discrete Stationary Wavelet Transform.
Parameters
----------
coeffs : list
Approximation and details coefficients::
[
(cA_1,
(cH_1, cV_1, cD_1)
),
(cA_2,
(cH_2, cV_2, cD_2)
),
...,
(cA_n
(cH_n, cV_n, cD_n)
)
]
where cA is approximation, cH is horizontal details, cV is
vertical details, cD is diagonal details and n is number of
levels.
wavelet : Wavelet object or name string
Wavelet to use
Returns
-------
2D array of reconstructed data.
Examples
--------
>>> import pywt
>>> coeffs = coeffs = pywt.swt2([[1,2,3,4],[5,6,7,8],
[9,10,11,12],[13,14,15,16]],
'db1', level=2)
>>> pywt.iswt2(coeffs, 'db1')
array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]])
"""

output = coeffs[-1][0].copy() # Avoid modification of input data

# num_levels, equivalent to the decomposition level, n
num_levels = len(coeffs)
for j in range(num_levels,0,-1):
step_size = int(pow(2, j-1))
last_index = step_size
_, (cH, cV, cD) = coeffs[j-1]
# We are going to assume cH, cV, and cD are square and of equal size
if (cH.shape != cV.shape) or (cH.shape != cD.shape) or (cH.shape[0] != cH.shape[1]):
raise RuntimeError("Mismatch in shape of intermediate coefficient arrays")
for first_h in range(last_index): # 0 to last_index - 1
for first_w in range(last_index): # 0 to last_index - 1
# Getting the indices that we will transform
indices_h = slice(first_h, cH.shape[0], step_size)
indices_w = slice(first_w, cH.shape[1], step_size)

even_idx_h = slice(first_h, cH.shape[0], 2*step_size)
even_idx_w = slice(first_w, cH.shape[1], 2*step_size)
odd_idx_h = slice(first_h + step_size, cH.shape[0], 2*step_size)
odd_idx_w = slice(first_w + step_size, cH.shape[1], 2*step_size)

# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
x1 = idwt2((output[even_idx_h, even_idx_w],
(cH[even_idx_h, even_idx_w],
cV[even_idx_h, even_idx_w],
cD[even_idx_h, even_idx_w])),
wavelet, 'per')
x2 = idwt2((output[even_idx_h, odd_idx_w],
(cH[even_idx_h, odd_idx_w],
cV[even_idx_h, odd_idx_w],
cD[even_idx_h, odd_idx_w])),
wavelet, 'per')
x3 = idwt2((output[odd_idx_h, even_idx_w],
(cH[odd_idx_h, even_idx_w],
cV[odd_idx_h, even_idx_w],
cD[odd_idx_h, even_idx_w])),
wavelet, 'per')
x4 = idwt2((output[odd_idx_h, odd_idx_w],
(cH[odd_idx_h, odd_idx_w],
cV[odd_idx_h, odd_idx_w],
cD[odd_idx_h, odd_idx_w])),
wavelet, 'per')

# perform a circular shifts
x2 = np.roll(x2, 1, axis=1)
x3 = np.roll(x3, 1, axis=0)
x4 = np.roll(x4, 1, axis=0)
x4 = np.roll(x4, 1, axis=1)
output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4

return output
64 changes: 60 additions & 4 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def test_swt_decomposition():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
assert_allclose(cA1, [7.07106781, 5.65685425, 2.82842712, 0.70710678,
2.82842712, 7.07106781, 7.07106781, 6.36396103])
assert_allclose(cD1, [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
-5.65685425, 1.41421356, -1.41421356, 2.12132034])
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
2.82842712, 7.07106781, 7.07106781, 6.36396103]
assert_allclose(cA1, expected_cA1)
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
assert_allclose(cD1, expected_cD1)
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
assert_allclose(cA2, expected_cA2, rtol=1e-12)
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
Expand All @@ -62,6 +64,33 @@ def test_swt_decomposition():
assert_(pywt.swt_max_level(len(x)) == 3)


def test_swt_iswt_integration():
# This function performs a round-trip swt/iswt transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt or iswt as such, but it does ensure
# that they are each other's inverse.

max_level = 3
wavelets = pywt.wavelist()
if 'dmey' in wavelets:
# The 'dmey' wavelet seems to be a bit special - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length)
coeffs = pywt.swt(X, current_wavelet, max_level)
Y = pywt.iswt(coeffs, current_wavelet)
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)


def test_swt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
Expand All @@ -80,6 +109,33 @@ def test_swt_dtypes():
"swt2: " + errmsg)


def test_swt2_iswt2_integration():
# This function performs a round-trip swt2/iswt2 transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt2 or iswt2 as such, but it does ensure
# that they are each other's inverse.

max_level = 3
wavelets = pywt.wavelist()
if 'dmey' in wavelets:
# The 'dmey' wavelet seems to be a bit special - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length**2).reshape(input_length, input_length)
coeffs = pywt.swt2(X, current_wavelet, max_level)
Y = pywt.iswt2(coeffs, current_wavelet)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)


def test_wavedec2():
coeffs = pywt.wavedec2(np.ones((4, 4)), 'db1')
assert_(len(coeffs) == 3)
Expand Down

0 comments on commit b9146ab

Please sign in to comment.