Skip to content

Commit

Permalink
Merge 0aa102b into df58724
Browse files Browse the repository at this point in the history
  • Loading branch information
bensarthou committed Jul 20, 2018
2 parents df58724 + 0aa102b commit be0316a
Show file tree
Hide file tree
Showing 13 changed files with 704 additions and 102 deletions.
25 changes: 22 additions & 3 deletions pysap/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class WaveletTransformBase(with_metaclass(MetaRegister)):
Available transforms are define in 'pysap.transform'.
"""
def __init__(self, nb_scale, verbose=0, **kwargs):
def __init__(self, nb_scale, verbose=0, dim=2, **kwargs):
""" Initialize the WaveletTransformBase class.
Parameters
Expand All @@ -90,6 +90,7 @@ def __init__(self, nb_scale, verbose=0, **kwargs):
self.scales_lengths = None
self.scales_padds = None
self.use_wrapping = pysparse is None
self.data_dim = dim

# Data that can be decalred afterward
self._data = None
Expand All @@ -110,8 +111,18 @@ def __init__(self, nb_scale, verbose=0, **kwargs):
self.__isap_transform_id__)
kwargs["number_of_scales"] = self.nb_scale
self.trf = pysparse.MRTransform(**self.kwargs)
if self.data_dim == 2:
self.trf = pysparse.MRTransform(**self.kwargs)
elif self.data_dim == 3:
self.trf = pysparse.MRTransform3D(**self.kwargs)
else:
raise NameError('Please define a correct dimension for data')
else:
self.trf = None
if self.data_dim == 2:
self.trf = None
elif self.data_dim == 3:
raise NameError('For 3D, only the bindings work for now')

def __reduce__(self):
""" The interface to pickle dump call.
Expand Down Expand Up @@ -233,8 +244,16 @@ def _set_data(self, data):
print("[info] Replacing existing input data array.")
if not all([e == data.shape[0] for e in data.shape]):
raise ValueError("Expect a square shape data.")
if data.ndim != 2:
raise ValueError("Expect a two-dim data array.")
if data.ndim != self.data_dim:
if self.data_dim == 2:
raise ValueError("This wavelet can only be applied on 2D"
" square images")
if self.data_dim == 3:
raise ValueError("This wavelet can only be applied on 3D"
" cubic images")
else:
raise ValueError("Those data dimensions aren't managed by"
" current transformation")
if self.is_decimated and not (data.shape[0] // 2**(self.nb_scale) > 0):
raise ValueError("Can't decimate the data with the specified "
"number of scales.")
Expand Down
3 changes: 3 additions & 0 deletions pysap/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
from .tools import mr_filter
from .tools import mr_deconv
from .tools import mr_recons
from .tools import mr3d_recons
from .tools import mr3d_transform
from .tools import mr3d_filter
from .formating import FLATTENING_FCTS as ISAP_FLATTEN
from .formating import INFLATING_FCTS as ISAP_UNFLATTEN
99 changes: 99 additions & 0 deletions pysap/extensions/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,102 @@ def mr_recons(
# Execute the command
process = Sparse2dWrapper(verbose=verbose)
process(cmd)


def mr3d_recons(in_mr_file, out_image, verbose=False):
""" Wrap the Sparse2d 'mr3d_recons'.
"""
# Generate the command
cmd = ["mr3d_recons"]
if verbose:
cmd.append("-v")
cmd += [in_mr_file, out_image]

# Execute the command
process = Sparse2dWrapper(verbose=verbose)
process(cmd)


def mr3d_transform(
in_image, out_mr_file, type_of_multiresolution_transform=2,
type_of_lifting_transform=3, number_of_scales=4,
type_of_filters=1, use_l2_norm=False,
verbose=False):
""" Wrap the Sparse2d 'mr3d_trans'.
"""
# Generate the command
cmd = [
"mr3d_trans",
"-t", type_of_multiresolution_transform,
"-n", number_of_scales]
for key, value in [("-v", verbose)]:
if value:
cmd.append(key)

# Bi orthogonal transform
if type_of_multiresolution_transform == 1:
if type_of_filters == 10:
raise ValueError('Wrong type of filters with orthogonal transform')
if type_of_lifting_transform != 3 and\
type_of_lifting_transform is not None:
raise ValueError('Wrong type of lifting transform with orthogonal')
for key, value in [("-l", type_of_lifting_transform),
("-T", type_of_filters)]:
if value is not None:
cmd += [key, value]
for key, value in [("-L", use_l2_norm)]:
if value:
cmd.append(key)

# (bi) orthogonal transform with lifting
if type_of_multiresolution_transform == 2:
for key, value in [("-l", type_of_lifting_transform)]:
if value is not None:
cmd += [key, value]

# A trous wavelet transform
if type_of_multiresolution_transform == 3:
if type_of_lifting_transform != 3 and\
type_of_lifting_transform is not None:
raise ValueError('Wrong type of lifting transform with orthogonal')
for key, value in [("-l", type_of_lifting_transform)]:
if value is not None:
cmd += [key, value]

cmd += [in_image, out_mr_file]

# Execute the command
process = Sparse2dWrapper(verbose=verbose)
process(cmd)


def mr3d_filter(
in_image, out_image,
type_of_multiresolution_transform=2, type_of_filters=1,
sigma=None, correlated_noise=None, number_of_scales=4,
nsigma=3,
verbose=False):
""" Wrap the Sparse2d 'mr3d_filter'.
"""
# WARNING: relative path with ~ doesn't work, use absolute path from /home
# Generate the command
cmd = [
"mr3d_filter",
"-t", type_of_multiresolution_transform,
"-T", type_of_filters,
"-n", number_of_scales]
for key, value in [
("-C", correlated_noise),
("-v", verbose)]:
if value:
cmd.append(key)
for key, value in [
("-g", sigma),
("-s", nsigma)]:
if value is not None:
cmd += [key, value]
cmd += [in_image, out_image]

# Execute the command
process = Sparse2dWrapper(verbose=verbose)
process(cmd)
48 changes: 48 additions & 0 deletions pysap/extensions/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
- WaveletTransformViaLiftingScheme
- OnLine53AndOnColumn44
- OnLine44AndOnColumn53
For 3D:
- BiOrthogonalTransform3D
- Wavelet3DTransformViaLiftingScheme
- ATrou3D
"""

# System import
Expand Down Expand Up @@ -595,3 +602,44 @@ class OnLine44AndOnColumn53(ISAPWaveletTransformBase):
def _update_default_transformation_parameters(self):
self.bands_names = ["a", "a", "a"]
self.bands_lengths[-1, 1:] = 0

#####################
# 3D Transforms #####
#####################


class BiOrthogonalTransform3D(ISAPWaveletTransformBase):
""" Mallat's 3D wavelet transform (7/9 biorthogonal filters)
"""
def __init__(self, nb_scale, verbose, **kwargs):
ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale,
dim=3, **kwargs)

__isap_transform_id__ = 1
__isap_name__ = "3D Wavelet transform via lifting scheme"
__is_decimated__ = True
__isap_nb_bands__ = 7


class Wavelet3DTransformViaLiftingScheme(ISAPWaveletTransformBase):
""" Wavelet transform via lifting scheme.
"""
def __init__(self, nb_scale, verbose):
ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale, dim=3)

__isap_transform_id__ = 2
__isap_name__ = "Wavelet transform via lifting scheme"
__is_decimated__ = True
__isap_nb_bands__ = 7


class ATrou3D(ISAPWaveletTransformBase):
""" Wavelet transform with the A trou algorithm.
"""
def __init__(self, nb_scale, verbose):
ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale, dim=3)

__isap_transform_id__ = 3
__isap_name__ = "3D Wavelet A Trou"
__is_decimated__ = False
__isap_nb_bands__ = 1
12 changes: 10 additions & 2 deletions pysap/plugins/mri/reconstruct/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class Wavelet2(object):
""" The 2D wavelet transform class.
"""
def __init__(self, wavelet_name, nb_scale=4, verbose=0):
def __init__(self, wavelet_name, nb_scale=4, verbose=0, **kwargs):
""" Initialize the 'Wavelet2' class.
Parameters
Expand All @@ -37,14 +37,22 @@ def __init__(self, wavelet_name, nb_scale=4, verbose=0):
the verbosity level.
"""
self.nb_scale = nb_scale
self.flatten = flatten
self.unflatten = unflatten
if wavelet_name not in pysap.AVAILABLE_TRANSFORMS:
raise ValueError(
"Unknown transformation '{0}'.".format(wavelet_name))
transform_klass = pysap.load_transform(wavelet_name)
self.transform = transform_klass(
nb_scale=self.nb_scale, verbose=verbose)
nb_scale=self.nb_scale, verbose=verbose, **kwargs)
self.coeffs_shape = None

def get_coeff(self):
return self.transform.analysis_data

def set_coeff(self, coeffs):
self.transform.analysis_data = coeffs

def op(self, data):
""" Define the wavelet operator.
Expand Down
19 changes: 9 additions & 10 deletions pysap/scripts/pysap_gridsearch
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import json
import pickle
from datetime import datetime
from argparse import RawTextHelpFormatter
from pprint import pprint
import pprint
from collections import OrderedDict
import logging
if sys.version_info[0] < 3:
Expand Down Expand Up @@ -237,13 +237,15 @@ def _launch(sigma, mask_type, acc_factor, dirname, max_nb_of_iter, n_jobs,
}

# principal gridsearch params grid
mu_list = list(np.logspace(-8, -1, 20))
nb_scales = [3, 4, 5]
# mu_list = list(np.logspace(-8, -1, 20))
mu_list = list(np.logspace(-8, -1, 2))
# nb_scales = [3, 4, 5]
nb_scales = [3]
list_wts = ["MallatWaveletTransform79Filters",
"UndecimatedBiOrthogonalTransform",
"MeyerWaveletsCompactInFourierSpace",
"BsplineWaveletTransformATrousAlgorithm",
"FastCurveletTransform"]
# "UndecimatedBiOrthogonalTransform",
# "MeyerWaveletsCompactInFourierSpace",
# "BsplineWaveletTransformATrousAlgorithm",
"FastCurveletTransform"]

for wt in list_wts:

Expand Down Expand Up @@ -324,7 +326,6 @@ global_params["verbose_gridsearch"] = bool(global_params[
"verbose_gridsearch"])
global_params["max_nb_of_iter"] = int(global_params["max_nb_of_iter"])

global_params["verbose_reconstruction"] = True
global_params["verbose_gridsearch"] = True
for section in config.sections():
if "Run" in section:
Expand All @@ -344,5 +345,3 @@ for section in config.sections():
for sigma in sigma_list:
params["sigma"] = sigma
_launch(**params)


15 changes: 12 additions & 3 deletions pysap/test/test_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,26 @@ def setUp(self):
def test_wavelet_transformations(self):
""" Test all the registered transformations.
"""
for image in self.images:
for image_i in self.images:
print("Process test with image '{0}'...".format(
image.metadata["path"]))
image_i.metadata["path"]))
for nb_scale in self.nb_scales:
print("- Number of scales: {0}".format(nb_scale))
for transform in self.transforms:
print(" Transform: {0}".format(transform))
transform = transform(nb_scale=nb_scale, verbose=0)

image = numpy.copy(image_i)

if transform.data_dim == 3:
image = image[64:192, 64:192]
image = numpy.tile(image, (image.shape[0], 1, 1))
transform.data = image
else:
transform.data = image

self.assertFalse(transform.use_wrapping)
transform.info
transform.data = image
transform.analysis()
# transform.show()
recim = transform.synthesis()
Expand Down
2 changes: 1 addition & 1 deletion pysap/test/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setUp(self):
print("[info] Image loaded for test: {0}.".format(
[im.data.shape for im in self.images]))
self.mask = get_sample_data("mri-mask").data
self.names = ["BsplineWaveletTransformATrousAlgorithm"]
self.names = ["MallatWaveletTransform79Filters"]
print("[info] Found {0} transformations.".format(len(self.names)))
self.nb_scales = [4]
self.nb_iter = 10
Expand Down
Loading

0 comments on commit be0316a

Please sign in to comment.