Skip to content

Commit

Permalink
spherical mean models can now generate CSD models. tournier07 optimiz…
Browse files Browse the repository at this point in the history
…er can now take voxel-varying convolution kernels - multiple if volume fractions are fixed. tournier07 optimizer now has laplace-beltrami regularization.
  • Loading branch information
rutgerfick committed Jun 26, 2018
1 parent 8c9f1df commit b1a0e74
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 52 deletions.
76 changes: 76 additions & 0 deletions dmipy/core/fitted_modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,82 @@ def return_parametric_fod_model(
self.fitted_parameters[smt_parameter_name])
return mc_bundles_model

def return_spherical_harmonics_fod_model(self, sh_order=8):
"""
Retuns spherical harmonics FOD model using the rotational harmonics of
the fitted spherical mean model as the convolution kernel.
Internally, the input models to the spherical mean model are given to
a MultiCompartmentSphericalHarmonicsModel where the parameter links are
replayed such that the new model has the same parameter constraints as
the spherical mean model. The FittedSphericalMeanModel parameters
are given as fixed parameters for the kernel (the kernel will not be
fitted while the FOD's coefficients are being optimized).
The function returns a MultiCompartmentSphericalHarmonicsModel instance
that can be interacted with as usual to fit dMRI data.
Parameters
----------
sh_order: even, positive integer,
Spherical harmonics order of the FODs.
Returns
-------
mc_bundles_model: Dmipy MultiCompartmentModel instance,
MultiCompartmentModel instance that can be used to estimate
parametric FODs using the fitted spherical mean model as a kernel.
"""
from .modeling_framework import MultiCompartmentSphericalHarmonicsModel

if sh_order < 0 or sh_order % 2 != 0:
msg = 'sh_order must be an even, positive integer.'
raise ValueError(msg)

sh_model = MultiCompartmentSphericalHarmonicsModel(
self.model.models, sh_order=sh_order)

for link in self.model.parameter_links:
param_to_delete = self.model._inverted_parameter_map[link[0],
link[1]]
if link[2] is T1_tortuosity:
sh_model.parameter_links.append(
[link[0], link[1], link[2], link[3][:-1]])
elif link[2] is fractional_parameter:
new_parameter_name = param_to_delete + '_fraction'
sh_model.parameter_ranges.update(
{new_parameter_name: [0., 1.]})
sh_model.parameter_scales.update({new_parameter_name: 1.})
sh_model.parameter_cardinality.update({new_parameter_name: 1})
sh_model.parameter_types.update({new_parameter_name: 'normal'})

sh_model._parameter_map.update(
{new_parameter_name: (None, 'fraction')})
sh_model._inverted_parameter_map.update(
{(None, 'fraction'): new_parameter_name})

# add parmeter link to fractional parameter
param_larger_than = self.model._inverted_parameter_map[
link[3][1][0], link[3][1][1]]

model, name = sh_model._parameter_map[param_to_delete]
sh_model.parameter_links.append(
[model, name, fractional_parameter, [
sh_model._parameter_map[new_parameter_name],
sh_model._parameter_map[param_larger_than]]])
else:
sh_model.parameter_links.append(link)
del sh_model.parameter_ranges[param_to_delete]
del sh_model.parameter_cardinality[param_to_delete]
del sh_model.parameter_scales[param_to_delete]
del sh_model.parameter_types[param_to_delete]
del sh_model.parameter_optimization_flags[param_to_delete]

for smt_par_name in self.model.parameter_names:
sh_model.set_fixed_parameter(
smt_par_name, self.fitted_parameters[smt_par_name])
return sh_model


class FittedMultiCompartmentSphericalHarmonicsModel:
"""
Expand Down
67 changes: 40 additions & 27 deletions dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def parameters_to_parameter_vector(self, **parameters):
value = np.atleast_1d(parameters[parameter])
if card == 1 and not np.all(value.shape == np.r_[1]):
parameter_shapes.append(value.shape)
if card == 2 and not np.all(value.shape == np.r_[2]):
elif card > 1 and not np.all(value.shape == np.r_[card]):
parameter_shapes.append(value.shape[:-1])

if len(set(parameter_shapes)) > 1:
Expand All @@ -187,7 +187,7 @@ def parameters_to_parameter_vector(self, **parameters):
np.tile(value[0], np.r_[parameter_shapes[0], 1]))
elif card == 1 and not np.all(value.shape == np.r_[1]):
parameter_vector.append(value[..., None])
elif card == 2 and np.all(value.shape == np.r_[2]):
elif card > 1 and np.all(value.shape == np.r_[card]):
parameter_vector.append(
np.tile(value, np.r_[parameter_shapes[0], 1])
)
Expand Down Expand Up @@ -1572,29 +1572,41 @@ def _add_spherical_harmonics_parameters(self, sh_order):
self.parameter_scales['sh_coeff'] = np.ones(N_coef, dtype=float)
self.parameter_cardinality['sh_coeff'] = N_coef
self.parameter_types['sh_coeff'] = 'sh_coefficients'
self.parameter_optimization_flags['sh_coeff'] = True

def _check_if_kernel_parameters_are_fixed(self):
"checks if only volume fraction and sh_coeff parameters are optimized."
self.volume_fractions_fixed = True
for name, flag in self.parameter_optimization_flags.items():
if flag is True:
if (not name == 'sh_coeff' and
not name.startswith('partial_volume')):
msg = 'kernel parameter {} is not fixed.'.format(name)
msg = 'Kernel parameter {} is not fixed.'.format(name)
raise ValueError(msg)
if name.startswith('partial_volume'):
self.volume_fractions_fixed = False
if (not self.volume_fractions_fixed and
self.multiple_anisotropic_kernels):
msg = 'Cannot have multiple anisotropic kernels without having '
msg += 'all volume fractions fixed.'
raise ValueError(msg)

def _check_that_one_anisotropic_kernel_is_present(self):
"checks if one anisotropic kernel is given."
orientation_counter = 0
self.multiple_anisotropic_kernels = False
for model in self.models:
if 'orientation' in model.parameter_types.values():
orientation_counter += 1
if orientation_counter != 1:
if orientation_counter == 0:
msg = 'MultiCompartmentSphericalHarmonicsModel must at least have '
msg += 'one anisotropic kernel input model.'
raise ValueError(msg)
if orientation_counter > 1:
self.multiple_anisotropic_kernels = True

def fit(self, acquisition_scheme, data, mask=None,
solver='cvxpy', unity_constraint=True,
solver='tournier07', unity_constraint=True,
use_parallel_processing=have_pathos,
number_of_processors=None):
""" The main data fitting function of a
Expand Down Expand Up @@ -1680,9 +1692,28 @@ def fit(self, acquisition_scheme, data, mask=None,
**self.x0_parameters)
x0_ = homogenize_x0_to_data(
data_, x0_)
x0_bool = np.all(
np.isnan(x0_), axis=tuple(np.arange(x0_.ndim - 1)))
x0_[..., ~x0_bool] /= self.scales_for_optimization[~x0_bool]

start = time()
if solver == 'tournier07':
fit_func = CsdTournierOptimizer(
acquisition_scheme, self, x0_, self.sh_order,
unity_constraint=unity_constraint)
if use_parallel_processing:
msg = 'Parallel processing turned off for tournier07 optimizer'
msg += ' because it does not improve fitting speed.'
print(msg)
use_parallel_processing = False
print('Setup Tournier07 FOD optimizer in {} seconds'.format(
time() - start))
elif solver == 'cvxpy':
fit_func = GeneralPurposeCSDOptimizer(
acquisition_scheme, self, x0_, self.sh_order, unity_constraint)
else:
msg = "Unknown solver name {}".format(solver)
raise ValueError(msg)
print('Setup CVXPY FOD optimizer in {} seconds'.format(
time() - start))
self.optimizer = fit_func

if use_parallel_processing and not have_pathos:
msg = 'Cannot use parallel processing without pathos.'
Expand All @@ -1698,23 +1729,6 @@ def fit(self, acquisition_scheme, data, mask=None,
fitted_parameters_lin = np.empty(
np.r_[N_voxels, N_parameters], dtype=float)

start = time()
if solver == 'cvxpy':
fit_func = GeneralPurposeCSDOptimizer(
acquisition_scheme, self, x0_, self.sh_order, unity_constraint)
elif solver == 'tournier07':
fit_func = CsdTournierOptimizer(
acquisition_scheme, self, x0_, self.sh_order,
unity_constraint=unity_constraint)
print('Setup Tournier07 FOD optimizer in {} seconds'.format(
time() - start))
else:
msg = "Unknown solver name {}".format(solver)
raise ValueError(msg)
print('Setup CVXPY FOD optimizer in {} seconds'.format(
time() - start))
self.optimizer = fit_func

start = time()
for idx, pos in enumerate(zip(*mask_pos)):
voxel_E = data_[pos] / S0[pos]
Expand All @@ -1735,8 +1749,7 @@ def fit(self, acquisition_scheme, data, mask=None,
print('Average of {} seconds per voxel.'.format(
fitting_time / N_voxels))
fitted_parameters = np.zeros_like(x0_, dtype=float)
fitted_parameters[mask_pos] = (
fitted_parameters_lin * self.scales_for_optimization)
fitted_parameters[mask_pos] = fitted_parameters_lin

return FittedMultiCompartmentSphericalHarmonicsModel(
self, S0, mask, fitted_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_equivalence_csd_and_parametric_fod(
[stick])
sh_mod.set_fixed_parameter('C1Stick_1_lambda_par', lambda_par)

sh_fit = sh_mod.fit(scheme, data)
sh_fit = sh_mod.fit(scheme, data, solver='cvxpy')
fod = sh_fit.fod(sphere.vertices)

watson = distributions.SD1Watson(mu=[0., 0.], odi=0.15)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_multi_compartment_fod_with_parametric_model(
partial_volume_1=1 - vf_intra)
data = mc_mod.simulate_signal(scheme, simulation_parameters)

sh_fit = sh_mod.fit(scheme, data)
sh_fit = sh_mod.fit(scheme, data, solver='cvxpy')

vf_intra_estimated = sh_fit.fitted_parameters['partial_volume_0']
assert_almost_equal(vf_intra, vf_intra_estimated)
Expand Down Expand Up @@ -123,4 +123,4 @@ def test_spherical_harmonics_model_raises(
sh_mod = modeling_framework.MultiCompartmentSphericalHarmonicsModel(
[stick])

assert_raises(ValueError, sh_mod.fit, scheme, data)
assert_raises(ValueError, sh_mod.fit, scheme, data, solver='cvxpy')
3 changes: 3 additions & 0 deletions dmipy/optimizers_fod/construct_observation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def construct_model_based_A_matrix(acquisition_scheme, model_rh, lmax):
# prepare the rotational harmonics of the kernel
counter = 0
for n_ in range(0, lmax + 1, 2):
if n_ // 2 > model_rh.shape[1]:
# in case an isotropic kernel is given
break
coef_in_order = 2 * n_ + 1
sh_eigenvalues[:, counter: counter + coef_in_order] = (
np.sqrt((4 * np.pi) / (2 * n_ + 1)) * # sh eigenvalues
Expand Down
64 changes: 42 additions & 22 deletions dmipy/optimizers_fod/csd_tournier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dipy.data import get_sphere, HemiSphere
from dipy.reconst.shm import real_sym_sh_mrtrix
from dipy.utils.optpkg import optional_package
from dipy.reconst.shm import sph_harm_ind_list
sphere = get_sphere('symmetric724')
numba, have_numba, _ = optional_package("numba")

Expand All @@ -14,8 +15,8 @@

class CsdTournierOptimizer:
def __init__(self, acquisition_scheme, model, x0_vector=None, sh_order=8,
lambda_reg=1., tau=0.1, max_iter=50, unity_constraint=True,
init_sh_order=4):
lambda_pos=1., lambda_lb=5e-4, tau=0.1, max_iter=50,
unity_constraint=True, init_sh_order=4):
"""
The classical Constrained Spherical Deconvolution (CSD) optimizer as
proposed by Tournier et al. (2007) [1]_.
Expand All @@ -37,8 +38,11 @@ def __init__(self, acquisition_scheme, model, x0_vector=None, sh_order=8,
Possible parameters for model kernels.
sh_order: positive even integer,
Spherical harmonics order for deconvolution.
lambda_reg: positive float,
lambda_pos: positive float,
Positivity regularization parameter.
lambda_lb: positive float,
Laplace-Belrami regularization weight to impose smoothness in the
FOD. Same as is done in [2]_.
tau: positive float,
Scales positivity threshold relative to maximum FOD amplitude.
max_iter: positive integer,
Expand All @@ -57,14 +61,18 @@ def __init__(self, acquisition_scheme, model, x0_vector=None, sh_order=8,
"Robust determination of the fibre orientation distribution in
diffusion MRI: non-negativity constrained super-resolved spherical
deconvolution." Neuroimage 35.4 (2007): 1459-1472.
.. [2] Fick, Rutger. "An Optimized Processing Framework for Fiber
Tracking on DW-MRI Applied to the Optic Radiation", Master Thesis
(2013).
"""
self.model = model
self.acquisition_scheme = acquisition_scheme
self.sh_order = sh_order
self.Ncoef = int((sh_order + 2) * (sh_order + 1) // 2)
self.Ncoef4 = int((init_sh_order + 2) * (init_sh_order + 1) // 2)
self.Nmodels = len(self.model.models)
self.lambda_reg = lambda_reg
self.lambda_pos = lambda_pos
self.lambda_lb = lambda_lb
self.tau = tau
self.max_iter = max_iter
self.unity_constraint = unity_constraint
Expand All @@ -76,8 +84,11 @@ def __init__(self, acquisition_scheme, model, x0_vector=None, sh_order=8,
self.L_positivity = real_sym_sh_mrtrix(
self.sh_order, hemisphere.theta, hemisphere.phi)[0]

sh_l = sph_harm_ind_list(sh_order)[1]
self.R_smoothness = np.diag(sh_l * (sh_l + 1))

# check if there is only one model. If so, precompute rh array.
if self.Nmodels == 1:
if self.model.volume_fractions_fixed:
x0_single_voxel = np.reshape(
x0_vector, (-1, x0_vector.shape[-1]))[0]
if np.all(np.isnan(x0_single_voxel)):
Expand All @@ -88,7 +99,7 @@ def __init__(self, acquisition_scheme, model, x0_vector=None, sh_order=8,
else:
self.single_convolution_kernel = False
else:
msg = "This CSD optimizer does not support multiple models."
msg = "This CSD optimizer cannot estimate volume fractions."
raise ValueError(msg)

def __call__(self, data, x0_vector):
Expand Down Expand Up @@ -116,10 +127,10 @@ def __call__(self, data, x0_vector):

if self.single_convolution_kernel:
A = self.A
AT_A = self.AT_A
AT_A = self.AT_A + self.lambda_lb * self.R_smoothness
else:
A = self._construct_convolution_kernel(x0_vector)
AT_A = np.dot(A.T, A)
AT_A = np.dot(A.T, A) + self.lambda_lb * self.R_smoothness

if self.unity_constraint:
return self._optimize_with_unity_constraint(
Expand Down Expand Up @@ -161,7 +172,7 @@ def _optimize_without_unity_constraint(self, A, AT_A, data, x0_vector):

for iteration in range(self.max_iter):
L = self.L_positivity[negative_fod_check]
Q = AT_A + self.lambda_reg * np.dot(L.T, L)
Q = AT_A + self.lambda_pos * np.dot(L.T, L)
f_sh = np.dot(np.dot(np.linalg.inv(Q), A.T), data)
negative_fod_check_old = negative_fod_check
negative_fod_check = np.dot(self.L_positivity, f_sh) < threshold
Expand Down Expand Up @@ -222,7 +233,7 @@ def _optimize_with_unity_constraint(self, A, AT_A, data, x0_vector):

for iteration in range(self.max_iter):
L = self.L_positivity[negative_fod_check]
Q = AT_A + self.lambda_reg * np.dot(L.T, L)
Q = AT_A + self.lambda_pos * np.dot(L.T, L)
f_sh[1:] = np.dot(np.dot(np.linalg.inv(Q), A.T), data)[1:]
negative_fod_check_old = negative_fod_check
negative_fod_check = np.dot(self.L_positivity, f_sh) < threshold
Expand Down Expand Up @@ -262,17 +273,26 @@ def _construct_convolution_kernel(self, x0_vector):
parameters_dict = self.model.add_linked_parameters_to_parameters(
parameters_dict)

parameters = {}
for parameter in self.model.models[0].parameter_ranges:
parameter_name = self.model._inverted_parameter_map[
(self.model.models[0], parameter)
if len(self.model.models) > 1:
partial_volumes = [
parameters_dict[p] for p in self.model.partial_volume_names
]
parameters[parameter] = parameters_dict.get(
parameter_name
)
model_rh = (
self.model.models[0].rotational_harmonics_representation(
self.acquisition_scheme, **parameters))
kernel = construct_model_based_A_matrix(
self.acquisition_scheme, model_rh, self.sh_order)
else:
partial_volumes = [1.]

kernel = 0.
for model, partial_volume in zip(self.model.models, partial_volumes):
parameters = {}
for parameter in model.parameter_ranges:
parameter_name = self.model._inverted_parameter_map[
(model, parameter)
]
parameters[parameter] = parameters_dict.get(
parameter_name
)
model_rh = (
model.rotational_harmonics_representation(
self.acquisition_scheme, **parameters))
kernel += partial_volume * construct_model_based_A_matrix(
self.acquisition_scheme, model_rh, self.sh_order)
return kernel

0 comments on commit b1a0e74

Please sign in to comment.