Skip to content

Commit

Permalink
added classical tournier07 csd optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rutgerfick committed Jun 23, 2018
1 parent 9a26d81 commit 8c9f1df
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 27 deletions.
2 changes: 1 addition & 1 deletion dmipy/core/fitted_modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def fod_sh(self):
spherical harmonics coefficients of the FODs, scaled by volume
fraction.
"""
return self.fitted_parameter['sh_coeff']
return self.fitted_parameters['sh_coeff']

def anisotropy_index(self):
"""
Expand Down
14 changes: 13 additions & 1 deletion dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..optimizers.brute2fine import (
GlobalBruteOptimizer, Brute2FineOptimizer)
from ..optimizers_fod.cvxpy_fod import GeneralPurposeCSDOptimizer
from ..optimizers_fod.csd_tournier import CsdTournierOptimizer
from ..optimizers.mix import MixOptimizer
from dipy.utils.optpkg import optional_package
from graphviz import Digraph
Expand Down Expand Up @@ -979,6 +980,7 @@ def fit(self, acquisition_scheme, data,
else:
msg = "Unknown solver name {}".format(solver)
raise ValueError(msg)
self.optimizer = fit_func

start = time()
for idx, pos in enumerate(zip(*mask_pos)):
Expand Down Expand Up @@ -1352,6 +1354,7 @@ def fit(self, acquisition_scheme, data,
else:
msg = "Unknown solver name {}".format(solver)
raise ValueError(msg)
self.optimizer = fit_func

start = time()
for idx, pos in enumerate(zip(*mask_pos)):
Expand Down Expand Up @@ -1699,9 +1702,19 @@ def fit(self, acquisition_scheme, data, mask=None,
if solver == 'cvxpy':
fit_func = GeneralPurposeCSDOptimizer(
acquisition_scheme, self, x0_, self.sh_order, unity_constraint)
elif solver == 'tournier07':
fit_func = CsdTournierOptimizer(
acquisition_scheme, self, x0_, self.sh_order,
unity_constraint=unity_constraint)
print('Setup Tournier07 FOD optimizer in {} seconds'.format(
time() - start))
else:
msg = "Unknown solver name {}".format(solver)
raise ValueError(msg)
print('Setup CVXPY FOD optimizer in {} seconds'.format(
time() - start))
self.optimizer = fit_func

start = time()
for idx, pos in enumerate(zip(*mask_pos)):
voxel_E = data_[pos] / S0[pos]
Expand All @@ -1721,7 +1734,6 @@ def fit(self, acquisition_scheme, data, mask=None,
len(fitted_parameters_lin), fitting_time))
print('Average of {} seconds per voxel.'.format(
fitting_time / N_voxels))

fitted_parameters = np.zeros_like(x0_, dtype=float)
fitted_parameters[mask_pos] = (
fitted_parameters_lin * self.scales_for_optimization)
Expand Down
27 changes: 27 additions & 0 deletions dmipy/core/tests/test_csd_equivalence_parametric_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@
sphere = get_sphere('symmetric724')


def test_equivalence_csd_and_parametric_fod_tournier07(
odi=0.15, mu=[0., 0.], lambda_par=1.7e-9):
stick = cylinder_models.C1Stick()
watsonstick = distribute_models.SD1WatsonDistributed(
[stick])

params = {'SD1Watson_1_odi': odi,
'SD1Watson_1_mu': mu,
'C1Stick_1_lambda_par': lambda_par}

data = watsonstick(scheme, **params)

sh_mod = modeling_framework.MultiCompartmentSphericalHarmonicsModel(
[stick])
sh_mod.set_fixed_parameter('C1Stick_1_lambda_par', lambda_par)

sh_fit = sh_mod.fit(scheme, data, solver='tournier07')
fod = sh_fit.fod(sphere.vertices)

watson = distributions.SD1Watson(mu=[0., 0.], odi=0.15)
sf = watson(sphere.vertices)
assert_array_almost_equal(fod[0], sf, 1)

fitted_signal = sh_fit.predict()
assert_array_almost_equal(data, fitted_signal[0], 2)


@np.testing.dec.skipif(not have_cvxpy)
def test_equivalence_csd_and_parametric_fod(
odi=0.15, mu=[0., 0.], lambda_par=1.7e-9):
Expand Down
8 changes: 3 additions & 5 deletions dmipy/optimizers_fod/construct_observation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
]


def construct_model_based_A_matrix(acquisition_scheme, model, lmax):
def construct_model_based_A_matrix(acquisition_scheme, model_rh, lmax):
"""Constructs the multi-shell observation matrix from spherical_harmonics
to DWIs. Follows the notation of Eq. (2) in [1]_.
Expand All @@ -21,8 +21,8 @@ def construct_model_based_A_matrix(acquisition_scheme, model, lmax):
----------
acquisition_scheme : DmipyAcquisitionScheme instance,
An acquisition scheme that has been instantiated using dmipy.
model: dmipy signal model,
dmipy model with all parameters fixed.
model_rh: array of size (N_shells, N_rh_coeffs for that shell),
rotational harmonics for every shell.
lmax: even positive integer,
even maximum spherical harmonics order of the to-be-estimated FOD.
Expand All @@ -41,8 +41,6 @@ def construct_model_based_A_matrix(acquisition_scheme, model, lmax):
Ams = np.zeros([acquisition_scheme.number_of_measurements, Ncoef])
Ams[acquisition_scheme.b0_mask, 0] = 2 * np.sqrt(np.pi)

model_rh = model.rotational_harmonics_representation(acquisition_scheme)

sh_eigenvalues = np.zeros([len(model_rh), Ncoef])

# prepare the rotational harmonics of the kernel
Expand Down

0 comments on commit 8c9f1df

Please sign in to comment.