Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wavelet support #315

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions modopt/opt/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""LINEAR OPERATORS.

This module contains linear operator classes.

:Author: Samuel Farrens <samuel.farrens@cea.fr>
:Author: Pierre-Antoine Comby <pierre-antoine.comby@cea.fr>
"""

from .base import LinearParent, Identity, MatrixOperator, LinearCombo

from .wavelet import WaveletConvolve, WaveletTransform


__all__ = [
"LinearParent",
"Identity",
"MatrixOperator",
"LinearCombo",
"WaveletConvolve",
"WaveletTransform",
]
50 changes: 2 additions & 48 deletions modopt/opt/linear.py → modopt/opt/linear/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
# -*- coding: utf-8 -*-

"""LINEAR OPERATORS.

This module contains linear operator classes.

:Author: Samuel Farrens <samuel.farrens@cea.fr>

"""
"""Base classes for linear operators."""

import numpy as np

from modopt.base.types import check_callable, check_float
from modopt.base.types import check_callable
from modopt.base.backend import get_array_module
from modopt.signal.wavelet import filter_convolve_stack


class LinearParent(object):
"""Linear Operator Parent Class.
Expand Down Expand Up @@ -99,42 +89,6 @@ def __init__(self, array):
self.adj_op = lambda x: array.T @ x


class WaveletConvolve(LinearParent):
"""Wavelet Convolution Class.

This class defines the wavelet transform operators via convolution with
predefined filters.

Parameters
----------
filters: numpy.ndarray
Array of wavelet filter coefficients
method : str, optional
Convolution method (default is ``'scipy'``)

See Also
--------
LinearParent : parent class
modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution

"""

def __init__(self, filters, method='scipy'):

self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
self._filters,
method=method,
)
self.adj_op = lambda input_data: filter_convolve_stack(
input_data,
self._filters,
filter_rot=True,
method=method,
)


class LinearCombo(LinearParent):
"""Linear Combination Class.

Expand Down
216 changes: 216 additions & 0 deletions modopt/opt/linear/wavelet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#!/usr/bin/env python3
"""Wavelet operator, using either scipy filter or pywavelet."""
import warnings

import numpy as np

from modopt.base.types import check_float
from modopt.signal.wavelet import filter_convolve_stack

from .base import LinearParent

pywt_available = True
try:
import pywt
from joblib import Parallel, cpu_count, delayed
except ImportError:
pywt_available = False


class WaveletConvolve(LinearParent):
"""Wavelet Convolution Class.

This class defines the wavelet transform operators via convolution with
predefined filters.

Parameters
----------
filters: numpy.ndarray
Array of wavelet filter coefficients
method : str, optional
Convolution method (default is ``'scipy'``)

See Also
--------
LinearParent : parent class
modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution

"""

def __init__(self, filters, method='scipy'):

self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
self._filters,
method=method,
)
self.adj_op = lambda input_data: filter_convolve_stack(
input_data,
self._filters,
filter_rot=True,
method=method,
)



class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.

This is a light wrapper around PyWavelet, with multicoil support.

Parameters
----------
wavelet_name: str
the wavelet name to be used during the decomposition.
shape: tuple[int,...]
Shape of the input data. The shape should be a tuple of length 2 or 3.
It should not contains coils or batch dimension.
nb_scales: int, default 4
the number of scales in the decomposition.
n_batchs: int, default 1
the number of channel/ batch dimension
n_jobs: int, default 1
the number of cores to use for multichannel.
backend: str, default "threading"
the backend to use for parallel multichannel linear operation.
verbose: int, default 0
the verbosity level.

Attributes
----------
nb_scale: int
number of scale decomposed in wavelet space.
n_jobs: int
number of jobs for parallel computation
n_batchs: int
number of coils use f
backend: str
Backend use for parallel computation
verbose: int
Verbosity level
"""

def __init__(
self,
wavelet_name,
shape,
level=4,
n_batch=1,
n_jobs=1,
decimated=True,
backend="threading",
mode="symmetric",
):
if not pywt_available:
raise ImportError(
"PyWavelet and/or joblib are not available. Please install it to use WaveletTransform."
)
if wavelet_name not in pywt.wavelist(kind="all"):
raise ValueError(
"Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``"
)

self.wavelet = wavelet_name
if isinstance(shape, int):
shape = (shape,)
self.shape = shape
self.n_jobs = n_jobs
self.mode = mode
self.level = level
if not decimated:
raise NotImplementedError(
"Undecimated Wavelet Transform is not implemented yet."
)
ca, *cds = pywt.wavedecn_shapes(
self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level
)
self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()]

if len(shape) > 1:
self.dwt = pywt.wavedecn
self.idwt = pywt.waverecn
self._pywt_fun = "wavedecn"
else:
self.dwt = pywt.wavedec
self.idwt = pywt.waverec
self._pywt_fun = "wavedec"

self.n_batch = n_batch
if self.n_batch == 1 and self.n_jobs != 1:
warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1")
self.n_jobs = 1
self.backend = backend
n_proc = self.n_jobs
if n_proc < 0:
n_proc = cpu_count() + self.n_jobs + 1

def op(self, data):
"""Define the wavelet operator.

This method returns the input data convolved with the wavelet filter.

Parameters
----------
data: ndarray or Image
input 2D data array.

Returns
-------
coeffs: ndarray
the wavelet coefficients.
"""
if self.n_batch > 1:
coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip(
*Parallel(
n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
)(delayed(self._op)(data[i]) for i in np.arange(self.n_batch))
)
coeffs = np.asarray(coeffs)
else:
coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data)
return coeffs

def _op(self, data):
"""Single coil wavelet transform."""
return pywt.ravel_coeffs(
self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet)
)

def adj_op(self, coeffs):
"""Define the wavelet adjoint operator.

This method returns the reconstructed image.

Parameters
----------
coeffs: ndarray
the wavelet coefficients.

Returns
-------
data: ndarray
the reconstructed data.
"""
if self.n_batch > 1:
images = Parallel(
n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose
)(
delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i])
for i in np.arange(self.n_batch)
)
images = np.asarray(images)
else:
images = self._adj_op(coeffs)
return images

def _adj_op(self, coeffs):
"""Single coil inverse wavelet transform."""
return self.idwt(
pywt.unravel_coeffs(
coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun
),
wavelet=self.wavelet,
mode=self.mode,
)
21 changes: 20 additions & 1 deletion modopt/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
except ImportError:
SKLEARN_AVAILABLE = False

PYWT_AVAILABLE = True
try:
import pywt
import joblib
except ImportError:
PYWT_AVAILABLE = False

# Basic functions to be used as operators or as dummy functions
func_identity = lambda x_val: x_val
Expand Down Expand Up @@ -156,7 +162,7 @@ def case_linear_identity(self):

return linop, data_op, data_adj_op, res_op, res_adj_op

def case_linear_wavelet(self):
def case_linear_wavelet_convolve(self):
"""Case linear operator wavelet."""
linop = linear.WaveletConvolve(
filters=np.arange(8).reshape(2, 2, 2).astype(float)
Expand All @@ -168,6 +174,19 @@ def case_linear_wavelet(self):

return linop, data_op, data_adj_op, res_op, res_adj_op

@pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")
def case_linear_wavelet_transform(self):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
level=2,
)
data_op = np.arange(64).reshape(8, 8).astype(float)
res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
data_adj_op = linop.op(data_op)
res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
return linop, data_op, data_adj_op, res_op, res_adj_op

@parametrize(weights=[[1.0, 1.0], None])
def case_linear_combo(self, weights):
"""Case linear operator combo with weights."""
Expand Down
Loading