Skip to content

Commit

Permalink
Merge pull request #129 from kwohlfahrt/unify_2D
Browse files Browse the repository at this point in the history
Unify 2D and nD dwt
  • Loading branch information
grlee77 committed Nov 27, 2015
2 parents 36c15ca + 4a2c175 commit 935c55b
Showing 1 changed file with 6 additions and 106 deletions.
112 changes: 6 additions & 106 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,40 +56,8 @@ def dwt2(data, wavelet, mode='symmetric'):
if data.ndim != 2:
raise ValueError("Expected 2-D data array")

if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)

mode = Modes.from_object(mode)

# filter rows
H, L = [], []
for row in data:
cA, cD = dwt(row, wavelet, mode)
L.append(cA)
H.append(cD)

# filter columns
H = np.transpose(H)
L = np.transpose(L)

LL, HL = [], []
for row in L:
cA, cD = dwt(row, wavelet, mode)
LL.append(cA)
HL.append(cD)

LH, HH = [], []
for row in H:
cA, cD = dwt(row, wavelet, mode)
LH.append(cA)
HH.append(cD)

# build result structure: (approx,
# (horizontal, vertical, diagonal))
ret = (np.transpose(LL),
(np.transpose(HL), np.transpose(LH), np.transpose(HH)))

return ret
coefs = dwtn(data, wavelet, mode)
return coefs['aa'], (coefs['da'], coefs['ad'], coefs['dd'])


def idwt2(coeffs, wavelet, mode='symmetric'):
Expand Down Expand Up @@ -118,81 +86,13 @@ def idwt2(coeffs, wavelet, mode='symmetric'):
[ 3., 4.]])
"""
if len(coeffs) != 2 or len(coeffs[1]) != 3:
raise ValueError("Invalid coeffs param")

# L -low-pass data, H - high-pass data
LL, (HL, LH, HH) = coeffs
if not all(c.ndim == 2 for c in (LL, HL, LH, HL) if c is not None):
raise TypeError("All input coefficients arrays must be 2D.")

if LL is not None:
LL = np.transpose(LL)
if LH is not None:
LH = np.transpose(LH)
if HL is not None:
HL = np.transpose(HL)
if HH is not None:
HH = np.transpose(HH)

all_none = True
for arr in (LL, LH, HL, HH):
if arr is not None:
all_none = False
if arr.ndim != 2:
raise TypeError("All input coefficients arrays must be 2D.")

if all_none:
raise ValueError(
"At least one input coefficients array must not be None.")

if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)

mode = Modes.from_object(mode)

# idwt columns
L = []
if LL is None and HL is None:
L = None
else:
if LL is None:
# IDWT can handle None input values - equals to zero-array
LL = cycle([None])
if HL is None:
# IDWT can handle None input values - equals to zero-array
HL = cycle([None])
for rowL, rowH in zip(LL, HL):
L.append(idwt(rowL, rowH, wavelet, mode))

H = []
if LH is None and HH is None:
H = None
else:
if LH is None:
# IDWT can handle None input values - equals to zero-array
LH = cycle([None])
if HH is None:
# IDWT can handle None input values - equals to zero-array
HH = cycle([None])
for rowL, rowH in zip(LH, HH):
H.append(idwt(rowL, rowH, wavelet, mode))

if L is not None:
L = np.transpose(L)
if H is not None:
H = np.transpose(H)

# idwt rows
data = []
if L is None:
# IDWT can handle None input values - equals to zero-array
L = cycle([None])
if H is None:
# IDWT can handle None input values - equals to zero-array
H = cycle([None])
for rowL, rowH in zip(L, H):
data.append(idwt(rowL, rowH, wavelet, mode))

return np.array(data)
coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
return idwtn(coeffs, wavelet, mode)


def dwtn(data, wavelet, mode='symmetric'):
Expand Down

0 comments on commit 935c55b

Please sign in to comment.