Skip to content

Commit

Permalink
Merge pull request #98 from grlee77/downcoef_fix2
Browse files Browse the repository at this point in the history
BUG: fix downcoef detail coefficients for level > 1
  • Loading branch information
rgommers committed Nov 22, 2015
2 parents da1c6b4 + 9562db8 commit e2426ae
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
20 changes: 15 additions & 5 deletions pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,10 @@ def _upcoef(part, np.ndarray[data_t, ndim=1, mode="c"] coeffs, wavelet,
# reconstruct
rec = np.zeros(rec_len, dtype=coeffs.dtype)

# To mirror multi-level wavelet reconstruction behaviour, when detail
# reconstruction is requested, the dec_d variant is only called at the
# first level to generate the approximation coefficients at the second
# level. Subsequent levels apply the reconstruction filter.
if do_rec_a:
if data_t is np.float64_t:
if c_wt.double_rec_a(&coeffs[0], coeffs.size, w.w,
Expand All @@ -1112,13 +1116,14 @@ def _upcoef(part, np.ndarray[data_t, ndim=1, mode="c"] coeffs, wavelet,
if data_t is np.float64_t:
if c_wt.double_rec_d(&coeffs[0], coeffs.size, w.w,
&rec[0], rec.size) < 0:
raise RuntimeError("C rec_a failed.")
raise RuntimeError("C rec_d failed.")
elif data_t is np.float32_t:
if c_wt.float_rec_d(&coeffs[0], coeffs.size, w.w,
&rec[0], rec.size) < 0:
raise RuntimeError("C rec_a failed.")
raise RuntimeError("C rec_d failed.")
else:
raise RuntimeError("Invalid data type.")
# switch to approximation filter for subsequent levels
do_rec_a = 1

# TODO: this algorithm needs some explaining
Expand Down Expand Up @@ -1204,7 +1209,12 @@ def _downcoef(part, np.ndarray[data_t, ndim=1, mode="c"] data,
raise RuntimeError("Invalid output length.")
coeffs = np.zeros(output_len, dtype=data.dtype)

if do_dec_a:
# To mirror multi-level wavelet decomposition behaviour, when detail
# coefficients are requested, the dec_d variant is only called at the
# final level. All prior levels use dec_a. In other words, the detail
# coefficients at level n are those produced via the operation of the
# detail filter on the approximation coefficients of level n-1.
if do_dec_a or (i < level - 1):
if data_t is np.float64_t:
if c_wt.double_dec_a(&data[0], data.size, w.w,
&coeffs[0], coeffs.size, mode_) < 0:
Expand All @@ -1219,11 +1229,11 @@ def _downcoef(part, np.ndarray[data_t, ndim=1, mode="c"] data,
if data_t is np.float64_t:
if c_wt.double_dec_d(&data[0], data.size, w.w,
&coeffs[0], coeffs.size, mode_) < 0:
raise RuntimeError("C dec_a failed.")
raise RuntimeError("C dec_d failed.")
elif data_t is np.float32_t:
if c_wt.float_dec_d(&data[0], data.size, w.w,
&coeffs[0], coeffs.size, mode_) < 0:
raise RuntimeError("C dec_a failed.")
raise RuntimeError("C dec_d failed.")
else:
raise RuntimeError("Invalid data type.")
data = coeffs
Expand Down
23 changes: 19 additions & 4 deletions pywt/tests/test__pywt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,41 @@ def test_upcoef_reconstruct():


def test_downcoef_multilevel():
r = np.random.randn(16)
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.downcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.downcoef('a', r, 'haar', level=3)
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)


def test_compare_downcoef_coeffs():
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
# compare downcoef against wavedec outputs
for nlevels in [1, 2, 3]:
for wavelet in pywt.wavelist():
a = pywt.downcoef('a', r, wavelet, level=nlevels)
d = pywt.downcoef('d', r, wavelet, level=nlevels)
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
assert_allclose(a, coeffs[0])
assert_allclose(d, coeffs[1])


def test_upcoef_multilevel():
r = np.random.randn(4)
rstate = np.random.RandomState(1234)
r = rstate.randn(4)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.upcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.upcoef('a', r, 'haar', level=3)
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)


Expand Down

0 comments on commit e2426ae

Please sign in to comment.