From 9e0707f72c1805bfc13b52dd32fd429d2b14cc8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 4 Jul 2018 16:41:26 +0200 Subject: [PATCH 01/10] correct the new gridsearch script, post_processing not working --- pysap/scripts/pysap_gridsearch | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pysap/scripts/pysap_gridsearch b/pysap/scripts/pysap_gridsearch index 76d9eb93..d1c1ec78 100644 --- a/pysap/scripts/pysap_gridsearch +++ b/pysap/scripts/pysap_gridsearch @@ -18,7 +18,7 @@ import json import pickle from datetime import datetime from argparse import RawTextHelpFormatter -from pprint import pprint +import pprint from collections import OrderedDict import logging if sys.version_info[0] < 3: @@ -237,13 +237,15 @@ def _launch(sigma, mask_type, acc_factor, dirname, max_nb_of_iter, n_jobs, } # principal gridsearch params grid - mu_list = list(np.logspace(-8, -1, 20)) - nb_scales = [3, 4, 5] + # mu_list = list(np.logspace(-8, -1, 20)) + mu_list = list(np.logspace(-8, -1, 2)) + # nb_scales = [3, 4, 5] + nb_scales = [3] list_wts = ["MallatWaveletTransform79Filters", - "UndecimatedBiOrthogonalTransform", - "MeyerWaveletsCompactInFourierSpace", - "BsplineWaveletTransformATrousAlgorithm", - "FastCurveletTransform"] + # "UndecimatedBiOrthogonalTransform", + # "MeyerWaveletsCompactInFourierSpace", + # "BsplineWaveletTransformATrousAlgorithm", + "FastCurveletTransform"] for wt in list_wts: @@ -344,5 +346,3 @@ for section in config.sections(): for sigma in sigma_list: params["sigma"] = sigma _launch(**params) - - From 0a9554a6828f2d0e9ec21a2ec5be3e48a67943d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 4 Jul 2018 16:46:12 +0200 Subject: [PATCH 02/10] remove harcoded verbose value --- pysap/scripts/pysap_gridsearch | 1 - 1 file changed, 1 deletion(-) diff --git a/pysap/scripts/pysap_gridsearch b/pysap/scripts/pysap_gridsearch index d1c1ec78..d247bc58 100644 --- a/pysap/scripts/pysap_gridsearch +++ b/pysap/scripts/pysap_gridsearch @@ -326,7 +326,6 @@ global_params["verbose_gridsearch"] = bool(global_params[ "verbose_gridsearch"]) global_params["max_nb_of_iter"] = int(global_params["max_nb_of_iter"]) -global_params["verbose_reconstruction"] = True global_params["verbose_gridsearch"] = True for section in config.sections(): if "Run" in section: From 9ff9a00753bd56d592074724278c479069af3f72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 18 Jul 2018 16:55:32 +0200 Subject: [PATCH 03/10] adding options for 3D Wavelet transform: can call bindings on cubic data, no wrapping --- pysap/base/transform.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/pysap/base/transform.py b/pysap/base/transform.py index 9cf44d9e..d56d1104 100644 --- a/pysap/base/transform.py +++ b/pysap/base/transform.py @@ -63,7 +63,7 @@ class WaveletTransformBase(with_metaclass(MetaRegister)): Available transforms are define in 'pysap.transform'. """ - def __init__(self, nb_scale, verbose=0, **kwargs): + def __init__(self, nb_scale, verbose=0, dim=2, **kwargs): """ Initialize the WaveletTransformBase class. Parameters @@ -90,6 +90,7 @@ def __init__(self, nb_scale, verbose=0, **kwargs): self.scales_lengths = None self.scales_padds = None self.use_wrapping = pysparse is None + self.data_dim = dim # Data that can be decalred afterward self._data = None @@ -110,8 +111,18 @@ def __init__(self, nb_scale, verbose=0, **kwargs): self.__isap_transform_id__) kwargs["number_of_scales"] = self.nb_scale self.trf = pysparse.MRTransform(**self.kwargs) + if self.data_dim == 2: + self.trf = pysparse.MRTransform(**self.kwargs) + elif self.data_dim == 3: + self.trf = pysparse.MRTransform3D(**self.kwargs) + else: + raise NameError('Please define a correct dimension for data') else: self.trf = None + if self.data_dim == 2: + self.trf = None + elif self.data_dim == 3: + raise NameError('For 3D, only the bindings work for now') def __reduce__(self): """ The interface to pickle dump call. @@ -235,6 +246,16 @@ def _set_data(self, data): raise ValueError("Expect a square shape data.") if data.ndim != 2: raise ValueError("Expect a two-dim data array.") + if data.ndim != self.data_dim: + if self.data_dim == 2: + raise ValueError("This wavelet can only be applied on 2D" + " square images") + if self.data_dim == 3: + raise ValueError("This wavelet can only be applied on 3D" + " cubic images") + else: + raise ValueError("Those data dimensions aren't managed by" + " current transformation") if self.is_decimated and not (data.shape[0] // 2**(self.nb_scale) > 0): raise ValueError("Can't decimate the data with the specified " "number of scales.") From ad89af565dc7a979e40cb1c28f756dccf8d7070e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 18 Jul 2018 17:01:35 +0200 Subject: [PATCH 04/10] adding 3D transforms in extension + functions to use wrapper if needed in the future (notImplemented yet) --- pysap/extensions/__init__.py | 3 ++ pysap/extensions/tools.py | 99 +++++++++++++++++++++++++++++++++++ pysap/extensions/transform.py | 47 +++++++++++++++++ 3 files changed, 149 insertions(+) diff --git a/pysap/extensions/__init__.py b/pysap/extensions/__init__.py index 2c30953b..2c666410 100755 --- a/pysap/extensions/__init__.py +++ b/pysap/extensions/__init__.py @@ -16,5 +16,8 @@ from .tools import mr_filter from .tools import mr_deconv from .tools import mr_recons +from .tools import mr3d_recons +from .tools import mr3d_transform +from .tools import mr3d_filter from .formating import FLATTENING_FCTS as ISAP_FLATTEN from .formating import INFLATING_FCTS as ISAP_UNFLATTEN diff --git a/pysap/extensions/tools.py b/pysap/extensions/tools.py index 7f5ead8e..9a089bb6 100644 --- a/pysap/extensions/tools.py +++ b/pysap/extensions/tools.py @@ -182,3 +182,102 @@ def mr_recons( # Execute the command process = Sparse2dWrapper(verbose=verbose) process(cmd) + + +def mr3d_recons(in_mr_file, out_image, verbose=False): + """ Wrap the Sparse2d 'mr3d_recons'. + """ + # Generate the command + cmd = ["mr3d_recons"] + if verbose: + cmd.append("-v") + cmd += [in_mr_file, out_image] + + # Execute the command + process = Sparse2dWrapper(verbose=verbose) + process(cmd) + + +def mr3d_transform( + in_image, out_mr_file, type_of_multiresolution_transform=2, + type_of_lifting_transform=3, number_of_scales=4, + type_of_filters=1, use_l2_norm=False, + verbose=False): + """ Wrap the Sparse2d 'mr3d_trans'. + """ + # Generate the command + cmd = [ + "mr3d_trans", + "-t", type_of_multiresolution_transform, + "-n", number_of_scales] + for key, value in [("-v", verbose)]: + if value: + cmd.append(key) + + # Bi orthogonal transform + if type_of_multiresolution_transform == 1: + if type_of_filters == 10: + raise ValueError('Wrong type of filters with orthogonal transform') + if type_of_lifting_transform != 3 and\ + type_of_lifting_transform is not None: + raise ValueError('Wrong type of lifting transform with orthogonal') + for key, value in [("-l", type_of_lifting_transform), + ("-T", type_of_filters)]: + if value is not None: + cmd += [key, value] + for key, value in [("-L", use_l2_norm)]: + if value: + cmd.append(key) + + # (bi) orthogonal transform with lifting + if type_of_multiresolution_transform == 2: + for key, value in [("-l", type_of_lifting_transform)]: + if value is not None: + cmd += [key, value] + + # A trous wavelet transform + if type_of_multiresolution_transform == 3: + if type_of_lifting_transform != 3 and\ + type_of_lifting_transform is not None: + raise ValueError('Wrong type of lifting transform with orthogonal') + for key, value in [("-l", type_of_lifting_transform)]: + if value is not None: + cmd += [key, value] + + cmd += [in_image, out_mr_file] + + # Execute the command + process = Sparse2dWrapper(verbose=verbose) + process(cmd) + + +def mr3d_filter( + in_image, out_image, + type_of_multiresolution_transform=2, type_of_filters=1, + sigma=None, correlated_noise=None, number_of_scales=4, + nsigma=3, + verbose=False): + """ Wrap the Sparse2d 'mr3d_filter'. + """ + # WARNING: relative path with ~ doesn't work, use absolute path from /home + # Generate the command + cmd = [ + "mr3d_filter", + "-t", type_of_multiresolution_transform, + "-T", type_of_filters, + "-n", number_of_scales] + for key, value in [ + ("-C", correlated_noise), + ("-v", verbose)]: + if value: + cmd.append(key) + for key, value in [ + ("-g", sigma), + ("-s", nsigma)]: + if value is not None: + cmd += [key, value] + cmd += [in_image, out_image] + + # Execute the command + process = Sparse2dWrapper(verbose=verbose) + process(cmd) diff --git a/pysap/extensions/transform.py b/pysap/extensions/transform.py index 04aa4617..4c6b4d93 100644 --- a/pysap/extensions/transform.py +++ b/pysap/extensions/transform.py @@ -43,6 +43,12 @@ - WaveletTransformViaLiftingScheme - OnLine53AndOnColumn44 - OnLine44AndOnColumn53 + +For 3D: + +- BiOrthogonalTransform3D +- Wavelet3DTransformViaLiftingScheme +- ATrou3D """ # System import @@ -595,3 +601,44 @@ class OnLine44AndOnColumn53(ISAPWaveletTransformBase): def _update_default_transformation_parameters(self): self.bands_names = ["a", "a", "a"] self.bands_lengths[-1, 1:] = 0 + +##################### +# 3D Transforms ##### +##################### + + +class BiOrthogonalTransform3D(ISAPWaveletTransformBase): + """ Mallat's 3D wavelet transform (7/9 biorthogonal filters) + """ + def __init__(self, nb_scale, verbose, **kwargs): + ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale, + dim=3, **kwargs) + + __isap_transform_id__ = 1 + __isap_name__ = "3D Wavelet transform via lifting scheme" + __is_decimated__ = True + __isap_nb_bands__ = 7 + + +class Wavelet3DTransformViaLiftingScheme(ISAPWaveletTransformBase): + """ Wavelet transform via lifting scheme. + """ + def __init__(self, nb_scale, verbose): + ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale, dim=3) + + __isap_transform_id__ = 2 + __isap_name__ = "Wavelet transform via lifting scheme" + __is_decimated__ = True + __isap_nb_bands__ = 7 + + +class ATrou3D(ISAPWaveletTransformBase): + """ Wavelet transform with the A trou algorithm. + """ + def __init__(self, nb_scale, verbose): + ISAPWaveletTransformBase.__init__(self, nb_scale=nb_scale, dim=3) + + __isap_transform_id__ = 3 + __isap_name__ = "3D Wavelet A Trou" + __is_decimated__ = False + __isap_nb_bands__ = 1 From aa58c89710885bb21a40086c162a2d8f69402e3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 18 Jul 2018 17:06:12 +0200 Subject: [PATCH 05/10] adding kwargs on linear op to change the filters on 3D wavelets --- pysap/plugins/mri/reconstruct/linear.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pysap/plugins/mri/reconstruct/linear.py b/pysap/plugins/mri/reconstruct/linear.py index bd2c059a..6d2e325b 100644 --- a/pysap/plugins/mri/reconstruct/linear.py +++ b/pysap/plugins/mri/reconstruct/linear.py @@ -24,7 +24,7 @@ class Wavelet2(object): """ The 2D wavelet transform class. """ - def __init__(self, wavelet_name, nb_scale=4, verbose=0): + def __init__(self, wavelet_name, nb_scale=4, verbose=0, **kwargs): """ Initialize the 'Wavelet2' class. Parameters @@ -37,14 +37,22 @@ def __init__(self, wavelet_name, nb_scale=4, verbose=0): the verbosity level. """ self.nb_scale = nb_scale + self.flatten = flatten + self.unflatten = unflatten if wavelet_name not in pysap.AVAILABLE_TRANSFORMS: raise ValueError( "Unknown transformation '{0}'.".format(wavelet_name)) transform_klass = pysap.load_transform(wavelet_name) self.transform = transform_klass( - nb_scale=self.nb_scale, verbose=verbose) + nb_scale=self.nb_scale, verbose=verbose, **kwargs) self.coeffs_shape = None + def get_coeff(self): + return self.transform.analysis_data + + def set_coeff(self, coeffs): + self.transform.analysis_data = coeffs + def op(self, data): """ Define the wavelet operator. From 203f7cc646c7b11c7b92a414a92e2c2f76830afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 18 Jul 2018 17:08:23 +0200 Subject: [PATCH 06/10] adding test for 3D wavelets --- pysap/test/test_binding.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pysap/test/test_binding.py b/pysap/test/test_binding.py index b4d667d0..1b75615f 100644 --- a/pysap/test/test_binding.py +++ b/pysap/test/test_binding.py @@ -42,17 +42,26 @@ def setUp(self): def test_wavelet_transformations(self): """ Test all the registered transformations. """ - for image in self.images: + for image_i in self.images: print("Process test with image '{0}'...".format( - image.metadata["path"])) + image_i.metadata["path"])) for nb_scale in self.nb_scales: print("- Number of scales: {0}".format(nb_scale)) for transform in self.transforms: print(" Transform: {0}".format(transform)) transform = transform(nb_scale=nb_scale, verbose=0) + + image = numpy.copy(image_i) + + if transform.data_dim == 3: + image = image[64:192, 64:192] + image = numpy.tile(image, (image.shape[0], 1, 1)) + transform.data = image + else: + transform.data = image + self.assertFalse(transform.use_wrapping) transform.info - transform.data = image transform.analysis() # transform.show() recim = transform.synthesis() From 39690307375117f8dfbcd6305b3aaf11fb419dbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Wed, 18 Jul 2018 17:19:30 +0200 Subject: [PATCH 07/10] adding cpp files for Boost Sparse3D binding --- sparse2d/python/NumPyArrayData.h | 172 ++++++---- .../python/cmake/Modules/BuildSparse2D.cmake | 7 +- sparse2d/python/pysparse.cpp | 77 ++++- sparse2d/python/transform.hpp | 16 +- sparse2d/python/transform_3D.hpp | 311 ++++++++++++++++++ 5 files changed, 500 insertions(+), 83 deletions(-) create mode 100644 sparse2d/python/transform_3D.hpp diff --git a/sparse2d/python/NumPyArrayData.h b/sparse2d/python/NumPyArrayData.h index c53b5db1..348740bc 100644 --- a/sparse2d/python/NumPyArrayData.h +++ b/sparse2d/python/NumPyArrayData.h @@ -27,60 +27,60 @@ namespace bn = boost::python::numpy; // Helper class for fast access to array elements template class NumPyArrayData { - char* m_data; - const Py_intptr_t* m_strides; + char* m_data; + const Py_intptr_t* m_strides; public: - NumPyArrayData(const bn::ndarray &arr) - { - bn::dtype dtype = arr.get_dtype(); - bn::dtype dtype_expected = bn::dtype::get_builtin(); - - if (dtype != dtype_expected) - { - std::stringstream ss; - ss << "NumPyArrayData: Unexpected data type (" << bp::extract(dtype.attr("__str__")()) << ") received. "; - ss << "Expected " << bp::extract(dtype_expected.attr("__str__")()); - throw std::runtime_error(ss.str().c_str()); - } - - m_data = arr.get_data(); - m_strides = arr.get_strides(); - } - - T* data() - { - return reinterpret_cast(m_data); - } - - const Py_intptr_t* strides() - { - return m_strides; - } - - // 1D array access - inline T& operator()(int i) - { - return *reinterpret_cast(m_data + i*m_strides[0]); - } - - // 2D array access - inline T& operator()(int i, int j) - { - return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1]); - } - - // 3D array access - inline T& operator()(int i, int j, int k) - { - return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1] + k*m_strides[2]); - } - - // 4D array access - inline T& operator()(int i, int j, int k, int l) - { - return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1] + k*m_strides[2] + l*m_strides[3]); - } + NumPyArrayData(const bn::ndarray &arr) + { + bn::dtype dtype = arr.get_dtype(); + bn::dtype dtype_expected = bn::dtype::get_builtin(); + + if (dtype != dtype_expected) + { + std::stringstream ss; + ss << "NumPyArrayData: Unexpected data type (" << bp::extract(dtype.attr("__str__")()) << ") received. "; + ss << "Expected " << bp::extract(dtype_expected.attr("__str__")()); + throw std::runtime_error(ss.str().c_str()); + } + + m_data = arr.get_data(); + m_strides = arr.get_strides(); + } + + T* data() + { + return reinterpret_cast(m_data); + } + + const Py_intptr_t* strides() + { + return m_strides; + } + + // 1D array access + inline T& operator()(int i) + { + return *reinterpret_cast(m_data + i*m_strides[0]); + } + + // 2D array access + inline T& operator()(int i, int j) + { + return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1]); + } + + // 3D array access + inline T& operator()(int i, int j, int k) + { + return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1] + k*m_strides[2]); + } + + // 4D array access + inline T& operator()(int i, int j, int k, int l) + { + return *reinterpret_cast(m_data + i*m_strides[0] + j*m_strides[1] + k*m_strides[2] + l*m_strides[3]); + } }; @@ -90,29 +90,71 @@ bn::ndarray image2array_2d(const Ifloat& im){ bn::ndarray arr = bn::zeros( bp::make_tuple(im.nl(), im.nc()), bn::dtype::get_builtin()); - NumPyArrayData arr_data(arr); - for (int i=0; i arr_data(arr); + for (int i=0; i arr_data(arr); - for (int i=0; i arr_data(arr); + for (int i=0; i()); + + NumPyArrayData arr_data(arr); + for (int i=0; i arr_data(arr); + + for (int i=0; i MRTransform3D_exposer_t; + MRTransform3D_exposer_t MRTransform3D_exposer = MRTransform3D_exposer_t( + "MRTransform3D", + bp::init< int, bp::optional< int, int, int, int, bool, int, int > >( + ( bp::arg("type_of_multiresolution_transform"), + bp::arg("type_of_lifting_transform")=(int)(3), + bp::arg("number_of_scales")=(int)(4), + bp::arg("iter")=(int)(3), + bp::arg("type_of_filters")=(int)(1), + bp::arg("use_l2_norm")=(bool)(false), + bp::arg("nb_procs")=(int)(0), + bp::arg("verbose")=(int)(0) ) + ) + ); + bp::scope MRTransform3D_scope( MRTransform3D_exposer ); + bp::implicitly_convertible< int, MRTransform3D >(); + + // Information method + { + typedef void ( ::MRTransform3D::*Info_function_type)( ) ; + MRTransform3D_exposer.def( + "info", + Info_function_type( &::MRTransform3D::Info ) + ); + } + + // Transform method + { + typedef ::bp::list ( ::MRTransform3D::*Transform_function_type)( ::bn::ndarray, bool ) ; + MRTransform3D_exposer.def( + "transform", + Transform_function_type( &::MRTransform3D::Transform ), + ( bp::arg("arr"), bp::arg("save")=(bool)(0) ) + ); + } + + // Reconstruction method + { + typedef ::bn::ndarray ( ::MRTransform3D::*Reconstruct_function_type)( bp::list ) ; + MRTransform3D_exposer.def( + "reconstruct", + Reconstruct_function_type( &::MRTransform3D::Reconstruct ) ); + } + + // Output path accessors + { + typedef ::std::string ( ::MRTransform3D::*get_opath_function_type)( ) const; + typedef void ( ::MRTransform3D::*set_opath_function_type)( ::std::string ) ; + MRTransform3D_exposer.add_property( + "opath", + get_opath_function_type( &::MRTransform3D::get_opath ), + set_opath_function_type( &::MRTransform3D::set_opath ) ); } } + // Module property bp::scope().attr("__version__") = "0.0.1"; bp::scope().attr("__doc__") = "Python bindings for ISAP"; diff --git a/sparse2d/python/transform.hpp b/sparse2d/python/transform.hpp index be6dcfdb..ec5c922d 100644 --- a/sparse2d/python/transform.hpp +++ b/sparse2d/python/transform.hpp @@ -124,13 +124,13 @@ MRTransform::MRTransform( // Load the lifting transform if ((this->type_of_lifting_transform > 0) && (this->type_of_lifting_transform <= NBR_LIFT)) this->lift_transform = type_lift(this->type_of_lifting_transform); - else + else throw std::invalid_argument("Invalid lifting transform number."); // Check the number of scales if ((this->number_of_scales <= 1) || (this->number_of_scales > MAX_SCALE)) throw std::invalid_argument("Bad number of scales ]1; MAX_SCALE]."); - + // Check the number of iterations if ((this->iter <= 1) || (this->iter > 20)) throw std::invalid_argument("Bad number of iteration ]1; 20]."); @@ -147,7 +147,7 @@ MRTransform::MRTransform( this->norm = NORM_L2; // Load the non orthogonal filter - if ((this->type_of_non_orthog_filters > 0) && (this->type_of_non_orthog_filters <= NBR_UNDEC_FILTER)) + if ((this->type_of_non_orthog_filters > 0) && (this->type_of_non_orthog_filters <= NBR_UNDEC_FILTER)) this->no_filter = type_undec_filter(this->type_of_non_orthog_filters - 1); // Check input parameters @@ -231,7 +231,7 @@ bp::list MRTransform::Transform(const bn::ndarray& arr, bool save){ } mr.alloc(data.nl(), data.nc(), this->number_of_scales, this->mr_transform, ptrfas, this->norm, - this->nb_of_undecimated_scales, this->no_filter); + this->nb_of_undecimated_scales, this->no_filter); if (this->mr_transform == TO_LIFTING) mr.LiftingTrans = this->lift_transform; mr.Border = this->bord; @@ -274,7 +274,6 @@ bp::list MRTransform::Transform(const bn::ndarray& arr, bool save){ if (nb_bands_count != mr.nbr_band()) { mr_scale[-1] = 1; } - // Format the result bp::list mr_result; mr_result.append(mr_data); @@ -295,8 +294,12 @@ bn::ndarray MRTransform::Reconstruct(bp::list mr_data){ // Update transformation for (int s=0; s(mr_data[s])); + // cout << "Size of inserted band "; + // cout << "nb_e:"<< band_data.n_elem() << "/ndim:" << band_data.naxis()\ + // << "/nx:" << band_data.nx() << "/ny:" << band_data.ny() << endl; + mr.insert_band(band_data, s); - } + } // Start the reconstruction Ifloat data(mr.size_ima_nl(), mr.size_ima_nc(), "Reconstruct"); @@ -304,4 +307,3 @@ bn::ndarray MRTransform::Reconstruct(bp::list mr_data){ return image2array_2d(data); } - diff --git a/sparse2d/python/transform_3D.hpp b/sparse2d/python/transform_3D.hpp new file mode 100644 index 00000000..efafaaf1 --- /dev/null +++ b/sparse2d/python/transform_3D.hpp @@ -0,0 +1,311 @@ +/*########################################################################## +# XXX - Copyright (C) XXX, 2017 +# Distributed under the terms of the CeCILL-B license, as published by +# the CEA-CNRS-INRIA. Refer to the LICENSE file or to +# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html +# for details. +##########################################################################*/ +/*Availables transforms: +1: Mallat 3D +2: Lifting +3: A trous*/ + +// Includes +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "NumPyArrayData.h" + + + +#define ASSERT_THROW(a,msg) if (!(a)) throw std::runtime_error(msg); + +class MRTransform3D { + +public: + // Constructor + MRTransform3D( + int type_of_multiresolution_transform, + int type_of_lifting_transform=3, + int number_of_scales=4, + int iter=3, + int type_of_filters=1, + bool use_l2_norm=false, + int nb_procs=0, + int verbose=0); + + // Destructor + ~MRTransform3D(); + + // Save transformation method + void Save(MR_3D &mr); + + // Information method + void Info(); + + // Transform method + bp::list Transform(const bn::ndarray& arr, bool save=false); + + // Reconstruction method + bn::ndarray Reconstruct(bp::list mr_data); + + // Getter/setter functions for the input/output image path + void set_opath(std::string path) {this->m_opath = path;} + string get_opath() const {return m_opath;} + +private: + MR_3D mr; + FilterAnaSynt fas; + FilterAnaSynt *ptrfas = NULL; + bool mr_initialized; + std::string m_opath; + int type_of_multiresolution_transform; + int type_of_lifting_transform; + int number_of_scales; + int iter; + int type_of_filters; + bool use_l2_norm; + int nb_procs; + int verbose; + + type_trans_3d mr_transform = TO3_MALLAT; + type_lift lift_transform = DEF_LIFT; + type_sb_filter filter = F_MALLAT_7_9; + + sb_type_norm norm = NORM_L1; +}; + +// Constructor +MRTransform3D::MRTransform3D( + int type_of_multiresolution_transform, + int type_of_lifting_transform, + int number_of_scales, + int iter, + int type_of_filters, + bool use_l2_norm, + int nb_procs, + int verbose){ + // Define instance attributes + this->type_of_multiresolution_transform = type_of_multiresolution_transform; + this->type_of_lifting_transform = type_of_lifting_transform; + this->number_of_scales = number_of_scales; + this->iter = iter; + this->type_of_filters = type_of_filters; + this->use_l2_norm = use_l2_norm; + this->verbose = verbose; + this->mr_initialized = false; + bool use_filter = false; + // The maximum number of threads returned by omp_get_max_threads() + // (which is the default number of threads used by OMP in parallel + // regions) can sometimes be far below the number of CPUs. + // It is then required to set it in relation to the real number of CPUs + // (the -1 is used to live one thread to the main process which ensures + // better and more constant performances). - Fabrice Poupon 2013/03/09 + #ifdef _OPENMP + if (nb_procs <= 0) + this->nb_procs = omp_get_num_procs() - 1; + else + this->nb_procs = nb_procs; + omp_set_num_threads(this->nb_procs); + #endif + + // Load the mr transform + if ((this->type_of_multiresolution_transform > 0) && (this->type_of_multiresolution_transform <= NBR_TRANS_3D+1)) + this->mr_transform = type_trans_3d(this->type_of_multiresolution_transform - 1); + else + throw std::invalid_argument("Invalid MR transform number."); + + // Load the lifting transform + if ((this->type_of_lifting_transform > 0) && (this->type_of_lifting_transform <= NBR_LIFT)) + this->lift_transform = type_lift(this->type_of_lifting_transform); + else + throw std::invalid_argument("Invalid lifting transform number."); + + // Check the number of scales + if ((this->number_of_scales <= 1) || (this->number_of_scales > MAX_SCALE_1D)) + throw std::invalid_argument("Bad number of scales ]1; MAX_SCALE]."); + + // Check the number of iterations + if ((this->iter <= 1) || (this->iter > 20)) + throw std::invalid_argument("Bad number of iteration ]1; 20]."); + + // Load the filter + if (this->type_of_filters != 1){ + std::stringstream strs; + strs << type_of_filters; + this->filter = get_filter_bank((char *)strs.str().c_str()); + use_filter = true; + } + + // Change the norm + if (this->use_l2_norm) + this->norm = NORM_L2; + + // Check compatibility between parameters + if ((this->mr_transform != TO3_MALLAT) && (this->use_l2_norm || use_filter)) + throw std::invalid_argument("transforms other than Mallat should not be used with filters and L2 norm"); + if ((this->mr_transform != TO3_LIFTING) && (this->lift_transform != DEF_LIFT)) + throw std::invalid_argument("Non lifting transforms can only be used with integer Haar WT as lifting scheme:"); + +} + +// Destructor +MRTransform3D::~MRTransform3D(){ +} + +// Save transformation method +void MRTransform3D::Save(MR_3D &mr){ + + // Welcome message + if (this->verbose > 0) + cout << " Output path: " << this->m_opath << endl; + + // Check inputs + if (this->m_opath == "") + throw std::invalid_argument( + "Please specify an output image path in 'opath'."); + + // Write the results + mr.write((char *)this->m_opath.c_str()); +} + +void MRTransform3D::Info(){ + // Information message + cout << "---------" << endl; + cout << "Information" << endl; + cout << "Runtime parameters:" << endl; + cout << " Number of procs: " << this->nb_procs << endl; + cout << " MR transform ID: " << this->type_of_multiresolution_transform << endl; + cout << " MR transform name: " << StringTransf3D(this->mr_transform) << endl; + if ((this->mr_transform == TO3_MALLAT)) { + cout << " Filter ID: " << this->type_of_filters << endl; + cout << " Filter name: " << StringSBFilter(this->filter) << endl; + if (this->use_l2_norm) + cout << " Use L2-norm." << endl; + } + if (this->mr_transform == TO3_LIFTING) { + cout << " Lifting transform ID: " << this->type_of_lifting_transform << endl; + cout << " Lifting transform name: " << StringLSTransform(this->lift_transform) << endl; + } + cout << " Number of scales: " << this->number_of_scales << endl; + cout << "---------" << endl; + } + + +// Transform method +bp::list MRTransform3D::Transform(const bn::ndarray& arr, bool save){ + + // Create the transformation + fltarray data = array2image_3d(arr); + if (!this->mr_initialized) { + if ((this->mr_transform == TO3_MALLAT)) { + fas.Verbose = (Bool)this->verbose; + fas.alloc(this->filter); + ptrfas = &fas; + } + + mr.alloc(data.nx(), data.ny(), data.nz(), this->mr_transform, + this->number_of_scales, ptrfas, this->norm); + + if (this->mr_transform == TO3_LIFTING) + mr.LiftingTrans = this->lift_transform; + + mr.Verbose = (Bool)this->verbose; + this->mr_initialized = true; + } + + // Perform the transformation + if (this->verbose > 0) { + cout << "Starting transformation" << endl; + cout << "Runtime parameters:" << endl; + cout << " Number of bands: " << mr.nbr_band() << endl; + cout << " Data dimension: " << arr.get_nd() << endl; + cout << " Array shape: " << arr.shape(0) << ", " << arr.shape(1) << ", " << arr.shape(2) << endl; + cout << " Save transform: " << save << endl; + } + + + ASSERT_THROW( + ((int)pow(2, this->number_of_scales) <= (int)min(arr.shape(0), min(arr.shape(1), arr.shape(2)))), + "Number of scales is too damn high (for the size of the data)"); + + mr.transform(data); + + // Save transform if requested + if (save) + Save(mr); + + // Return the generated bands data + bp::list mr_data; + for (int s=0; smr_transform == TO3_ATROUS ){nbr_band_per_resol_cst = 1;} + + int nb_bands_count = 0; + for (int s=0; s(bp::object(mr_result[1]).attr("__str__")())() << endl; + + return mr_result; +} + +// Reconstruction method +bn::ndarray MRTransform3D::Reconstruct(bp::list mr_data){ + // Welcome message + if (this->verbose > 0) { + cout << "Starting Reconstruction" << endl; + cout << "Runtime parameters:" << endl; + cout << " Number of bands: " << bp::len(mr_data) << endl; + } + + // Update transformation + for (int s=0; s(mr_data[s])); + // cout << "Size of inserted band "; + // cout << "nb_e:"<< band_data.n_elem() << "/ndim:" << band_data.naxis()\ + // << "/nx:" << band_data.nx() << "/ny:" << band_data.ny() << "nz:" << band_data.nz() << endl; + + mr.insert_band(s, band_data); + } + + int Nx = mr.size_cube_nx(); + int Ny = mr.size_cube_ny(); + int Nz = mr.size_cube_nz(); + + // Start the reconstruction + fltarray data(Nx, Ny, Nz, "Reconstruct"); + mr.recons(data); + + return image2array_3d(data); +} From 27af4d3fbfb0b9c884f34bddac9c361ec464b619 Mon Sep 17 00:00:00 2001 From: bensarthou <37699244+bensarthou@users.noreply.github.com> Date: Fri, 20 Jul 2018 13:52:20 +0200 Subject: [PATCH 08/10] bonker commit to rerun travis test --- pysap/extensions/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysap/extensions/transform.py b/pysap/extensions/transform.py index 4c6b4d93..7da1aaf9 100644 --- a/pysap/extensions/transform.py +++ b/pysap/extensions/transform.py @@ -49,6 +49,7 @@ - BiOrthogonalTransform3D - Wavelet3DTransformViaLiftingScheme - ATrou3D + """ # System import From 145ea696e83b5136a4990c8c6baca55a5dd6cd7a Mon Sep 17 00:00:00 2001 From: bensarthou <37699244+bensarthou@users.noreply.github.com> Date: Fri, 20 Jul 2018 16:04:50 +0200 Subject: [PATCH 09/10] Changing tag on Sparse2D version --- sparse2d/python/cmake/Modules/BuildSparse2D.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sparse2d/python/cmake/Modules/BuildSparse2D.cmake b/sparse2d/python/cmake/Modules/BuildSparse2D.cmake index a0d8bafb..5a5f0ae8 100644 --- a/sparse2d/python/cmake/Modules/BuildSparse2D.cmake +++ b/sparse2d/python/cmake/Modules/BuildSparse2D.cmake @@ -2,12 +2,12 @@ # Build the CfitsIO dependencies for the project # #========================================================# -set(sparse2dVersion 2.1.0) +set(sparse2dVersion 2.1.1) ExternalProject_Add(sparse2d PREFIX sparse2d GIT_REPOSITORY https://github.com/CosmoStat/Sparse2D.git - GIT_TAG v2.1.0 + GIT_TAG v2.1.1 # GIT_TAG master DEPENDS cfitsio CONFIGURE_COMMAND cmake ../sparse2d From a3dfdbd4bdcde19060492606aad95d561e3df86e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20SARTHOU?= Date: Fri, 20 Jul 2018 17:10:05 +0200 Subject: [PATCH 10/10] remove a miscopy of merge + change the optimizer test (wrong wavelet used with FISTA) --- pysap/base/transform.py | 2 -- pysap/test/test_optimizer.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pysap/base/transform.py b/pysap/base/transform.py index d56d1104..6092cc89 100644 --- a/pysap/base/transform.py +++ b/pysap/base/transform.py @@ -244,8 +244,6 @@ def _set_data(self, data): print("[info] Replacing existing input data array.") if not all([e == data.shape[0] for e in data.shape]): raise ValueError("Expect a square shape data.") - if data.ndim != 2: - raise ValueError("Expect a two-dim data array.") if data.ndim != self.data_dim: if self.data_dim == 2: raise ValueError("This wavelet can only be applied on 2D" diff --git a/pysap/test/test_optimizer.py b/pysap/test/test_optimizer.py index e2bc526a..61088352 100644 --- a/pysap/test/test_optimizer.py +++ b/pysap/test/test_optimizer.py @@ -38,7 +38,7 @@ def setUp(self): print("[info] Image loaded for test: {0}.".format( [im.data.shape for im in self.images])) self.mask = get_sample_data("mri-mask").data - self.names = ["BsplineWaveletTransformATrousAlgorithm"] + self.names = ["MallatWaveletTransform79Filters"] print("[info] Found {0} transformations.".format(len(self.names))) self.nb_scales = [4] self.nb_iter = 10