Skip to content

Commit

Permalink
Merge pull request #125 from kwohlfahrt/correct_length
Browse files Browse the repository at this point in the history
Disallow mismatching coefficient lengths in idwt2
  • Loading branch information
grlee77 committed Nov 24, 2015
2 parents f443dbd + b113ebc commit b3dcaa6
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 87 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
Changelog

0.4.0
changes:
- idwt no longer takes a 'correct_length' parameter.
- Sizes of the arrays passed to all idwt functions must match exactly.
- use 'multilevel.wavecrec' for multilevel transforms

0.3.0
A major refactoring, providing support for Python 3.x while maintaining
full backwards compatiblity.
Expand Down
5 changes: 5 additions & 0 deletions doc/release/0.4.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ Deprecated features
Backwards incompatible changes
==============================

``idwt`` no longer takes a ``correct_length`` parameter. As a consequence,
``idwt2`` inputs must match exactly in length. For multilevel transforms, where
arrays differing in size by one element may be produced, use the ``waverec``
functions from the ``multilevel`` module instead.


Other changes
=============
Expand Down
6 changes: 3 additions & 3 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def idwt2(coeffs, wavelet, mode='symmetric'):
# IDWT can handle None input values - equals to zero-array
HL = cycle([None])
for rowL, rowH in zip(LL, HL):
L.append(idwt(rowL, rowH, wavelet, mode, 1))
L.append(idwt(rowL, rowH, wavelet, mode))

H = []
if LH is None and HH is None:
Expand All @@ -174,7 +174,7 @@ def idwt2(coeffs, wavelet, mode='symmetric'):
# IDWT can handle None input values - equals to zero-array
HH = cycle([None])
for rowL, rowH in zip(LH, HH):
H.append(idwt(rowL, rowH, wavelet, mode, 1))
H.append(idwt(rowL, rowH, wavelet, mode))

if L is not None:
L = np.transpose(L)
Expand All @@ -190,7 +190,7 @@ def idwt2(coeffs, wavelet, mode='symmetric'):
# IDWT can handle None input values - equals to zero-array
H = cycle([None])
for rowL, rowH in zip(L, H):
data.append(idwt(rowL, rowH, wavelet, mode, 1))
data.append(idwt(rowL, rowH, wavelet, mode))

return np.array(data)

Expand Down
19 changes: 17 additions & 2 deletions pywt/_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def waverec(coeffs, wavelet, mode='symmetric'):
a, ds = coeffs[0], coeffs[1:]

for d in ds:
a = idwt(a, d, wavelet, mode, 1)
if len(a) == len(d) + 1:
a = a[:-1]
a = idwt(a, d, wavelet, mode)

return a

Expand Down Expand Up @@ -215,9 +217,22 @@ def waverec2(coeffs, wavelet, mode='symmetric'):
"Coefficient list too short (minimum 2 arrays required).")

a, ds = coeffs[0], coeffs[1:]
a = np.asarray(a)

for d in ds:
a = idwt2((a, d), wavelet, mode)
d = tuple(np.asarray(coeff) if coeff is not None else None
for coeff in d)
d_shapes = (coeff.shape for coeff in d if coeff is not None)
try:
d_shape = next(d_shapes)
except StopIteration:
idxs = slice(None), slice(None)
else:
if not all(s == d_shape for s in d_shapes):
raise ValueError("All detail shapes must be the same length.")
idxs = tuple(slice(None, -1 if a_len == d_len + 1 else None)
for a_len, d_len in zip(a.shape, d_shape))
a = idwt2((a[idxs], d), wavelet, mode)

return a

Expand Down
3 changes: 1 addition & 2 deletions pywt/_wavelet_packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,7 @@ def _reconstruct(self, update):
raise ValueError("Node is a leaf node and cannot be reconstructed"
" from subnodes.")
else:
rec = idwt(data_a, data_d, self.wavelet, self.mode,
correct_size=True)
rec = idwt(data_a, data_d, self.wavelet, self.mode)
if update:
self.data = rec
return rec
Expand Down
44 changes: 14 additions & 30 deletions pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,9 @@ cdef np.dtype _check_dtype(data):
return dt


def idwt(cA, cD, object wavelet, object mode='symmetric', int correct_size=0):
def idwt(cA, cD, object wavelet, object mode='symmetric'):
"""
idwt(cA, cD, wavelet, mode='symmetric', correct_size=0)
idwt(cA, cD, wavelet, mode='symmetric')
Single level Inverse Discrete Wavelet Transform
Expand All @@ -946,12 +946,6 @@ def idwt(cA, cD, object wavelet, object mode='symmetric', int correct_size=0):
Wavelet to use
mode : str, optional (default: 'symmetric')
Signal extension mode, see Modes
correct_size : int, optional (default: 0)
Under normal conditions (all data lengths dyadic) `cA` and `cD`
coefficients lists must have the same lengths. With `correct_size`
set to True, length of `cA` may be greater by one than length of `cD`.
Useful when doing multilevel decomposition and reconstruction of
non-dyadic length signals.
Returns
-------
Expand All @@ -974,8 +968,8 @@ def idwt(cA, cD, object wavelet, object mode='symmetric', int correct_size=0):
elif cD is None:
cA = np.asarray(cA)
cD = np.zeros_like(cA)
return (idwt(cA.real, cD.real, wavelet, mode, correct_size) +
1j*idwt(cA.imag, cD.imag, wavelet, mode, correct_size))
return (idwt(cA.real, cD.real, wavelet, mode) +
1j*idwt(cA.imag, cD.imag, wavelet, mode))

if cA is not None:
dt = _check_dtype(cA)
Expand All @@ -998,12 +992,12 @@ def idwt(cA, cD, object wavelet, object mode='symmetric', int correct_size=0):
elif cD is None:
cD = np.zeros_like(cA)

return _idwt(cA, cD, wavelet, mode, correct_size)
return _idwt(cA, cD, wavelet, mode)


def _idwt(np.ndarray[data_t, ndim=1, mode="c"] cA,
np.ndarray[data_t, ndim=1, mode="c"] cD,
object wavelet, object mode='symmetric', int correct_size=0):
object wavelet, object mode='symmetric'):
"""See `idwt` for details"""

cdef index_t input_len
Expand All @@ -1016,20 +1010,10 @@ def _idwt(np.ndarray[data_t, ndim=1, mode="c"] cA,

cdef np.ndarray[data_t, ndim=1, mode="c"] rec
cdef index_t rec_len
cdef index_t size_diff

# check for size difference between arrays
size_diff = cA.size - cD.size
if size_diff:
if correct_size:
if size_diff < 0 or size_diff > 1:
msg = ("Coefficients arrays must satisfy "
"(0 <= len(cA) - len(cD) <= 1).")
raise ValueError(msg)
input_len = cA.size - size_diff
else:
msg = "Coefficients arrays must have the same size."
raise ValueError(msg)
if cA.size != cD.size:
raise ValueError("Coefficients arrays must have the same size.")
else:
input_len = cA.size

Expand All @@ -1050,15 +1034,15 @@ def _idwt(np.ndarray[data_t, ndim=1, mode="c"] cA,
# reconstruction of non-null part will be performed
if data_t is np.float64_t:
if c_wt.double_idwt(&cA[0], cA.size,
&cD[0], cD.size, w.w,
&rec[0], rec.size, mode_,
correct_size) < 0:
&cD[0], cD.size,
&rec[0], rec.size,
w.w, mode_) < 0:
raise RuntimeError("C idwt failed.")
elif data_t == np.float32_t:
if c_wt.float_idwt(&cA[0], cA.size,
&cD[0], cD.size, w.w,
&rec[0], rec.size, mode_,
correct_size) < 0:
&cD[0], cD.size,
&rec[0], rec.size,
w.w, mode_) < 0:
raise RuntimeError("C idwt failed.")
else:
raise RuntimeError("Invalid data type.")
Expand Down
6 changes: 2 additions & 4 deletions pywt/src/c_wt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ cdef extern from "wt.h":

cdef int double_idwt(double * const coeffs_a, const size_t coeffs_a_len,
double * const coeffs_d, const size_t coeffs_d_len,
const Wavelet * const wavelet,
double * const output, const size_t output_len,
const MODE mode, const int correct_size)
const Wavelet * const wavelet, const MODE mode)

cdef int double_swt_a(double input[], index_t input_len, Wavelet* wavelet,
double output[], index_t output_len, int level)
Expand Down Expand Up @@ -71,9 +70,8 @@ cdef extern from "wt.h":

cdef int float_idwt(const float * const coeffs_a, const size_t coeffs_a_len,
const float * const coeffs_d, const size_t coeffs_d_len,
const Wavelet * const wavelet,
float * const output, const size_t output_len,
const MODE mode, const int correct_size)
const Wavelet * const wavelet, const MODE mode)

cdef int float_swt_a(float input[], index_t input_len, Wavelet* wavelet,
float output[], index_t output_len, int level)
Expand Down
38 changes: 7 additions & 31 deletions pywt/src/wt.template.c
Original file line number Diff line number Diff line change
Expand Up @@ -340,48 +340,24 @@ int CAT(TYPE, _rec_d)(const TYPE * const restrict coeffs_d, const size_t coeffs_


/*
* IDWT reconstruction from approximation and detail coeffs
* IDWT reconstruction from approximation and detail coeffs, either of which may
* be NULL.
*
* If fix_size_diff is 1 then coeffs arrays can differ by one in length (this
* is useful in multilevel decompositions and reconstructions of odd-length
* signals). Requires zero-filled output buffer.
* Requires zero-filled output buffer.
*/
int CAT(TYPE, _idwt)(const TYPE * const restrict coeffs_a, const size_t coeffs_a_len,
const TYPE * const restrict coeffs_d, const size_t coeffs_d_len,
const Wavelet * const restrict wavelet,
TYPE * const restrict output, const size_t output_len,
const MODE mode, const int fix_size_diff){

const Wavelet * const restrict wavelet, const MODE mode){
size_t input_len;

/*
* If one of coeffs array is NULL then the reconstruction will be performed
* using the other one
*/

if(coeffs_a != NULL && coeffs_d != NULL){

if(fix_size_diff){
if( (coeffs_a_len > coeffs_d_len ? coeffs_a_len - coeffs_d_len
: coeffs_d_len-coeffs_a_len) > 1){ /* abs(a-b) */
goto error;
}

input_len = coeffs_a_len>coeffs_d_len ? coeffs_d_len
: coeffs_a_len; /* min */
} else {
if(coeffs_a_len != coeffs_d_len)
goto error;

input_len = coeffs_a_len;
}

if(coeffs_a_len != coeffs_d_len)
goto error;
input_len = coeffs_a_len;
} else if(coeffs_a != NULL){
input_len = coeffs_a_len;

} else if (coeffs_d != NULL){
input_len = coeffs_d_len;

} else {
goto error;
}
Expand Down
3 changes: 1 addition & 2 deletions pywt/src/wt.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ int CAT(TYPE, _rec_d)(const TYPE * const restrict coeffs_d, const size_t coeffs_
/* Single level IDWT reconstruction */
int CAT(TYPE, _idwt)(const TYPE * const restrict coeffs_a, const size_t coeffs_a_len,
const TYPE * const restrict coeffs_d, const size_t coeffs_d_len,
const Wavelet * const wavelet,
TYPE * const restrict output, const size_t output_len,
const MODE mode, const int fix_size_diff);
const Wavelet * const wavelet, const MODE mode);

/* SWT decomposition at given level */
int CAT(TYPE, _swt_a)(TYPE input[], index_t input_len,
Expand Down
13 changes: 0 additions & 13 deletions pywt/tests/test_dwt_idwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,6 @@ def test_idwt_none_input():
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')


def test_idwt_correct_size_kw():
res = pywt.idwt([1, 2, 3, 4, 5], [1, 2, 3, 4], 'db2', 'symmetric',
correct_size=True)
expected = [1.76776695, 0.61237244, 3.18198052, 0.61237244, 4.59619408,
0.61237244]
assert_allclose(res, expected)

assert_raises(ValueError, pywt.idwt,
[1, 2, 3, 4, 5], [1, 2, 3, 4], 'db2', 'symmetric')
assert_raises(ValueError, pywt.idwt, [1, 2, 3, 4], [1, 2, 3, 4, 5], 'db2',
'symmetric', correct_size=True)


def test_idwt_invalid_input():
# Too short, min length is 4 for 'db4':
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
Expand Down
7 changes: 7 additions & 0 deletions pywt/tests/test_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,5 +226,12 @@ def test_dwtn_idwtn_dtypes():
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)


def test_idwt2_size_mismatch_error():
LL = np.zeros((6, 6))
LH = HL = HH = np.zeros((5, 5))

assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')


if __name__ == '__main__':
run_module_suite()
19 changes: 19 additions & 0 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def test_waverec():
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)


def test_waverec_odd_length():
x = [3, 7, 1, 1, -2, 5]
coeffs = pywt.wavedec(x, 'db1')
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)


def test_waverec_complex():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 1j
Expand Down Expand Up @@ -174,5 +180,18 @@ def test_wavedec2_complex():
assert_allclose(pywt.waverec2(coeffs, 'db1'), data, rtol=1e-12)


def test_waverec2_odd_length():
x = np.ones((10, 6))
coeffs = pywt.wavedec2(x, 'db1')
assert_allclose(pywt.waverec2(coeffs, 'db1'), x, rtol=1e-12)


def test_waverec2_none_coeffs():
x = np.arange(24).reshape(6, 4)
coeffs = pywt.wavedec2(x, 'db1')
coeffs[1] = (None, None, None)
assert_(x.shape == pywt.waverec2(coeffs, 'db1').shape)


if __name__ == '__main__':
run_module_suite()

0 comments on commit b3dcaa6

Please sign in to comment.