Skip to content

Commit

Permalink
prevent bad keywords getting into optimization, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Feb 28, 2018
1 parent 2074c32 commit e653108
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion thecannon/fitting.py
Expand Up @@ -314,6 +314,36 @@ def _scatter_objective_function(scatter, residuals_squared, ivar):
return (np.mean(chi_sq) - 1.0)**2


def _remove_forbidden_op_kwds(op_method, op_kwds):
"""
Remove forbidden optimization keywords.
:param op_method:
The optimization algorithm to use.
:param op_kwds:
Optimization keywords.
:returns:
`None`. The dictionary of `op_kwds` will be updated.
"""
all_allowed_keys = dict(
l_bfgs_b=("bounds", "m", "factr", "pgtol", "epsilon", "iprint",
"maxfun", "maxiter", "disp", "callback", "maxls"),
powell=("xtol", "ftol", "maxiter", "maxfun", "full_output", "disp",
"retall", "callback", "initial_simplex"))

forbidden_keys = set(op_kwds).difference(all_allowed_keys[op_method])
if forbidden_keys:
logger.warn("Ignoring forbidden optimization keywords for {}: {}"\
.format(op_method, ", ".join(forbidden_keys)))
for key in forbidden_keys:
del op_kwds[key]

return None



def fit_pixel_fixed_scatter(flux, ivar, initial_thetas, design_matrix,
regularization, censoring_mask, **kwargs):
"""
Expand Down Expand Up @@ -397,6 +427,9 @@ def fit_pixel_fixed_scatter(flux, ivar, initial_thetas, design_matrix,
op_kwds["bounds"] = [b for b, is_censored in \
zip(op_kwds["bounds"], censored_theta) if not is_censored]

# Only include allowable keywords.
_remove_forbidden_op_kwds(op_method, op_kwds)

op_params, fopt, metadata = op.fmin_l_bfgs_b(
_pixel_objective_function_fixed_scatter,
fprime=None, approx_grad=None, **op_kwds)
Expand All @@ -410,7 +443,8 @@ def fit_pixel_fixed_scatter(flux, ivar, initial_thetas, design_matrix,
logger.warn("Optimization warning (l_bfgs_b): {}".format(reason))

# Do optimization again.
base_op_kwds.update(op_method="powell", x0=op_params)
op_method = "powell"
base_op_kwds.update(x0=op_params)

else:
break
Expand All @@ -420,6 +454,9 @@ def fit_pixel_fixed_scatter(flux, ivar, initial_thetas, design_matrix,
op_kwds.update(xtol=1e-6, ftol=1e-6)
op_kwds.update((kwargs.get("op_kwds", {}) or {}))

# Only include allowable keywords.
_remove_forbidden_op_kwds(op_method, op_kwds)

# Set 'False' in args so that we don't return the gradient,
# because fmin doesn't want it.
args = list(op_kwds["args"])
Expand Down

0 comments on commit e653108

Please sign in to comment.