Skip to content

Commit

Permalink
Add a nice error message for user-defined functions
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 12, 2018
1 parent 67133a8 commit c78369c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/Derivatives/forward.jl
@@ -1,3 +1,15 @@
# Internal function: evaluates `func` at `args` and checks that the return type
# is `RetType`. This is a catch to provide the user with a nice error message if
# their user-defined function returns unexpected results.
function eval_and_check_return_type(func::Function, RetType, args...)
ret = func(args...)
if !isa(ret, RetType)
error("Expected return type of $(RetType), but got $(typeof(ret)). " *
"Make sure your user-defined function only depends on variables" *
" passed as arguments.")
end
return ret
end


# forward-mode evaluation of an expression tree
Expand Down Expand Up @@ -135,7 +147,8 @@ function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVect
end
# TODO: The function names are confusing here. This just
# evaluates the function value and gradient.
fval = MOI.eval_objective(evaluator, f_input)::T
fval = eval_and_check_return_type(
MOI.eval_objective, T, evaluator, f_input)::T
MOI.eval_objective_gradient(evaluator, grad_output, f_input)
storage[k] = fval
r = 1
Expand All @@ -156,8 +169,8 @@ function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVect
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
f = user_operators.univariate_operator_f[userop]
fprime = user_operators.univariate_operator_fprime[userop]
fval = f(child_val)::T
fprimeval = fprime(child_val)::T
fval = eval_and_check_return_type(f, T, child_val)::T
fprimeval = eval_and_check_return_type(fprime, T, child_val)::T
else
fval, fprimeval = eval_univariate(op, child_val)
end
Expand Down Expand Up @@ -341,18 +354,17 @@ function forward_eval_ϵ(storage::AbstractVector{T},
child_val = storage[child_idx]
if op >= USER_UNIVAR_OPERATOR_ID_START
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
fprimeprime = user_operators.univariate_operator_fprimeprime[userop](child_val)::T
fprimeprime = eval_and_check_return_type(
user_operators.univariate_operator_fprimeprime[userop],
T, child_val)::T
else
fprimeprime = eval_univariate_2nd_deriv(op, child_val,storage[k])
end
partials_storage_ϵ[child_idx] = fprimeprime*storage_ϵ[child_idx]
end
end

end

return storage_ϵ[1]

end

export forward_eval_ϵ
Expand Down
17 changes: 17 additions & 0 deletions test/nlp.jl
Expand Up @@ -528,4 +528,21 @@
MOI.eval_objective_gradient(d, grad, [2.0])
@test grad == [1.0]
end

@testset "" begin
model = Model()
@variable(model, x[1:2])
f(x1) = x1 + x[2]
JuMP.register(model, :f, 1, f; autodiff = true)
@NLobjective(model, Min, f(x[1]))
d = JuMP.NLPEvaluator(model)
MOI.initialize(d, [:Grad])
expected_exception = ErrorException(
"Expected return type of Float64, but got " *
"JuMP.GenericAffExpr{Float64,VariableRef}. Make sure your " *
"user-defined function only depends on variables passed as " *
"arguments."
)
@test_throws expected_exception MOI.eval_objective(d, [1.0, 1.0])
end
end

0 comments on commit c78369c

Please sign in to comment.