Skip to content

Commit

Permalink
Merge 7a1842d into 2201c60
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 28, 2019
2 parents 2201c60 + 7a1842d commit a79c9a4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
36 changes: 31 additions & 5 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ Helper function for macros to construct container objects. Takes an `Expr` that
"""
_build_ref_sets(c) = _build_ref_sets(c, _get_name(c))

function _expr_is_splat(ex::Expr)
if ex.head == :(...)
return true
elseif ex.head == :escape
return _expr_is_splat(ex.args[1])
end
return false
end
_expr_is_splat(::Any) = false

"""
JuMP._get_looped_code(varname, code, condition, idxvars, idxsets, sym, requestedcontainer::Symbol; lowertri=false)
Expand Down Expand Up @@ -592,6 +602,9 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function)
# Strategy: build up the code for add_constraint, and if needed
# we will wrap in loops to assign to the ConstraintRefs
refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_is_splat.(idxsets))
_error("cannot use splatting operator `...`.")
end

vectorized, parsecode, buildcall = parsefun(_error, x.args...)
_add_kw_args(buildcall, kw_args)
Expand Down Expand Up @@ -1008,7 +1021,7 @@ expr = @expression(m, [i=1:3], i*sum(x[j] for j=1:3))
```
"""
macro expression(args...)

macro_error(str...) = _macro_error(:expression, args, str...)
args, kw_args, requestedcontainer = _extract_kw_args(args)
if length(args) == 3
m = esc(args[1])
Expand All @@ -1019,14 +1032,17 @@ macro expression(args...)
c = gensym()
x = args[2]
else
error("@expression: needs at least two arguments.")
macro_error("needs at least two arguments.")
end
length(kw_args) == 0 || error("@expression: unrecognized keyword argument")
length(kw_args) == 0 || macro_error("unrecognized keyword argument")

anonvar = isexpr(c, :vect) || isexpr(c, :vcat) || length(args) == 2
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_is_splat.(idxsets))
macro_error("cannot use splatting operator `...`.")
end
newaff, parsecode = _parse_expr_toplevel(x, :q)
code = quote
q = Val{false}()
Expand Down Expand Up @@ -1393,10 +1409,12 @@ macro variable(args...)
final_variable = variable
else
isa(var,Expr) || _error("Expected $var to be a variable name")

# We now build the code to generate the variables (and possibly the
# SparseAxisArray to contain them)
refcall, idxvars, idxsets, condition = _build_ref_sets(var, variable)
if any(_expr_is_splat.(idxsets))
_error("cannot use splatting operator `...`.")
end
clear_dependencies(i) = (Containers.is_dependent(idxvars,idxsets[i],i) ? () : idxsets[i])

# Code to be used to create each variable of the container.
Expand Down Expand Up @@ -1510,6 +1528,9 @@ macro NLconstraint(m, x, extra...)
# Strategy: build up the code for non-macro add_constraint, and if needed
# we will wrap in loops to assign to the ConstraintRefs
refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_is_splat.(idxsets))
error("@NLconstraint: cannot use splatting operator `...`.")
end
# Build the constraint
if isexpr(x, :call) # one-sided constraint
# Simple comparison - move everything to the LHS
Expand Down Expand Up @@ -1606,6 +1627,9 @@ macro NLexpression(args...)
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_is_splat.(idxsets))
error("@NLexpression: cannot use splatting operator `...`.")
end
code = quote
$(refcall) = NonlinearExpression($(esc(m)), $(_process_NL_expr(m, x)))
end
Expand Down Expand Up @@ -1663,7 +1687,6 @@ macro NLparameter(m, ex, extra...)
end
c = ex.args[2]
x = ex.args[3]

anonvar = isexpr(c, :vect) || isexpr(c, :vcat)
if anonvar
error("In @NLparameter($m, $ex): Anonymous nonlinear parameter syntax is not currently supported")
Expand All @@ -1672,6 +1695,9 @@ macro NLparameter(m, ex, extra...)
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_is_splat.(idxsets))
error("@NLparameter: cannot use splatting operator `...`.")
end
code = quote
if !isa($(esc(x)), Number)
error(string("in @NLparameter (", $(string(ex)), "): expected ",
Expand Down
31 changes: 31 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,37 @@ end
c = @NLconstraint(model, x == sum(1.0 for i in 1:0))
@test sprint(show, c) == "x - 0 = 0" || sprint(show, c) == "x - 0 == 0"
end

@testset "Splatting error" begin
model = Model()
A = [1 0; 0 1]
@variable(model, x)

@test_macro_throws ErrorException(
"In `@variable(model, y[axes(A)...])`: cannot use splatting operator `...`."
) @variable(model, y[axes(A)...])

f(a, b) = [a, b]
@variable(model, z[f((1, 2)...)])
@test length(z) == 2

@test_macro_throws ErrorException(
"In `@constraint(model, [axes(A)...], x >= 1)`: cannot use splatting operator `...`."
) @constraint(model, [axes(A)...], x >= 1)

@test_macro_throws ErrorException(
"@NLconstraint: cannot use splatting operator `...`."
) @NLconstraint(model, [axes(A)...], x >= 1)

@test_macro_throws ErrorException(
"In `@expression(model, [axes(A)...], x)`: cannot use splatting operator `...`."
) @expression(model, [axes(A)...], x)

@test_macro_throws ErrorException(
"@NLexpression: cannot use splatting operator `...`."
) @NLexpression(model, [axes(A)...], x)
end

end

@testset "Macros for JuMPExtension.MyModel" begin
Expand Down

0 comments on commit a79c9a4

Please sign in to comment.