Skip to content

Commit

Permalink
Fixed an issue where the first stage of the DTCWT was not shifted pro…
Browse files Browse the repository at this point in the history
…perly (#36)
  • Loading branch information
LaurentRDC committed Jul 21, 2021
1 parent 2222142 commit 02dd54b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

Release 2.1.7
-------------

* Fixed an issue where the first stage of the dual-tree complex wavelet transform was not shifted properly (#36).

Release 2.1.6
-------------

Expand Down
2 changes: 1 addition & 1 deletion skued/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__author__ = "Laurent P. René de Cotret"
__email__ = "laurent.renedecotret@mail.mcgill.ca"
__license__ = "GPLv3"
__version__ = "2.1.6"
__version__ = "2.1.7"

from .affine import (
affine_map,
Expand Down
1 change: 1 addition & 0 deletions skued/baseline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
available_dt_filters,
available_first_stage_filters,
dt_max_level,
dt_first_stage,
dtcwt,
idtcwt,
)
14 changes: 7 additions & 7 deletions skued/baseline/dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,18 @@ def dt_first_stage(wavelet):

# extend filter bank with zeros
filter_bank = [np.array(f, copy=True) for f in wavelet.filter_bank]
for filt in filter_bank:
for i, filt in enumerate(filter_bank):
extended = np.zeros(shape=(filt.shape[0] + 2,), dtype=float)
extended[1:-1] = filt
filt = extended
filter_bank[i] = extended

# Shift deconstruction filters to one side, and reconstruction
# to the other side
shifted_fb = [np.array(f, copy=True) for f in wavelet.filter_bank]
for filt in shifted_fb[::2]: # Deconstruction filters
filt = np.roll(filt, 1)
for filt in shifted_fb[2::]: # Reconstruction filters
filt = np.roll(filt, -1)
shifted_fb = [np.array(f, copy=True) for f in filter_bank]
for i, filt in enumerate(shifted_fb[:2]): # Deconstruction filters
shifted_fb[i] = np.roll(filt, 1)
for i, filt in enumerate(shifted_fb[2:], start=2): # Reconstruction filters
shifted_fb[i] = np.roll(filt, -1)

return (
Wavelet(name=wavelet.name, filter_bank=filter_bank),
Expand Down
27 changes: 22 additions & 5 deletions skued/baseline/tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from skued.baseline import (
available_dt_filters,
available_first_stage_filters,
dt_first_stage,
dt_max_level,
dtcwt,
idtcwt,
Expand All @@ -17,13 +18,29 @@
np.random.seed(23)


def test_first_stage():
@pytest.mark.parametrize("first_stage_wavelet", available_first_stage_filters())
def test_first_stage(first_stage_wavelet):
"""Test of perfect reconstruction of first stage wavelets."""
array = np.sin(np.arange(0, 10, step=0.01))
for wavelet in available_first_stage_filters():
# Using waverec and wavedec instead of dwt and idwt because parameters
# don't need as much parsing.
assert np.allclose(array, pywt.waverec(pywt.wavedec(array, wavelet), wavelet))
# Using waverec and wavedec instead of dwt and idwt because parameters
# don't need as much parsing.
assert np.allclose(
array,
pywt.waverec(pywt.wavedec(array, first_stage_wavelet), first_stage_wavelet),
)


@pytest.mark.parametrize("first_stage_wavelet", available_first_stage_filters())
def test_first_stage_issue_36(first_stage_wavelet):
"""Test that first-stage wavelets are properly shifted. See Issue 36"""
w1, w2 = dt_first_stage(first_stage_wavelet)
# Reconstruction should be shifted back
for f1, f2 in zip(w1.filter_bank[:2], w2.filter_bank[:2]):
assert np.allclose(f1[1:-1], f2[2::])

# Deconstruction should be shifted forward
for f1, f2 in zip(w1.filter_bank[2::], w2.filter_bank[2::]):
assert np.allclose(f1[1:-1], f2[0:-2])


def gen_input(n_dimensions):
Expand Down

0 comments on commit 02dd54b

Please sign in to comment.