From 869af9a3ad4e4c0df77f184896833f2bd7eb76e1 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 15 Nov 2023 11:15:35 +0100 Subject: [PATCH 1/2] feat: create a linear operator module, add wavelet transform. --- modopt/opt/linear/__init__.py | 21 +++ modopt/opt/{linear.py => linear/base.py} | 50 +----- modopt/opt/linear/wavelet.py | 216 +++++++++++++++++++++++ 3 files changed, 239 insertions(+), 48 deletions(-) create mode 100644 modopt/opt/linear/__init__.py rename modopt/opt/{linear.py => linear/base.py} (85%) create mode 100644 modopt/opt/linear/wavelet.py diff --git a/modopt/opt/linear/__init__.py b/modopt/opt/linear/__init__.py new file mode 100644 index 00000000..d5c0d21f --- /dev/null +++ b/modopt/opt/linear/__init__.py @@ -0,0 +1,21 @@ +"""LINEAR OPERATORS. + +This module contains linear operator classes. + +:Author: Samuel Farrens +:Author: Pierre-Antoine Comby +""" + +from .base import LinearParent, Identity, MatrixOperator, LinearCombo + +from .wavelet import WaveletConvolve, WaveletTransform + + +__all__ = [ + "LinearParent", + "Identity", + "MatrixOperator", + "LinearCombo", + "WaveletConvolve", + "WaveletTransform", +] diff --git a/modopt/opt/linear.py b/modopt/opt/linear/base.py similarity index 85% rename from modopt/opt/linear.py rename to modopt/opt/linear/base.py index 1fd146fb..e347970d 100644 --- a/modopt/opt/linear.py +++ b/modopt/opt/linear/base.py @@ -1,19 +1,9 @@ -# -*- coding: utf-8 -*- - -"""LINEAR OPERATORS. - -This module contains linear operator classes. - -:Author: Samuel Farrens - -""" +"""Base classes for linear operators.""" import numpy as np -from modopt.base.types import check_callable, check_float +from modopt.base.types import check_callable from modopt.base.backend import get_array_module -from modopt.signal.wavelet import filter_convolve_stack - class LinearParent(object): """Linear Operator Parent Class. @@ -99,42 +89,6 @@ def __init__(self, array): self.adj_op = lambda x: array.T @ x -class WaveletConvolve(LinearParent): - """Wavelet Convolution Class. - - This class defines the wavelet transform operators via convolution with - predefined filters. - - Parameters - ---------- - filters: numpy.ndarray - Array of wavelet filter coefficients - method : str, optional - Convolution method (default is ``'scipy'``) - - See Also - -------- - LinearParent : parent class - modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution - - """ - - def __init__(self, filters, method='scipy'): - - self._filters = check_float(filters) - self.op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - method=method, - ) - self.adj_op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - filter_rot=True, - method=method, - ) - - class LinearCombo(LinearParent): """Linear Combination Class. diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py new file mode 100644 index 00000000..6e22a2b0 --- /dev/null +++ b/modopt/opt/linear/wavelet.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +"""Wavelet operator, using either scipy filter or pywavelet.""" +import warnings + +import numpy as np + +from modopt.base.types import check_float +from modopt.signal.wavelet import filter_convolve_stack + +from .base import LinearParent + +pywt_available = True +try: + import pywt + from joblib import Parallel, cpu_count, delayed +except ImportError: + pywt_available = False + + +class WaveletConvolve(LinearParent): + """Wavelet Convolution Class. + + This class defines the wavelet transform operators via convolution with + predefined filters. + + Parameters + ---------- + filters: numpy.ndarray + Array of wavelet filter coefficients + method : str, optional + Convolution method (default is ``'scipy'``) + + See Also + -------- + LinearParent : parent class + modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution + + """ + + def __init__(self, filters, method='scipy'): + + self._filters = check_float(filters) + self.op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + method=method, + ) + self.adj_op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + filter_rot=True, + method=method, + ) + + + +class WaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + + This is a light wrapper around PyWavelet, with multicoil support. + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + n_batchs: int, default 1 + the number of channel/ batch dimension + n_jobs: int, default 1 + the number of cores to use for multichannel. + backend: str, default "threading" + the backend to use for parallel multichannel linear operation. + verbose: int, default 0 + the verbosity level. + + Attributes + ---------- + nb_scale: int + number of scale decomposed in wavelet space. + n_jobs: int + number of jobs for parallel computation + n_batchs: int + number of coils use f + backend: str + Backend use for parallel computation + verbose: int + Verbosity level + """ + + def __init__( + self, + wavelet_name, + shape, + level=4, + n_batch=1, + n_jobs=1, + decimated=True, + backend="threading", + mode="symmetric", + ): + if not pywt_available: + raise ImportError( + "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform." + ) + if wavelet_name not in pywt.wavelist(kind="all"): + raise ValueError( + "Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``" + ) + + self.wavelet = wavelet_name + if isinstance(shape, int): + shape = (shape,) + self.shape = shape + self.n_jobs = n_jobs + self.mode = mode + self.level = level + if not decimated: + raise NotImplementedError( + "Undecimated Wavelet Transform is not implemented yet." + ) + ca, *cds = pywt.wavedecn_shapes( + self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level + ) + self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()] + + if len(shape) > 1: + self.dwt = pywt.wavedecn + self.idwt = pywt.waverecn + self._pywt_fun = "wavedecn" + else: + self.dwt = pywt.wavedec + self.idwt = pywt.waverec + self._pywt_fun = "wavedec" + + self.n_batch = n_batch + if self.n_batch == 1 and self.n_jobs != 1: + warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1") + self.n_jobs = 1 + self.backend = backend + n_proc = self.n_jobs + if n_proc < 0: + n_proc = cpu_count() + self.n_jobs + 1 + + def op(self, data): + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: ndarray or Image + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ + if self.n_batch > 1: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip( + *Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )(delayed(self._op)(data[i]) for i in np.arange(self.n_batch)) + ) + coeffs = np.asarray(coeffs) + else: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data) + return coeffs + + def _op(self, data): + """Single coil wavelet transform.""" + return pywt.ravel_coeffs( + self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet) + ) + + def adj_op(self, coeffs): + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ + if self.n_batch > 1: + images = Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )( + delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i]) + for i in np.arange(self.n_batch) + ) + images = np.asarray(images) + else: + images = self._adj_op(coeffs) + return images + + def _adj_op(self, coeffs): + """Single coil inverse wavelet transform.""" + return self.idwt( + pywt.unravel_coeffs( + coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun + ), + wavelet=self.wavelet, + mode=self.mode, + ) From 906ddc47f3e6a2e421772b51a70a88f8bf4c5bd5 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 15 Nov 2023 11:41:34 +0100 Subject: [PATCH 2/2] feat: add test case for wavelet transform. --- modopt/tests/test_opt.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 0e45ffb8..dace6d18 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -22,6 +22,12 @@ except ImportError: SKLEARN_AVAILABLE = False +PYWT_AVAILABLE = True +try: + import pywt + import joblib +except ImportError: + PYWT_AVAILABLE = False # Basic functions to be used as operators or as dummy functions func_identity = lambda x_val: x_val @@ -156,7 +162,7 @@ def case_linear_identity(self): return linop, data_op, data_adj_op, res_op, res_adj_op - def case_linear_wavelet(self): + def case_linear_wavelet_convolve(self): """Case linear operator wavelet.""" linop = linear.WaveletConvolve( filters=np.arange(8).reshape(2, 2, 2).astype(float) @@ -168,6 +174,19 @@ def case_linear_wavelet(self): return linop, data_op, data_adj_op, res_op, res_adj_op + @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.") + def case_linear_wavelet_transform(self): + linop = linear.WaveletTransform( + wavelet_name="haar", + shape=(8, 8), + level=2, + ) + data_op = np.arange(64).reshape(8, 8).astype(float) + res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2)) + data_adj_op = linop.op(data_op) + res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar") + return linop, data_op, data_adj_op, res_op, res_adj_op + @parametrize(weights=[[1.0, 1.0], None]) def case_linear_combo(self, weights): """Case linear operator combo with weights."""