Skip to content

Commit

Permalink
Merge pull request #546 from grlee77/pickling
Browse files Browse the repository at this point in the history
ENH: make Wavelet, WaveletPacket, WaveletPacket2D and ContinuousWavelet pickleable
  • Loading branch information
rgommers committed Nov 6, 2021
2 parents b8cd694 + d42bc24 commit 300ddd9
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 5 deletions.
6 changes: 6 additions & 0 deletions pywt/_extensions/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ cdef public class Wavelet [type WaveletType, object WaveletObject]:
wavelet.free_discrete_wavelet(self.w)
self.w = NULL

def __reduce__(self):
return (Wavelet, (self.name, self.filter_bank))

def __len__(self):
return self.w.dec_len

Expand Down Expand Up @@ -763,6 +766,9 @@ cdef public class ContinuousWavelet [type ContinuousWaveletType, object Continuo
wavelet.free_continuous_wavelet(self.w)
self.w = NULL

def __reduce__(self):
return (ContinuousWavelet, (self.name, self.dt))

property family_number:
"Wavelet family number"
def __get__(self):
Expand Down
8 changes: 8 additions & 0 deletions pywt/_wavelet_packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ def __init__(self, data, wavelet, mode='symmetric', maxlevel=None):

self._maxlevel = maxlevel

def __reduce__(self):
return (WaveletPacket,
(self.data, self.wavelet, self.mode, self.maxlevel))

def reconstruct(self, update=True):
"""
Reconstruct data value using coefficients from subnodes.
Expand Down Expand Up @@ -667,6 +671,10 @@ def __init__(self, data, wavelet, mode='smooth', maxlevel=None):
self.data_size = None
self._maxlevel = maxlevel

def __reduce__(self):
return (WaveletPacket2D,
(self.data, self.wavelet, self.mode, self.maxlevel))

def reconstruct(self, update=True):
"""
Reconstruct data using coefficients from subnodes.
Expand Down
14 changes: 13 additions & 1 deletion pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import os
from itertools import product
import pickle

from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises, assert_equal)
Expand Down Expand Up @@ -447,3 +448,14 @@ def test_cwt_method_fft():
# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)


def test_continuous_wavelet_pickle(tmpdir):
wavelet = pywt.ContinuousWavelet('cmor1.5-1.0')
filename = os.path.join(tmpdir, 'cwav.pickle')
with open(filename, 'wb') as f:
pickle.dump(wavelet, f)
with open(filename, 'rb') as f:
wavelet2 = pickle.load(f)
assert isinstance(wavelet2, pywt.ContinuousWavelet)
assert wavelet2.name == wavelet.name
15 changes: 13 additions & 2 deletions pywt/tests/test_wavelet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import

import os
import pickle
import numpy as np
from numpy.testing import assert_allclose, assert_

Expand Down Expand Up @@ -264,3 +264,14 @@ def test_wavefun_bior13():
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)


def test_wavelet_pickle(tmpdir):
wavelet = pywt.Wavelet('sym4')
filename = os.path.join(tmpdir, 'wav.pickle')
with open(filename, 'wb') as f:
pickle.dump(wavelet, f)
with open(filename, 'rb') as f:
wavelet2 = pickle.load(f)
assert isinstance(wavelet2, pywt.Wavelet)
assert wavelet2.name == wavelet.name
13 changes: 12 additions & 1 deletion pywt/tests/test_wp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python

from __future__ import division, print_function, absolute_import
import os
import pickle

import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
Expand Down Expand Up @@ -195,3 +196,13 @@ def test_db3_roundtrip():
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)


def test_wavelet_packet_pickle(tmpdir):
packet = pywt.WaveletPacket(np.arange(16), 'sym4')
filename = os.path.join(tmpdir, 'wp.pickle')
with open(filename, 'wb') as f:
pickle.dump(packet, f)
with open(filename, 'rb') as f:
packet2 = pickle.load(f)
assert isinstance(packet2, pywt.WaveletPacket)
13 changes: 12 additions & 1 deletion pywt/tests/test_wp2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python

from __future__ import division, print_function, absolute_import
import os
import pickle

import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
Expand Down Expand Up @@ -175,3 +176,13 @@ def test_2d_roundtrip():
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)


def test_wavelet_packet2d_pickle(tmpdir):
packet = pywt.WaveletPacket2D(np.arange(256).reshape(16, 16), 'sym4')
filename = os.path.join(tmpdir, 'wp2d.pickle')
with open(filename, 'wb') as f:
pickle.dump(packet, f)
with open(filename, 'rb') as f:
packet2 = pickle.load(f)
assert isinstance(packet2, pywt.WaveletPacket2D)

0 comments on commit 300ddd9

Please sign in to comment.