Skip to content

Commit

Permalink
added optional package cvxpy
Browse files Browse the repository at this point in the history
  • Loading branch information
rutgerfick committed Mar 14, 2018
1 parent 04c8003 commit 170e1e8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
FittedMultiCompartmentSphericalHarmonicsModel)
from ..optimizers.brute2fine import (
GlobalBruteOptimizer, Brute2FineOptimizer)
from ..optimizers_fod.cvxpy_fod import MultiCompartmentCSDOptimizer
from ..optimizers_fod.cvxpy_fod import GeneralPurposeCSDOptimizer
from ..optimizers.mix import MixOptimizer
from dipy.utils.optpkg import optional_package
pathos, have_pathos, _ = optional_package("pathos")
Expand Down Expand Up @@ -1600,7 +1600,7 @@ def fit(self, acquisition_scheme, data, mask=None,

start = time()
if solver == 'cvxpy':
fit_func = MultiCompartmentCSDOptimizer(
fit_func = GeneralPurposeCSDOptimizer(
acquisition_scheme, self, self.sh_order, unity_constraint)
start = time()
for idx, pos in enumerate(zip(*mask_pos)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from dmipy.core import modeling_framework
from dmipy.data.saved_acquisition_schemes import wu_minn_hcp_acquisition_scheme
from dipy.data import get_sphere
import numpy as np
from numpy.testing import (
assert_array_almost_equal, assert_almost_equal, assert_raises)
from dipy.utils.optpkg import optional_package
cvxpy, have_cvxpy, _ = optional_package("cvxpy")

scheme = wu_minn_hcp_acquisition_scheme()
sphere = get_sphere('symmetric724')


@np.testing.dec.skipif(not have_cvxpy)
def test_equivalence_csd_and_parametric_fod(
odi=0.15, mu=[0., 0.], lambda_par=1.7e-9):
stick = cylinder_models.C1Stick()
Expand Down Expand Up @@ -37,6 +41,7 @@ def test_equivalence_csd_and_parametric_fod(
assert_array_almost_equal(data, fitted_signal[0], 4)


@np.testing.dec.skipif(not have_cvxpy)
def test_multi_compartment_fod_with_parametric_model(
odi=0.15, mu=[0., 0.], lambda_iso=3e-9, lambda_par=1.7e-9,
vf_intra=0.7):
Expand Down Expand Up @@ -70,6 +75,7 @@ def test_multi_compartment_fod_with_parametric_model(
assert_array_almost_equal(data, predicted_signal[0], 4)


@np.testing.dec.skipif(not have_cvxpy)
def test_spherical_harmonics_model_raises(
odi=0.15, mu=[0., 0.], lambda_par=1.7e-9):
stick = cylinder_models.C1Stick()
Expand Down
7 changes: 4 additions & 3 deletions dmipy/optimizers_fod/cvxpy_fod.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import numpy as np
import cvxpy
from dipy.data import get_sphere, HemiSphere
from dipy.reconst.shm import real_sym_sh_mrtrix
from dipy.utils.optpkg import optional_package
cvxpy, have_cvxpy, _ = optional_package("cvxpy")


__all__ = [
'MultiCompartmentCSDOptimizer'
'GeneralPurposeCSDOptimizer'
]


class MultiCompartmentCSDOptimizer:
class GeneralPurposeCSDOptimizer:
"""
General purpose optimizer for multi-compartment constrained spherical
deconvolution (MC-CSD) to estimate Fiber Orientation Distributions (FODs).
Expand Down

0 comments on commit 170e1e8

Please sign in to comment.