Skip to content

Commit

Permalink
Merge pull request #112 from kwohlfahrt/complex-fix
Browse files Browse the repository at this point in the history
Complex fix
  • Loading branch information
aaren committed Aug 14, 2015
2 parents f188e57 + 79566a8 commit c414763
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
24 changes: 18 additions & 6 deletions pywt/multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,19 @@ def dwtn(data, wavelet, mode='sym'):
"""
data = np.asarray(data)
ndim = data.ndim
if np.iscomplexobj(data):
keys = (''.join(k) for k in product('ad', repeat=data.ndim))
real = dwtn(data.real, wavelet, mode)
imag = dwtn(data.imag, wavelet, mode)
return dict((k, real[k] + 1j * imag[k]) for k in keys)

if data.dtype == np.dtype('object'):
raise TypeError("Input must be a numeric array-like")
if ndim < 1:
if data.ndim < 1:
raise ValueError("Input data must be at least 1D")
coeffs = [('', data)]

for axis in range(ndim):
for axis in range(data.ndim):
new_coeffs = []
for subband, x in coeffs:
cA, cD = dwt_axis(x, wavelet, mode, axis)
Expand Down Expand Up @@ -269,7 +273,15 @@ def idwtn(coeffs, wavelet, mode='sym'):
mode = MODES.from_object(mode)

# Ignore any invalid keys
coeffs = dict((k, v) for k, v in coeffs.items() if set(k) <= set('ad'))
coeffs = dict((k, np.asarray(v)) for k, v in coeffs.items()
if v is not None and set(k) <= set('ad'))

if any(np.iscomplexobj(v) for v in coeffs.values()):
real_coeffs = dict((k, v.real) for k, v in coeffs.items())
imag_coeffs = dict((k, v.imag) for k, v in coeffs.items())
return (idwtn(real_coeffs, wavelet, mode)
+ 1j * idwtn(imag_coeffs, wavelet, mode))

dims = max(len(key) for key in coeffs.keys())

try:
Expand All @@ -287,8 +299,8 @@ def idwtn(coeffs, wavelet, mode='sym'):
new_keys = [''.join(coeff) for coeff in product('ad', repeat=axis)]

for key in new_keys:
L = coeffs.get(key + 'a')
H = coeffs.get(key + 'd')
L = coeffs.get(key + 'a', None)
H = coeffs.get(key + 'd', None)

new_coeffs[key] = idwt_axis(L, H, wavelet, mode, axis)
coeffs = new_coeffs
Expand Down
1 change: 0 additions & 1 deletion pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,6 @@ cpdef dwt_axis(np.ndarray data, object wavelet, object mode='sym', unsigned int
return (cA, cD)


# TODO: Use idwt rather than upcoef, which requires `mode` but not `take`
cpdef idwt_axis(np.ndarray coefs_a, np.ndarray coefs_d, object wavelet,
object mode='sym', unsigned int axis=0):
cdef Wavelet w = c_wavelet_from_object(wavelet)
Expand Down

0 comments on commit c414763

Please sign in to comment.