Skip to content

Commit

Permalink
switch to run-time lookups for user-defined functions. closes #705
Browse files Browse the repository at this point in the history
  • Loading branch information
mlubin committed Mar 17, 2016
1 parent 84dd43f commit 4c514b9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
30 changes: 24 additions & 6 deletions src/nlpmacros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@ function parseNLExpr(m, x, tapevar, parent, values)
code = :(let; end)
block = code.args[1]
@assert isexpr(block, :block)
haskey(univariate_operator_to_id,x.args[1]) || error("Unrecognized function $(x.args[1]) used in nonlinear expression.")
operatorid = univariate_operator_to_id[x.args[1]]
push!(block.args, :(push!($tapevar, NodeData(CALLUNIVAR, $operatorid, $parent))))
if haskey(univariate_operator_to_id,x.args[1])
operatorid = univariate_operator_to_id[x.args[1]]
push!(block.args, :(push!($tapevar, NodeData(CALLUNIVAR, $operatorid, $parent))))
else
opname = quot(x.args[1])
errorstring = "Unrecognized function $opname used in nonlinear expression."
lookupcode = quote
haskey(univariate_operator_to_id,$opname) || error($errorstring)
operatorid = univariate_operator_to_id[$opname]
end
push!(block.args, :($lookupcode; push!($tapevar, NodeData(CALLUNIVAR, operatorid, $parent))))
end
parentvar = gensym()
push!(block.args, :($parentvar = length($tapevar)))
push!(block.args, parseNLExpr(m, x.args[2], tapevar, parentvar, values))
Expand All @@ -27,10 +36,19 @@ function parseNLExpr(m, x, tapevar, parent, values)
code = :(let; end)
block = code.args[1]
@assert isexpr(block, :block)
haskey(operator_to_id,x.args[1]) || error("Unrecognized function $(x.args[1]) used in nonlinear expression.")
operatorid = operator_to_id[x.args[1]]
if haskey(operator_to_id,x.args[1]) # fast compile-time lookup
operatorid = operator_to_id[x.args[1]]
push!(block.args, :(push!($tapevar, NodeData(CALL, $operatorid, $parent))))
else # could be user defined
opname = quot(x.args[1])
errorstring = "Unrecognized function $opname used in nonlinear expression."
lookupcode = quote
haskey(operator_to_id,$opname) || error($errorstring)
operatorid = operator_to_id[$opname]
end
push!(block.args, :($lookupcode; push!($tapevar, NodeData(CALL, operatorid, $parent))))
end
parentvar = gensym()
push!(block.args, :(push!($tapevar, NodeData(CALL, $operatorid, $parent))))
push!(block.args, :($parentvar = length($tapevar)))
for i in 1:length(x.args)-1
push!(block.args, parseNLExpr(m, x.args[i+1], tapevar, parentvar, values))
Expand Down
12 changes: 7 additions & 5 deletions test/nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,14 +646,16 @@ mysquare(x) = x^2
function myf(x,y)
return (x-1)^2+(y-2)^2
end
registerNLFunction(:myf, 2, myf, autodiff=true)
registerNLFunction(:myf_2, 2, myf, (g,x,y) -> (g[1] = 2(x-1); g[2] = 2(y-2)))
registerNLFunction(:mysquare, 1, mysquare, autodiff=true)
registerNLFunction(:mysquare_2, 1, mysquare, x-> 2x, autodiff=true)
registerNLFunction(:mysquare_3, 1, mysquare, x-> 2x, x -> 2.0)


if length(convex_nlp_solvers) > 0
facts("[nonlinear] User-defined functions") do
registerNLFunction(:myf, 2, myf, autodiff=true)
registerNLFunction(:myf_2, 2, myf, (g,x,y) -> (g[1] = 2(x-1); g[2] = 2(y-2)))
registerNLFunction(:mysquare, 1, mysquare, autodiff=true)
registerNLFunction(:mysquare_2, 1, mysquare, x-> 2x, autodiff=true)
registerNLFunction(:mysquare_3, 1, mysquare, x-> 2x, x -> 2.0)

m = Model(solver=convex_nlp_solvers[1])

@defVar(m, x[1:2] >= 0.5)
Expand Down

0 comments on commit 4c514b9

Please sign in to comment.