Skip to content

Commit

Permalink
added tests for set_fixed_parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
rutgerfick committed Feb 26, 2018
1 parent cd5359c commit 84355d9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
17 changes: 16 additions & 1 deletion dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,20 @@ def set_fixed_parameter(self, parameter_name, value):
the value to fix the parameter at in SI units.
"""
if parameter_name in self.parameter_ranges.keys():
card = self.parameter_cardinality[parameter_name]
if card == 1:
if isinstance(value, int):
value = float(value)
if not isinstance(value, float):
msg = '{} can only be fixed to a float value.'.format(
parameter_name)
raise ValueError(msg)
elif card == 2:
value = np.array(value, dtype=float)
if value.shape != (2,):
msg = '{} can only be fixed '.format(parameter_name)
msg += 'to an array or list of length 2.'
raise ValueError(msg)
model, name = self._parameter_map[parameter_name]
parameter_link = (model, name, ReturnFixedValue(value), [])
self.parameter_links.append(parameter_link)
Expand All @@ -445,8 +459,9 @@ def set_fixed_parameter(self, parameter_name, value):
del self.parameter_types[parameter_name]
del self.parameter_optimization_flags[parameter_name]
else:
print('"{}" does not exist or has already been fixed.').format(
msg = '{} does not exist or has already been fixed.'.format(
parameter_name)
raise ValueError(msg)

def set_tortuous_parameter(self, lambda_perp_parameter_name,
lambda_par_parameter_name,
Expand Down
11 changes: 11 additions & 0 deletions dmipy/core/tests/test_raises_multicompartment_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,14 @@ def test_raise_mix_with_tortuosity_in_mcmodel():
data = stick(scheme, lambda_par=1.7e-9, mu=[0., 0.])

assert_raises(ValueError, mc.fit, scheme, data, solver='mix')


def test_set_fixed_parameter_raises():
cyl = cylinder_models.C1Stick()
mod = modeling_framework.MultiCompartmentModel([cyl])
assert_raises(ValueError, mod.set_fixed_parameter,
'C1Stick_1_lambda_par', [1])
assert_raises(ValueError, mod.set_fixed_parameter,
'C1Stick_1_mu', [1])
assert_raises(ValueError, mod.set_fixed_parameter,
'blabla', [1])
17 changes: 16 additions & 1 deletion dmipy/distributions/distribute_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,20 @@ def set_fixed_parameter(self, parameter_name, value):
the value to fix the parameter at in SI units.
"""
if parameter_name in self.parameter_ranges.keys():
card = self.parameter_cardinality[parameter_name]
if card == 1:
if isinstance(value, int):
value = float(value)
if not isinstance(value, float):
msg = '{} can only be fixed to a float value.'.format(
parameter_name)
raise ValueError(msg)
elif card == 2:
value = np.array(value, dtype=float)
if value.shape != (2,):
msg = '{} can only be fixed '.format(parameter_name)
msg += 'to an array or list of length 2.'
raise ValueError(msg)
model, name = self._parameter_map[parameter_name]
parameter_link = (model, name, ReturnFixedValue(value), [])
self.parameter_links.append(parameter_link)
Expand All @@ -254,8 +268,9 @@ def set_fixed_parameter(self, parameter_name, value):
del self.parameter_cardinality[parameter_name]
del self.parameter_types[parameter_name]
else:
print('{} does not exist or has already been fixed.').format(
msg = '{} does not exist or has already been fixed.'.format(
parameter_name)
raise ValueError(msg)

def set_tortuous_parameter(self, lambda_perp,
lambda_par,
Expand Down
11 changes: 11 additions & 0 deletions dmipy/distributions/tests/test_distributed_model_raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,14 @@ def test_raise_mixed_parameter_types():
assert_raises(AttributeError,
distribute_models.DD1GammaDistributed,
[sphere, cylinder])


def test_set_fixed_parameter_raises():
cyl = cylinder_models.C1Stick()
distcyl = distribute_models.SD1WatsonDistributed([cyl])
assert_raises(ValueError, distcyl.set_fixed_parameter,
'SD1Watson_1_odi', [1])
assert_raises(ValueError, distcyl.set_fixed_parameter,
'SD1Watson_1_mu', [1])
assert_raises(ValueError, distcyl.set_fixed_parameter,
'blabla', [1])

0 comments on commit 84355d9

Please sign in to comment.