Skip to content

Commit

Permalink
Merge pull request #99 from grlee77/dwt_complex
Browse files Browse the repository at this point in the history
complex support in all dwt and idwt related functions
  • Loading branch information
aaren committed Aug 14, 2015
2 parents e3bab45 + aa9ca60 commit f188e57
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 17 deletions.
37 changes: 35 additions & 2 deletions pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ def dwt(object data, object wavelet, object mode='sym'):
[-0.70710678 -0.70710678 -0.70710678]
"""
if np.iscomplexobj(data):
data = np.asarray(data)
cA_r, cD_r = dwt(data.real, wavelet, mode)
cA_i, cD_i = dwt(data.imag, wavelet, mode)
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)

# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.array(data, dtype=dt)
Expand Down Expand Up @@ -908,6 +914,18 @@ def idwt(cA, cD, object wavelet, object mode='sym', int correct_size=0):
raise ValueError("At least one coefficient parameter must be "
"specified.")

# for complex inputs: compute real and imaginary separately then combine
if ((cA is not None) and np.iscomplexobj(cA)) or ((cD is not None) and
np.iscomplexobj(cD)):
if cA is None:
cD = np.asarray(cD)
cA = np.zeros_like(cD)
elif cD is None:
cA = np.asarray(cA)
cD = np.zeros_like(cA)
return (idwt(cA.real, cD.real, wavelet, mode, correct_size) +
1j*idwt(cA.imag, cD.imag, wavelet, mode, correct_size))

if cA is not None:
dt = _check_dtype(cA)
cA = np.array(cA, dtype=dt)
Expand All @@ -925,9 +943,9 @@ def idwt(cA, cD, object wavelet, object mode='sym', int correct_size=0):
cA = cA.astype(np.float64)
cD = cD.astype(np.float64)
elif cA is None:
cA = np.zeros(cD.shape, dtype=cD.dtype)
cA = np.zeros_like(cD)
elif cD is None:
cD = np.zeros(cA.shape, dtype=cA.dtype)
cD = np.zeros_like(cA)

return _idwt(cA, cD, wavelet, mode, correct_size)

Expand Down Expand Up @@ -1044,6 +1062,9 @@ def upcoef(part, coeffs, wavelet, level=1, take=0):
[ 1. 2. 3. 4. 5. 6.]
"""
if np.iscomplexobj(coeffs):
return (upcoef(part, coeffs.real, wavelet, level, take) +
1j*upcoef(part, coeffs.imag, wavelet, level, take))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(coeffs)
coeffs = np.array(coeffs, dtype=dt)
Expand Down Expand Up @@ -1151,6 +1172,9 @@ def downcoef(part, data, wavelet, mode='sym', level=1):
upcoef
"""
if np.iscomplexobj(data):
return (downcoef(part, data.real, wavelet, mode, level) +
1j*downcoef(part, data.imag, wavelet, mode, level))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.array(data, dtype=dt)
Expand Down Expand Up @@ -1266,6 +1290,15 @@ def swt(data, object wavelet, object level=None, int start_level=0):
[(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)]
"""
if np.iscomplexobj(data):
data = np.asarray(data)
coeffs_real = swt(data.real, wavelet, level, start_level)
coeffs_imag = swt(data.imag, wavelet, level, start_level)
coeffs_cplx = []
for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag):
coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i))
return coeffs_cplx

# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.array(data, dtype=dt)
Expand Down
25 changes: 22 additions & 3 deletions pywt/tests/test_dwt_idwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import pywt

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float32, np.float64, np.complex64, np.complex128]
dtypes_out = [np.float64, np.float32, np.float64, np.complex64, np.complex128]


def test_dwt_idwt_basic():
x = [3, 7, 1, 1, -2, 5, 4, 6]
Expand All @@ -22,9 +27,6 @@ def test_dwt_idwt_basic():


def test_dwt_idwt_dtypes():
# Check that float32 is preserved. Other types get converted to float64.
dtypes_in = [np.int8, np.float32, np.float64]
dtypes_out = [np.float64, np.float32, np.float64]
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones(4, dtype=dt_in)
Expand All @@ -37,6 +39,23 @@ def test_dwt_idwt_dtypes():
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)


def test_dwt_idwt_basic_complex():
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 0.5j*x
cA, cD = pywt.dwt(x, 'db2')
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
7.77817459])
cA_expect = cA_expect + 0.5j*cA_expect
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
1.22474487])
cD_expect = cD_expect + 0.5j*cD_expect
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)

x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)


def test_dwt_input_error():
data = np.ones((16, 1))
assert_raises(ValueError, pywt.dwt, data, 'haar')
Expand Down
49 changes: 43 additions & 6 deletions pywt/tests/test_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

import pywt

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float32, np.float64, np.complex64, np.complex128]
dtypes_out = [np.float64, np.float32, np.float64, np.complex64, np.complex128]


def test_dwtn_input():
# Array-like must be accepted
Expand Down Expand Up @@ -74,6 +79,27 @@ def test_byte_offset():
assert_allclose(padded_dwtn[key], expected[key])


def test_3D_reconstruct_complex():
# All dimensions even length so `take` does not need to be specified
data = np.array([
[[0, 4, 1, 5, 1, 4],
[0, 5, 26, 3, 2, 1],
[5, 8, 2, 33, 4, 9],
[2, 5, 19, 4, 19, 1]],
[[1, 5, 1, 2, 3, 4],
[7, 12, 6, 52, 7, 8],
[2, 12, 3, 52, 6, 8],
[5, 2, 6, 78, 12, 2]]])
data = data + 1j

wavelet = pywt.Wavelet('haar')
d = pywt.dwtn(data, wavelet)
# idwtn creates even-length shapes (2x dwtn size)
original_shape = [slice(None, s) for s in data.shape]
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
rtol=1e-13, atol=1e-13)


def test_idwtn_idwt2():
data = np.array([
[0, 4, 1, 5, 1, 4],
Expand All @@ -91,6 +117,23 @@ def test_idwtn_idwt2():
rtol=1e-14, atol=1e-14)


def test_idwtn_idwt2_complex():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
data = data + 1j
wavelet = pywt.Wavelet('haar')

LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}

for mode in pywt.MODES.modes:
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-14, atol=1e-14)


def test_idwtn_missing():
# Test to confirm missing data behave as zeroes
data = np.array([
Expand Down Expand Up @@ -156,9 +199,6 @@ def test_error_mismatched_size():


def test_dwt2_idwt2_dtypes():
# Check that float32 is preserved. Other types get converted to float64.
dtypes_in = [np.int8, np.float32, np.float64]
dtypes_out = [np.float64, np.float32, np.float64]
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
Expand All @@ -173,9 +213,6 @@ def test_dwt2_idwt2_dtypes():


def test_dwtn_idwtn_dtypes():
# Check that float32 is preserved. Other types get converted to float64.
dtypes_in = [np.int8, np.float32, np.float64]
dtypes_out = [np.float64, np.float32, np.float64]
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
Expand Down
25 changes: 19 additions & 6 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

import pywt

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float32, np.float64, np.complex64, np.complex128]
dtypes_out = [np.float64, np.float32, np.float64, np.complex64, np.complex128]


def test_wavedec():
x = [3, 7, 1, 1, -2, 5, 4, 6]
Expand All @@ -26,6 +31,13 @@ def test_waverec():
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)


def test_waverec_complex():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 1j
coeffs = pywt.wavedec(x, 'db1')
assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12)


def test_swt_decomposition():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
Expand All @@ -51,9 +63,6 @@ def test_swt_decomposition():


def test_swt_dtypes():
# Check that float32 is preserved. Other types get converted to float64.
dtypes_in = [np.int8, np.float32, np.float64]
dtypes_out = [np.float64, np.float32, np.float64]
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
errmsg = "wrong dtype returned for {0} input".format(dt_in)
Expand All @@ -78,9 +87,6 @@ def test_wavedec2():


def test_multilevel_dtypes():
# Check that float32 is preserved. Other types get converted to float64.
dtypes_in = [np.int8, np.float32, np.float64]
dtypes_out = [np.float64, np.float32, np.float64]
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# wavedec, waverec
Expand All @@ -105,5 +111,12 @@ def test_multilevel_dtypes():
assert_(x_roundtrip.dtype == dt_out, "waverec2: " + errmsg)


def test_wavedec2_complex():
data = np.ones((4, 4)) + 1j
coeffs = pywt.wavedec2(data, 'db1')
assert_(len(coeffs) == 3)
assert_allclose(pywt.waverec2(coeffs, 'db1'), data, rtol=1e-12)


if __name__ == '__main__':
run_module_suite()

0 comments on commit f188e57

Please sign in to comment.