Skip to content

Commit

Permalink
Multiresolution Analysis (#527)
Browse files Browse the repository at this point in the history
* add functions for mra additive decomposition

* add MRA vs. SWT timecourse alignment demo

add imra functions to the namespace

* DOC: add MRA functions to the API documentation

* Add 2D MRA example

* DOC: add notes on the redundancy of the MRA transforms

* Use NumPy's AxisError for axis-related errors

* TST: add test suite for Multiresolution Analysis functions

* BUG: fix swt, swt2 and swtn when not built with C99 complex support

norm, axis or axes kwargs were not being appropriately passed on in this case!

BUG: fix iswt of complex-valued inputs when C99 complex unavailable

logic to split coefficients into real and imaginary parts was missing!

* ENH: support axis argument in iswt

ENH: support axes argument in swt2, iswt2

* TST: expand and simplify axis/axes tests

both dwt and swt now support axis-specific transforms of nD data
  • Loading branch information
grlee77 committed Nov 7, 2021
1 parent a092954 commit ce076bd
Show file tree
Hide file tree
Showing 13 changed files with 932 additions and 44 deletions.
40 changes: 40 additions & 0 deletions demo/mra2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python

import matplotlib.pyplot as plt

import pywt
import pywt.data

camera = pywt.data.camera()

wavelet = pywt.Wavelet('sym2')
level = 5
# Note: Running with transform="dwtn" is faster, but the resulting images will
# look substantially worse.
coeffs = pywt.mran(camera, wavelet=wavelet, level=level, transform='swtn')
ca = coeffs[0]
details = coeffs[1:]

# Plot all coefficient subbands and the original
gridspec_kw = dict(hspace=0.1, wspace=0.1)
fontdict = dict(verticalalignment='center', horizontalalignment='center',
color='k')
fig, axes = plt.subplots(len(details) + 1, 3, figsize=[5, 8], sharex=True,
sharey=True, gridspec_kw=gridspec_kw)
imshow_kw = dict(interpolation='nearest', cmap=plt.cm.gray)
for i, x in enumerate(details):
axes[i][0].imshow(details[-i - 1]['ad'], **imshow_kw)

axes[i][1].imshow(details[-i - 1]['da'], **imshow_kw)
axes[i][2].imshow(details[-i - 1]['dd'], **imshow_kw)
axes[i][0].text(256, 50, 'ad%d' % (i + 1), fontdict=fontdict)
axes[i][1].text(256, 50, 'da%d' % (i + 1), fontdict=fontdict)
axes[i][2].text(256, 50, 'dd%d' % (i + 1), fontdict=fontdict)

axes[-1][0].imshow(ca, **imshow_kw)
axes[-1][0].text(256, 50, 'approx.', fontdict=fontdict)
axes[-1][1].imshow(camera, **imshow_kw)
axes[-1][1].text(256, 50, 'original', fontdict=fontdict)

for ax in axes.ravel():
ax.set_axis_off()
72 changes: 72 additions & 0 deletions demo/mra_vs_swt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt

import pywt
import pywt.data

ecg = pywt.data.ecg()

wavelet = pywt.Wavelet('db8') #'db16')
level = 4
coeffs = pywt.mra(ecg, wavelet=wavelet, level=level)
ca = coeffs[0]
details = coeffs[1:]

# Create a plot using the same axis limits for all coefficient arrays to
# illustrate the preservation of alignment across decomposition levels.
ylim = [ecg.min(), ecg.max()]

def mark_peaks(ax):
# add dashed lines at the locations of the ECG peaks for reference
ylim = ax.get_ylim()
ax.plot([190, 190], ylim, 'k--')
ax.plot([518, 518], ylim, 'k--')
ax.plot([848, 848], ylim, 'k--')

fig, axes = plt.subplots(len(coeffs) + 1, 2, figsize=[12, 4])
axes[0][0].set_title("MRA decomposition")
axes[0][0].plot(ecg)
axes[0][0].set_ylabel('ECG Signal')
axes[0][0].set_xlim(0, len(ecg) - 1)
axes[0][0].set_ylim(ylim[0], ylim[1])
mark_peaks(axes[0][0])

for i, x in enumerate(coeffs):
ax = axes[-i - 1][0]
ax.plot(coeffs[i], 'g')
if i == 0:
ax.set_ylabel("A%d" % (len(coeffs) - 1))
else:
ax.set_ylabel("D%d" % (len(coeffs) - i))
# Scale axes
ax.set_xlim(0, len(ecg) - 1)
ax.set_ylim(ylim[0], ylim[1])
mark_peaks(ax)

"""
repeat using the SWT instead of MRA as the decomposition
"""
coeffs = pywt.swt(ecg, wavelet=wavelet, level=level, norm=True, trim_approx=True)
ca = coeffs[0]
details = coeffs[1:]

axes[0][1].set_title("normalized SWT decomposition")
axes[0][1].plot(ecg)
axes[0][1].set_ylabel('ECG Signal')
axes[0][1].set_xlim(0, len(ecg) - 1)
axes[0][1].set_ylim(ylim[0], ylim[1])
mark_peaks(axes[0][1])

for i, x in enumerate(coeffs):
ax = axes[-i - 1][1]
ax.plot(coeffs[i], 'g')
if i == 0:
ax.set_ylabel("A%d" % (len(coeffs) - 1))
else:
ax.set_ylabel("D%d" % (len(coeffs) - i))
# Scale axes
ax.set_xlim(0, len(ecg) - 1)
ax.set_ylim(ylim[0], ylim[1])
mark_peaks(ax)
1 change: 1 addition & 0 deletions doc/source/ref/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ API Reference
dwt-coefficient-handling
swt-stationary-wavelet-transform
iswt-inverse-stationary-wavelet-transform
mra
wavelet-packets
cwt
thresholding-functions
Expand Down
43 changes: 43 additions & 0 deletions doc/source/ref/mra.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
.. _ref-swt:

.. currentmodule:: pywt

Multiresolution Analysis
------------------------

The functions in this module can be used to project a signal onto wavelet
subspaces and an approximation subspace. This is an additive decomposition such
that the sum of the coefficients equals the original signal. The projected
signal coefficients remains temporally aligned with the original, regardless
of the symmetry of the wavelet used for the analysis.

Multilevel 1D ``mra``
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mra

Multilevel 2D ``mra2``
~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mra2

Multilevel n-dimensional ``mran``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mran

Inverse Multilevel 1D ``imra``
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: imra

Inverse Multilevel 2D ``imra2``
~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: imra2

Inverse Multilevel n-dimensional ``imran``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: imran

1 change: 1 addition & 0 deletions pywt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._dwt import *
from ._swt import *
from ._cwt import *
from ._mra import *

from . import data

Expand Down
2 changes: 1 addition & 1 deletion pywt/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
if np.isscalar(scales):
scales = np.array([scales])
if not np.isscalar(axis):
raise ValueError("axis must be a scalar.")
raise np.AxisError("axis must be a scalar.")

dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
Expand Down
4 changes: 2 additions & 2 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise ValueError("Axis greater than data dimensions")
raise np.AxisError("Axis greater than data dimensions")

if data.ndim == 1:
cA, cD = dwt_single(data, wavelet, mode)
Expand Down Expand Up @@ -282,7 +282,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + ndim
if not 0 <= axis < ndim:
raise ValueError("Axis greater than coefficient dimensions")
raise np.AxisError("Axis greater than coefficient dimensions")

if ndim == 1:
rec = idwt_single(cA, cD, wavelet, mode)
Expand Down

0 comments on commit ce076bd

Please sign in to comment.