Skip to content

Commit

Permalink
add multi tissue correction to tortuosity constraint (#98)
Browse files Browse the repository at this point in the history
* Add multi tissue correction to tortuosity constraint.

This commit adds the possibility to correct the tortuosity constraint
by taking into account the multi-tissue properties of the employed
model. To do this, T1_tortuosity is now a class and the new syntax of the
`set_tortuous_parameter` method of the `MultiCompartmentModel` class is
extended in such a way that it can take as input the S0 of the tissues
modelled by the intra-cellular and the extra-cellular compartments.
Backward compatibility is maintained.
  • Loading branch information
matteofrigo committed Jun 14, 2020
1 parent ff5067f commit d959bf0
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 128 deletions.
2 changes: 1 addition & 1 deletion dmipy/core/acquisition_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def visualise_acquisition_G_Delta_rainbow(
plt.xlim(Delta_start, Delta_end)
plt.ylim(G_start, G_end)
cb.set_label('b-value ($s$/$mm^2$)', fontsize=18)
plt.xlabel('Pulse Separation $\Delta$ [sec]', fontsize=18)
plt.xlabel(r'Pulse Separation $\Delta$ [sec]', fontsize=18)
plt.ylabel('Gradient Strength [T/m]', fontsize=18)

def return_pruned_acquisition_scheme(self, shell_indices, data=None):
Expand Down
4 changes: 2 additions & 2 deletions dmipy/core/fitted_modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def return_parametric_fod_model(
for link in self.model.parameter_links:
param_to_delete = self.model._inverted_parameter_map[link[0],
link[1]]
if link[2] is T1_tortuosity:
if isinstance(link[2], T1_tortuosity):
bundle.parameter_links.append(
[link[0], link[1], link[2], link[3][:-1]])
elif link[2] is fractional_parameter:
Expand Down Expand Up @@ -590,7 +590,7 @@ def return_spherical_harmonics_fod_model(self, sh_order=8):
for link in self.model.parameter_links:
param_to_delete = self.model._inverted_parameter_map[link[0],
link[1]]
if link[2] is T1_tortuosity:
if isinstance(link[2], T1_tortuosity):
sh_model.parameter_links.append(
[link[0], link[1], link[2], link[3][:-1]])
elif link[2] is fractional_parameter:
Expand Down
119 changes: 64 additions & 55 deletions dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,33 @@
Document Module
'''
from __future__ import division
import pkg_resources

from collections import OrderedDict
from time import time
from uuid import uuid4

import numpy as np
from time import time
import pkg_resources
from dipy.utils.optpkg import optional_package

from ..utils.spherical_mean import (
estimate_spherical_mean_multi_shell)
from ..utils.utils import (
T1_tortuosity,
parameter_equality,
fractional_parameter)
from .fitted_modeling_framework import (
FittedMultiCompartmentModel,
FittedMultiCompartmentSphericalMeanModel,
FittedMultiCompartmentSphericalHarmonicsModel)
from ..optimizers.brute2fine import (
GlobalBruteOptimizer, Brute2FineOptimizer)
from ..optimizers_fod.csd_tournier import CsdTournierOptimizer
from ..optimizers_fod.csd_cvxpy import CsdCvxpyOptimizer
from ..optimizers.mix import MixOptimizer
from ..optimizers.multi_tissue_convex_optimizer import (
MultiTissueConvexOptimizer)
from dipy.utils.optpkg import optional_package
from uuid import uuid4
from ..optimizers_fod.csd_cvxpy import CsdCvxpyOptimizer
from ..optimizers_fod.csd_tournier import CsdTournierOptimizer
from ..utils.spherical_mean import (
estimate_spherical_mean_multi_shell)
from ..utils.utils import (
T1_tortuosity,
parameter_equality,
fractional_parameter)

pathos, have_pathos, _ = optional_package("pathos")
numba, have_numba, _ = optional_package("numba")
graphviz, have_graphviz, _ = optional_package("graphviz")
Expand Down Expand Up @@ -93,10 +95,9 @@ def parameter_names(self):
@property
def parameter_cardinality(self):
"Returns the cardinality of model parameters"
return OrderedDict([
(k, len(np.atleast_2d(self.parameter_ranges[k])))
for k in self.parameter_ranges
])
return OrderedDict(
[(k, len(np.atleast_2d(self.parameter_ranges[k]))) for k in
self.parameter_ranges])


class MultiCompartmentModelProperties:
Expand Down Expand Up @@ -128,16 +129,14 @@ def parameter_vector_to_parameters(self, parameter_vector):
if parameter_vector.ndim == 1:
for parameter, card in self.parameter_cardinality.items():
parameters[parameter] = parameter_vector[
current_pos: current_pos + card
]
current_pos: current_pos + card]
if card == 1:
parameters[parameter] = parameters[parameter][0]
current_pos += card
else:
for parameter, card in self.parameter_cardinality.items():
parameters[parameter] = parameter_vector[
..., current_pos: current_pos + card
]
..., current_pos: current_pos + card]
if card == 1:
parameters[parameter] = parameters[parameter][..., 0]
current_pos += card
Expand Down Expand Up @@ -320,8 +319,8 @@ def _prepare_parameter_links(self):
parameter_function

if (
(parameter_model, parameter_name)
not in self._inverted_parameter_map
(parameter_model, parameter_name)
not in self._inverted_parameter_map
):
raise ValueError(
"Parameter function {} doesn't exist".format(i)
Expand Down Expand Up @@ -436,7 +435,7 @@ def scales_for_optimization(self):

def _check_for_tortuosity_constraint(self):
for link in self.parameter_links:
if link[2] is T1_tortuosity:
if isinstance(link[2], T1_tortuosity):
msg = "Cannot use MIX optimization when the Tortuosity "
msg += "constraint is set in the MultiCompartmentModel. To "
msg += "use MIX while imposing Tortuosity, set the constraint "
Expand Down Expand Up @@ -467,7 +466,7 @@ def set_initial_guess_parameter(self, parameter_name, value):
elif card >= 2:
value = np.array(value, dtype=float)
if value.shape[-1] != card:
msg = '{} can only be fixed to an array or list with '\
msg = '{} can only be fixed to an array or list with ' \
'last dimension {}.'
raise ValueError(msg.format(parameter_name, type(value)))
if value.ndim == 1:
Expand Down Expand Up @@ -515,13 +514,13 @@ def set_fixed_parameter(self, parameter_name, value):
elif isinstance(value, np.ndarray):
self._add_fixed_parameter_array(parameter_name, value)
else:
msg = 'fixed value for {} must be number or np.array, '\
msg = 'fixed value for {} must be number or np.array, ' \
'currently {}'
raise ValueError(msg.format(parameter_name, type(value)))
elif card >= 2:
value = np.array(value, dtype=float)
if value.shape[-1] != card:
msg = '{} can only be fixed to an array or list with '\
msg = '{} can only be fixed to an array or list with ' \
'last dimension {}.'
raise ValueError(msg.format(parameter_name, card))
if value.ndim == 1:
Expand Down Expand Up @@ -559,7 +558,8 @@ def _add_fixed_parameter_array(self, parameter_name, parameter_array):
def set_tortuous_parameter(self, lambda_perp_parameter_name,
lambda_par_parameter_name,
volume_fraction_intra_parameter_name,
volume_fraction_extra_parameter_name):
volume_fraction_extra_parameter_name,
S0_correction=False):
"""
Allows the user to set a tortuosity constraint on the perpendicular
diffusivity of the extra-axonal compartment, which depends on the
Expand All @@ -568,6 +568,9 @@ def set_tortuous_parameter(self, lambda_perp_parameter_name,
The perpendicular diffusivity parameter will be removed from the
optimized parameters and added as a linked parameter.
To employ the multi-tissue correction of tortuosity it is sufficient to
pass the S0_intra and S0_extra parameters.
Parameters
----------
lambda_perp_parameter_name: string
Expand All @@ -582,6 +585,9 @@ def set_tortuous_parameter(self, lambda_perp_parameter_name,
volume_fraction_extra_parameter_name: string
name of the extra-axonal volume fraction parameter, see
self.parameter_names.
S0_correction: bool
If True, it uses the S0 of the intra-axonal and extra-axonal
compartments to define the tortuosity constraint. Default: False.
"""
params = [lambda_perp_parameter_name, lambda_par_parameter_name,
volume_fraction_intra_parameter_name,
Expand All @@ -595,11 +601,23 @@ def set_tortuous_parameter(self, lambda_perp_parameter_name,
raise ValueError(msg)

model, name = self._parameter_map[lambda_perp_parameter_name]
self.parameter_links.append([model, name, T1_tortuosity, [
if S0_correction and self.S0_tissue_responses is not None:
s0intra_tag = volume_fraction_intra_parameter_name.split('_')[-1]
s0extra_tag = volume_fraction_extra_parameter_name.split('_')[-1]
S0_intra = self.S0_tissue_responses[int(s0intra_tag)]
S0_extra = self.S0_tissue_responses[int(s0extra_tag)]
print('Employing S0 correction of tortuosity constraint with:')
print('S0_intra: {}'.format(S0_intra))
print('S0_extra: {}'.format(S0_extra))
else:
S0_intra = 1.
S0_extra = 1.
tortuosity = T1_tortuosity(S0_intra, S0_extra)

self.parameter_links.append([model, name, tortuosity, [
self._parameter_map[lambda_par_parameter_name],
self._parameter_map[volume_fraction_intra_parameter_name],
self._parameter_map[volume_fraction_extra_parameter_name]]
])
self._parameter_map[volume_fraction_extra_parameter_name]]])
del self.parameter_ranges[lambda_perp_parameter_name]
del self.parameter_cardinality[lambda_perp_parameter_name]
del self.parameter_scales[lambda_perp_parameter_name]
Expand Down Expand Up @@ -974,7 +992,7 @@ def set_parameter_optimization_bounds(self, parameter_name, bounds):
msg.format(parameter_name, card, bounds_array.shape))
for lower, higher in bounds_array:
if higher < lower:
msg = 'given optimization bounds for {} are invalid: lower '\
msg = 'given optimization bounds for {} are invalid: lower ' \
'bound {} is higher than upper bound {}.'
raise ValueError(msg.format(parameter_name, lower, higher))
parameter_scale = np.max(bounds)
Expand Down Expand Up @@ -1003,7 +1021,7 @@ def __init__(self, models, S0_tissue_responses=None, parameter_links=None):
self.N_models = len(models)
if S0_tissue_responses is not None:
if len(S0_tissue_responses) != self.N_models:
msg = 'Number of S0_tissue responses {} must be same as '\
msg = 'Number of S0_tissue responses {} must be same as ' \
'number of input models {}.'
raise ValueError(
msg.format(len(S0_tissue_responses), self.N_models))
Expand Down Expand Up @@ -1039,7 +1057,7 @@ def _check_for_NMR_and_other_models(self):
raise ValueError(msg)

def _check_if_sh_coeff_fixed_if_present(self):
msg = 'sh_coeff parameter {} must be fixed in standard MC models '\
msg = 'sh_coeff parameter {} must be fixed in standard MC models ' \
'to estimate the kernel parameters.'
for name, par_type in self.parameter_types.items():
if par_type == 'sh_coefficients':
Expand Down Expand Up @@ -1243,8 +1261,8 @@ def fit(self, acquisition_scheme, data,
fitted_parameters[mask_pos] = (
fitted_parameters_lin * self.scales_for_optimization)

return FittedMultiCompartmentModel(
self, S0, mask, fitted_parameters, fitted_mt_fractions)
return FittedMultiCompartmentModel(self, S0, mask, fitted_parameters,
fitted_mt_fractions)

def simulate_signal(self, acquisition_scheme, parameters_array_or_dict):
"""
Expand Down Expand Up @@ -1335,7 +1353,7 @@ def __call__(self, acquisition_scheme_or_vertices,
partial_volumes = [1.]

for model_name, model, partial_volume in zip(
self.model_names, self.models, partial_volumes
self.model_names, self.models, partial_volumes
):
parameters = {}
for parameter in model.parameter_ranges:
Expand All @@ -1348,18 +1366,12 @@ def __call__(self, acquisition_scheme_or_vertices,
)

if quantity == "signal":
values = (
values +
partial_volume * model(
acquisition_scheme_or_vertices, **parameters)
)
values = (values + partial_volume * model(
acquisition_scheme_or_vertices, **parameters))
elif quantity == "FOD":
try:
values = (
values +
partial_volume * model.fod(
acquisition_scheme_or_vertices, **parameters)
)
values = (values + partial_volume * model.fod(
acquisition_scheme_or_vertices, **parameters))
except AttributeError:
continue
elif quantity == "stochastic cost function":
Expand Down Expand Up @@ -1389,7 +1401,7 @@ def __init__(self, models, S0_tissue_responses=None, parameter_links=None):
self.N_models = len(models)
if S0_tissue_responses is not None:
if len(S0_tissue_responses) != self.N_models:
msg = 'Number of S0_tissue responses {} must be same as '\
msg = 'Number of S0_tissue responses {} must be same as ' \
'number of input models {}.'
raise ValueError(
msg.format(len(S0_tissue_responses), self.N_models))
Expand Down Expand Up @@ -1737,7 +1749,7 @@ def __call__(self, acquisition_scheme_or_vertices,
partial_volumes = [1.]

for model_name, model, partial_volume in zip(
self.model_names, self.models, partial_volumes
self.model_names, self.models, partial_volumes
):
parameters = {}
for parameter in model.parameter_ranges:
Expand All @@ -1750,11 +1762,8 @@ def __call__(self, acquisition_scheme_or_vertices,
)

if quantity == "signal":
values = (
values +
partial_volume * model.spherical_mean(
acquisition_scheme_or_vertices, **parameters)
)
values = (values + partial_volume * model.spherical_mean(
acquisition_scheme_or_vertices, **parameters))
elif quantity == "stochastic cost function":
values[:, counter] = model.spherical_mean(
acquisition_scheme_or_vertices,
Expand All @@ -1781,7 +1790,7 @@ def __init__(self, models, S0_tissue_responses=None, sh_order=8):
if S0_tissue_responses is not None:
self.fit_S0_response = True
if len(S0_tissue_responses) != self.N_models:
msg = 'Number of S0_tissue responses {} must be same as '\
msg = 'Number of S0_tissue responses {} must be same as ' \
'number of input models {}.'
raise ValueError(
msg.format(len(S0_tissue_responses), self.N_models))
Expand Down Expand Up @@ -2205,7 +2214,7 @@ def homogenize_x0_to_data(data, x0):
else:
x0_as_data = x0.copy()
if not np.all(
x0_as_data.shape[:-1] == data.shape[:-1]
x0_as_data.shape[:-1] == data.shape[:-1]
):
# if x0 and data are both N-dimensional but have different shapes.
msg = "data and x0 both N-dimensional but have different shapes. "
Expand Down
5 changes: 0 additions & 5 deletions dmipy/core/tests/test_fitted_model_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,3 @@ def test_parametric_fod_spherical_mean_model():
distribution=distribution_name, Ncompartments=1)
fitted_fod_model = fod_model.fit(scheme, data)
assert_(isinstance(fitted_fod_model.fitted_parameters, dict))

# fod_model = smt_fit.return_parametric_fod_model(
# distribution='watson', Ncompartments=2)
# fitted_fod_model = fod_model(scheme, data)
# assert_(isinstance(fitted_fod_model.fitted_parameters, dict))

0 comments on commit d959bf0

Please sign in to comment.