Skip to content

Commit

Permalink
Merge b91debd into 10878c9
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 12, 2018
2 parents 10878c9 + b91debd commit abd2d5c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/Derivatives/Derivatives.jl
Expand Up @@ -9,6 +9,7 @@ using ForwardDiff
import Calculus
import MathOptInterface
const MOI = MathOptInterface
using ..JuMP

const TAG = :rds_tag

Expand Down
30 changes: 23 additions & 7 deletions src/Derivatives/forward.jl
@@ -1,3 +1,19 @@
# Internal function: evaluates `func` at `args` and checks that the return type
# is `RetType`. This is a catch to provide the user with a nice error message if
# their user-defined function returns unexpected results.
function eval_and_check_return_type(func::Function, RetType, args...)
ret = func(args...)
if !isa(ret, RetType)
message = "Expected return type of $(RetType) from a user-defined " *
"function, but got $(typeof(ret))."
if isa(ret, JuMP.AbstractJuMPScalar)
message *= " Make sure your user-defined function only depends on" *
" variables passed as arguments."
end
error(message)
end
return ret
end


# forward-mode evaluation of an expression tree
Expand Down Expand Up @@ -135,7 +151,8 @@ function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVect
end
# TODO: The function names are confusing here. This just
# evaluates the function value and gradient.
fval = MOI.eval_objective(evaluator, f_input)::T
fval = eval_and_check_return_type(
MOI.eval_objective, T, evaluator, f_input)::T
MOI.eval_objective_gradient(evaluator, grad_output, f_input)
storage[k] = fval
r = 1
Expand All @@ -156,8 +173,8 @@ function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVect
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
f = user_operators.univariate_operator_f[userop]
fprime = user_operators.univariate_operator_fprime[userop]
fval = f(child_val)::T
fprimeval = fprime(child_val)::T
fval = eval_and_check_return_type(f, T, child_val)::T
fprimeval = eval_and_check_return_type(fprime, T, child_val)::T
else
fval, fprimeval = eval_univariate(op, child_val)
end
Expand Down Expand Up @@ -341,18 +358,17 @@ function forward_eval_ϵ(storage::AbstractVector{T},
child_val = storage[child_idx]
if op >= USER_UNIVAR_OPERATOR_ID_START
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
fprimeprime = user_operators.univariate_operator_fprimeprime[userop](child_val)::T
fprimeprime = eval_and_check_return_type(
user_operators.univariate_operator_fprimeprime[userop],
T, child_val)::T
else
fprimeprime = eval_univariate_2nd_deriv(op, child_val,storage[k])
end
partials_storage_ϵ[child_idx] = fprimeprime*storage_ϵ[child_idx]
end
end

end

return storage_ϵ[1]

end

export forward_eval_ϵ
Expand Down
40 changes: 40 additions & 0 deletions test/nlp.jl
Expand Up @@ -528,4 +528,44 @@
MOI.eval_objective_gradient(d, grad, [2.0])
@test grad == [1.0]
end

@testset "User-defined function with variable closure" begin
model = Model()
@variable(model, x[1:2])
f(x1) = x1 + x[2]
JuMP.register(model, :f, 1, f; autodiff = true)
@NLobjective(model, Min, f(x[1]))
d = JuMP.NLPEvaluator(model)
MOI.initialize(d, [:Grad])
expected_exception = ErrorException(
"Expected return type of Float64 from a user-defined function, " *
"but got JuMP.GenericAffExpr{Float64,VariableRef}. Make sure your" *
" user-defined function only depends on variables passed as " *
"arguments."
)
if VERSION < v"0.7"
@test_throws ErrorException MOI.eval_objective(d, [1.0, 1.0])
else
@test_throws expected_exception MOI.eval_objective(d, [1.0, 1.0])
end
end

@testset "User-defined function returning bad type" begin
model = Model()
@variable(model, x)
f(x) = string(x)
JuMP.register(model, :f, 1, f; autodiff = true)
@NLobjective(model, Min, f(x))
d = JuMP.NLPEvaluator(model)
MOI.initialize(d, [:Grad])
expected_exception = ErrorException(
"Expected return type of Float64 from a user-defined function, " *
"but got String."
)
if VERSION < v"0.7"
@test_throws ErrorException MOI.eval_objective(d, [1.0, 1.0])
else
@test_throws expected_exception MOI.eval_objective(d, [1.0, 1.0])
end
end
end

0 comments on commit abd2d5c

Please sign in to comment.