Skip to content

Commit

Permalink
Merge pull request #393 from grlee77/wpacket_nd_and_axis
Browse files Browse the repository at this point in the history
Wavelet packets: extend to nD and support subsets of the axes
  • Loading branch information
rgommers committed Nov 8, 2021
2 parents 70bc050 + 5969b7f commit 9996424
Show file tree
Hide file tree
Showing 7 changed files with 794 additions and 63 deletions.
32 changes: 32 additions & 0 deletions demo/wp_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python
# Note: This demo is a repeat of wp_2d, but using WaveletPacketND instead

import numpy as np
import matplotlib.pyplot as plt

from pywt import WaveletPacketND
import pywt.data


arr = pywt.data.aero()

maxlevel = 2
wp2 = WaveletPacketND(arr, 'db2', 'symmetric', maxlevel=maxlevel)

# Show original figure
plt.imshow(arr, interpolation="nearest", cmap=plt.cm.gray)

fig = plt.figure()
i = 1
nsubplots = len(wp2.get_level(maxlevel, 'natural'))
nrows = int(np.floor(np.sqrt(nsubplots)))
ncols = int(np.ceil(nsubplots/nrows))
for node in wp2.get_level(maxlevel, 'natural'):
ax = fig.add_subplot(nrows, ncols, i)
ax.set_title("%s" % (node.path_tuple, ))
ax.imshow(np.sqrt(np.abs(node.data)), origin='upper',
interpolation="nearest", cmap=plt.cm.gray)
ax.set_axis_off()
i += 1

plt.show()
183 changes: 150 additions & 33 deletions doc/source/ref/wavelet-packets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@
Wavelet Packets
===============

.. versionadded:: 0.2

Version `0.2` of PyWavelets includes many new features and improvements. One of such
new feature is a two-dimensional wavelet packet transform structure that is
almost completely sharing programming interface with the one-dimensional tree
PyWavelets implements one-dimensional, two-dimensional and n-dimensional
wavelet packet transform structures. The higher dimensional structures almost
completely sharing programming interface with the one-dimensional tree
structure.

In order to achieve this simplification, a new inheritance scheme was used
in which a :class:`~pywt.BaseNode` base node class is a superclass for both
:class:`~pywt.Node` and :class:`~pywt.Node2D` node classes.

The node classes are used as data wrappers and can be organized in trees (binary
trees for 1D transform case and quad-trees for the 2D one). They are also
superclasses to the :class:`~pywt.WaveletPacket` class and
:class:`~pywt.WaveletPacket2D` class that are used as the decomposition tree
roots and contain a couple additional methods.
in which a :class:`~pywt.BaseNode` base node class is a superclass for the
:class:`~pywt.Node`, :class:`~pywt.Node2D` and :class:`~pywt.NodeND`
classes.

The node classes are used as data wrappers and can be organized in trees (
binary trees for 1D transform case, quad-trees for the 2D one and 2**N-ary
trees in ND). They are also superclasses to the :class:`~pywt.WaveletPacket`,
:class:`~pywt.WaveletPacket2D` and :class:`~pywt.WaveletPacketND` classes that
are used as the decomposition tree roots and contain a couple additional
methods.

Here 1D, 2D and ND refer to the number of axes of the data to be transformed.
All wavelet packet objects can operate on general n-dimensional arrays, but the
1D or 2D classes apply transforms along only 1 or 2 dimensions. The ND classes
allow transforms over an arbtirary number of axes of n-dimensional data.

The below diagram illustrates the inheritance tree:

Expand All @@ -36,33 +41,36 @@ The below diagram illustrates the inheritance tree:

- :class:`~pywt.WaveletPacket2D` - 2D decomposition tree root node

- :class:`~pywt.NodeND` - data carrier node in a ND decomposition tree

- :class:`~pywt.WaveletPacketND` - ND decomposition tree root node

BaseNode - a common interface of WaveletPacket and WaveletPacket2D
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

BaseNode - a common interface of WaveletPacket, WaveletPacket2D and WaveletPacketND
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. class:: BaseNode

.. note:: The BaseNode is a base class for :class:`Node` and :class:`Node2D`.
It should not be used directly unless creating a new transformation
type. It is included here to document the common interface of 1D
and 2D node an wavelet packet transform classes.
.. note:: The BaseNode is a base class for :class:`Node`, :class:`Node2D`,
and :class:`NodeND`. It should not be used directly unless creating
a new transformation type. It is included here to document the
common interface of the node and wavelet packet transform classes.

.. method:: __init__(parent, data, node_name)

:param parent: parent node. If parent is ``None`` then the node is
considered detached.

:param data: data associated with the node. 1D or 2D numeric array,
depending on the transform type.
:param data: The data associated with the node. An n-dimensional
numeric array.

:param node_name: a name identifying the coefficients type.
See :attr:`Node.node_name` and :attr:`Node2D.node_name`
for information on the accepted subnodes names.

.. attribute:: data

Data associated with the node. 1D or 2D numeric array (depends on the
transform type).
Data associated with the node. An n-dimensional numeric array.

.. attribute:: parent

Expand All @@ -73,6 +81,11 @@ BaseNode - a common interface of WaveletPacket and WaveletPacket2D
:class:`~pywt.Wavelet` used for decomposition and reconstruction. Inherited
from parent node.

.. attribute:: axes

A tuple of ints containing the axes along which the wavelet packet
transform is to be applied.

.. attribute:: mode

Signal extension :ref:`mode <ref-modes>` for the :func:`dwt` (:func:`dwt2`)
Expand All @@ -88,6 +101,13 @@ BaseNode - a common interface of WaveletPacket and WaveletPacket2D

Path string defining position of the node in the decomposition tree.

.. attribute:: path_tuple

A version of :attr:`path`, but in tuple form rather than as a single
string. The tuple form is easier to work with for n-dimensional transforms.
The length of the tuple will be equal to the number of levels of
decomposition at the current node.

.. attribute:: node_name

Node name describing :attr:`~BaseNode.data` coefficients type of the
Expand Down Expand Up @@ -214,8 +234,8 @@ BaseNode - a common interface of WaveletPacket and WaveletPacket2D
:attr:`maximum level <BaseNode.maxlevel>`.


WaveletPacket and WaveletPacket tree Node
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
WaveletPacket and Node
~~~~~~~~~~~~~~~~~~~~~~

.. class:: Node(BaseNode)

Expand All @@ -230,16 +250,22 @@ WaveletPacket and WaveletPacket tree Node

.. method:: decompose()

.. seealso::
.. seealso::

:func:`dwt` for 1D Discrete Wavelet Transform output coefficients.

.. method:: reconstruct()

- :func:`dwt` for 1D Discrete Wavelet Transform output coefficients.
.. seealso::

:func:`idwt` for 1D Inverse Discrete Wavelet Transform


.. class:: WaveletPacket(Node)

.. method:: __init__(data, wavelet, [mode='symmetric', [maxlevel=None]])
.. method:: __init__(data, wavelet, [mode='symmetric', [maxlevel=None, [axis=-1]]])

:param data: data associated with the node. 1D numeric array.
:param data: data associated with the node. N-dimensional numeric array.

:param wavelet: |wavelet|

Expand All @@ -250,6 +276,8 @@ WaveletPacket and WaveletPacket tree Node
it will be calculated based on the ``wavelet`` and
``data`` length using :func:`pywt.dwt_max_level`.

:param axis: The axis of the array that is to be transformed.

.. method:: get_level(level, [order="natural", [decompose=True]])

Collects nodes from the given level of decomposition.
Expand All @@ -267,8 +295,17 @@ WaveletPacket and WaveletPacket tree Node
decomposed) and the ``decompose`` is set to ``False``, only existing nodes
will be returned.

WaveletPacket2D and WaveletPacket2D tree Node2D
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. method:: reconstruct([update=True])

Reconstruct data from the subnodes.

:param update: A boolean indicating whether the coefficients of the
current node and its subnodes will be replaced with values
from the reconstruction.


WaveletPacket2D and Node2D
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. class:: Node2D(BaseNode)

Expand All @@ -286,14 +323,20 @@ WaveletPacket2D and WaveletPacket2D tree Node2D

:func:`dwt2` for 2D Discrete Wavelet Transform output coefficients.

.. method:: reconstruct()

.. seealso::

:func:`idwt2` for 2D Inverse Discrete Wavelet Transform

.. method:: expand_2d_path(self, path):


.. class:: WaveletPacket2D(Node2D)

.. method:: __init__(data, wavelet, [mode='symmetric', [maxlevel=None]])
.. method:: __init__(data, wavelet, [mode='symmetric', [maxlevel=None, [axes=(-2, -1)]]])

:param data: data associated with the node. 2D numeric array.
:param data: data associated with the node. N-dimensional numeric array.

:param wavelet: |wavelet|

Expand All @@ -304,6 +347,8 @@ WaveletPacket2D and WaveletPacket2D tree Node2D
it will be calculated based on the ``wavelet`` and
``data`` length using :func:`pywt.dwt_max_level`.

:param axes: The axes of the array that are to be transformed.

.. method:: get_level(level, [order="natural", [decompose=True]])

Collects nodes from the given level of decomposition.
Expand All @@ -320,3 +365,75 @@ WaveletPacket2D and WaveletPacket2D tree Node2D
If nodes at the given level are missing (i.e. the tree is partially
decomposed) and the ``decompose`` is set to ``False``, only existing nodes
will be returned.

.. method:: reconstruct([update=True])

Reconstruct data from the subnodes.

:param update: A boolean indicating whether the coefficients of the
current node and its subnodes will be replaced with values
from the reconstruction.

WaveletPacketND and NodeND
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. class:: NodeND(BaseNode)

.. attribute:: node_name

For :class:`WaveletPacketND` case it is just as in :func:`dwtn`:
- in 1D it has keys 'a' and 'd'
- in 2D it has keys 'aa', 'ad', 'da', 'dd'
- in 3D it has keys 'aaa', 'aad', 'ada', 'daa', ..., 'ddd'

.. method:: decompose()

.. seealso::

:func:`dwtn` for ND Discrete Wavelet Transform output coefficients.

.. method:: reconstruct()

.. seealso::

:func:`idwtn` for ND Inverse Discrete Wavelet Transform


.. class:: WaveletPacketND(NodeND)

.. method:: __init__(data, wavelet, [mode='symmetric', [maxlevel=None, [axes=None]]])

:param data: data associated with the node. N-dimensional numeric array.

:param wavelet: |wavelet|

:param mode: Signal extension :ref:`mode <ref-modes>` for the :func:`dwt`
and :func:`idwt` decomposition and reconstruction functions.

:param maxlevel: Maximum allowed level of decomposition. If not specified
it will be calculated based on the ``wavelet`` and
``data`` length using :func:`pywt.dwt_max_level`.

:param axes: The axes of the array that are to be transformed.

.. method:: get_level(level, [decompose=True])

Collects nodes from the given level of decomposition.

:param level: Specifies decomposition ``level`` from which the nodes will
be collected.

:param decompose: If set then the method will try to decompose the data up
to the specified ``level``.

If nodes at the given level are missing (i.e. the tree is partially
decomposed) and the ``decompose`` is set to ``False``, only existing nodes
will be returned.

.. method:: reconstruct([update=True])

Reconstruct data from the subnodes.

:param update: A boolean indicating whether the coefficients of the
current node and its subnodes will be replaced with values
from the reconstruction.
3 changes: 3 additions & 0 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

# Raise error for invalid key combinations
coeffs = _fix_coeffs(coeffs)

Expand Down

0 comments on commit 9996424

Please sign in to comment.