Skip to content

Commit

Permalink
Adds setting parameter optimization bounds functionality. (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
rutgerfick committed May 1, 2019
1 parent cf885b2 commit 81ebad5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 29 deletions.
39 changes: 39 additions & 0 deletions dmipy/core/modeling_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,45 @@ def _construct_convolution_kernel(self, **kwargs):
kernel = np.hstack(kernel)
return kernel

def set_parameter_optimization_bounds(self, parameter_name, bounds):
"""
Sets the parameter optimization bounds for a given parameter.
Parameters
----------
parameter_name: string,
name of the parameter whose bounds should be changed.
bounds: array or size(card, 2),
upper and lower bound for each optimized value for the given
parameter, where card is
self.parameter_cardinality[parameter_name]).
Raises
------
ValueError: parameter name not in model parameters
ValueError: input bounds are not of correct shape [card, 2]
ValueError: input higher bound is lower than lower bound
"""
if parameter_name not in self.parameter_names:
raise ValueError(
'{} not in model parameters'.format(parameter_name))
card = self.parameter_cardinality[parameter_name]
bounds_array = np.atleast_2d(bounds)
input_card, N_bounds = bounds_array.shape[:2]
if bounds_array.ndim > 2 or input_card != card or N_bounds != 2:
msg = '{} bounds must be of shape ({}, 2), currently {}.'
raise ValueError(
msg.format(parameter_name, card, bounds_array.shape))
for lower, higher in bounds_array:
if higher < lower:
msg = 'given optimization bounds for {} are invalid: lower '\
'bound {} is higher than upper bound {}.'
raise ValueError(msg.format(parameter_name, lower, higher))
parameter_scale = np.max(bounds)
ranges = np.array(bounds) / parameter_scale
self.parameter_ranges[parameter_name] = ranges
self.parameter_scales[parameter_name] = parameter_scale


class MultiCompartmentModel(MultiCompartmentModelProperties):
r'''
Expand Down
15 changes: 15 additions & 0 deletions dmipy/core/tests/test_raises_multicompartment_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,18 @@ def test_fitting_without_b0_raises():
[gaussian_models.G1Ball()])
data = np.atleast_1d(1.)
assert_raises(ValueError, mc.fit, scheme, data)


def test_set_parameter_optimization_bounds_raises():
ball = gaussian_models.G1Ball()
mc = modeling_framework.MultiCompartmentModel([ball])
assert_raises(ValueError, mc.set_parameter_optimization_bounds,
'not a valid name', [1, 2])
assert_raises(ValueError, mc.set_parameter_optimization_bounds,
'G1Ball_1_lambda_iso', 1)
assert_raises(ValueError, mc.set_parameter_optimization_bounds,
'G1Ball_1_lambda_iso', [[1, 2], [1, 2]])
assert_raises(ValueError, mc.set_parameter_optimization_bounds,
'G1Ball_1_lambda_iso', [1, 2, 3])
assert_raises(ValueError, mc.set_parameter_optimization_bounds,
'G1Ball_1_lambda_iso', [2, 1])
61 changes: 32 additions & 29 deletions examples/example_verdict.ipynb

Large diffs are not rendered by default.

0 comments on commit 81ebad5

Please sign in to comment.