Skip to content

Commit

Permalink
Add nice error message for splatting in macros
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed May 26, 2019
1 parent 2201c60 commit 6caa7a0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
39 changes: 36 additions & 3 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ Helper function for macros to construct container objects. Takes an `Expr` that
"""
_build_ref_sets(c) = _build_ref_sets(c, _get_name(c))

function _expr_contains_splat(ex::Expr)
if ex.head == :(...)
return true
end
return any(_expr_contains_splat.(ex.args))
end
_expr_contains_splat(::Any) = false

"""
JuMP._get_looped_code(varname, code, condition, idxvars, idxsets, sym, requestedcontainer::Symbol; lowertri=false)
Expand Down Expand Up @@ -593,6 +601,10 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function)
# we will wrap in loops to assign to the ConstraintRefs
refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)

if any(_expr_contains_splat.(idxsets))
_error("cannot use splatting operator `...`.")
end

vectorized, parsecode, buildcall = parsefun(_error, x.args...)
_add_kw_args(buildcall, kw_args)
if vectorized
Expand Down Expand Up @@ -1008,7 +1020,7 @@ expr = @expression(m, [i=1:3], i*sum(x[j] for j=1:3))
```
"""
macro expression(args...)

macro_error(str...) = _macro_error(:expression, args, str...)
args, kw_args, requestedcontainer = _extract_kw_args(args)
if length(args) == 3
m = esc(args[1])
Expand All @@ -1019,14 +1031,19 @@ macro expression(args...)
c = gensym()
x = args[2]
else
error("@expression: needs at least two arguments.")
macro_error("needs at least two arguments.")
end
length(kw_args) == 0 || error("@expression: unrecognized keyword argument")
length(kw_args) == 0 || macro_error("unrecognized keyword argument")

anonvar = isexpr(c, :vect) || isexpr(c, :vcat) || length(args) == 2
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)

if any(_expr_contains_splat.(idxsets))
macro_error("cannot use splatting operator `...`.")
end

newaff, parsecode = _parse_expr_toplevel(x, :q)
code = quote
q = Val{false}()
Expand Down Expand Up @@ -1397,6 +1414,10 @@ macro variable(args...)
# We now build the code to generate the variables (and possibly the
# SparseAxisArray to contain them)
refcall, idxvars, idxsets, condition = _build_ref_sets(var, variable)
if any(_expr_contains_splat.(idxsets))
_error("cannot use splatting operator `...`.")
end

clear_dependencies(i) = (Containers.is_dependent(idxvars,idxsets[i],i) ? () : idxsets[i])

# Code to be used to create each variable of the container.
Expand Down Expand Up @@ -1510,6 +1531,10 @@ macro NLconstraint(m, x, extra...)
# Strategy: build up the code for non-macro add_constraint, and if needed
# we will wrap in loops to assign to the ConstraintRefs
refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_contains_splat.(idxsets))
error("@NLconstraint: cannot use splatting operator `...`.")
end

# Build the constraint
if isexpr(x, :call) # one-sided constraint
# Simple comparison - move everything to the LHS
Expand Down Expand Up @@ -1606,6 +1631,10 @@ macro NLexpression(args...)
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_contains_splat.(idxsets))
error("@NLexpression: cannot use splatting operator `...`.")
end

code = quote
$(refcall) = NonlinearExpression($(esc(m)), $(_process_NL_expr(m, x)))
end
Expand Down Expand Up @@ -1672,6 +1701,10 @@ macro NLparameter(m, ex, extra...)
variable = gensym()

refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable)
if any(_expr_contains_splat.(idxsets))
error("@NLparameter: cannot use splatting operator `...`.")
end

code = quote
if !isa($(esc(x)), Number)
error(string("in @NLparameter (", $(string(ex)), "): expected ",
Expand Down
27 changes: 27 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,33 @@ end
c = @NLconstraint(model, x == sum(1.0 for i in 1:0))
@test sprint(show, c) == "x - 0 = 0" || sprint(show, c) == "x - 0 == 0"
end

@testset "Splatting error" begin
model = Model()
A = [1 0; 0 1]
@variable(model, x)

@test_macro_throws ErrorException(
"In `@variable(model, y[axes(A)...])`: cannot use splatting operator `...`."
) @variable(model, y[axes(A)...])

@test_macro_throws ErrorException(
"In `@constraint(model, [i = [axes(A)...]], x >= i)`: cannot use splatting operator `...`."
) @constraint(model, [i=[axes(A)...]], x >= i)

@test_macro_throws ErrorException(
"@NLconstraint: cannot use splatting operator `...`."
) @NLconstraint(model, [i=[axes(A)...]], x >= i)

@test_macro_throws ErrorException(
"In `@expression(model, [i = [axes(A)...]], i * x)`: cannot use splatting operator `...`."
) @expression(model, [i=[axes(A)...]], i * x)

@test_macro_throws ErrorException(
"@NLexpression: cannot use splatting operator `...`."
) @NLexpression(model, [i=[axes(A)...]], i * x)
end

end

@testset "Macros for JuMPExtension.MyModel" begin
Expand Down

0 comments on commit 6caa7a0

Please sign in to comment.