In [1]:
using MacroTools
using BenchmarkTools 

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
end

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

In [14]:
struct FactorisationConstraintsEntry{N, T}
    indices :: T
end

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

__io_entry_pair(pair::Pair)                     = __io_entry_pair(first(pair), last(pair))
__io_entry_pair(symbol::Symbol, ::Nothing)      = string(symbol)
__io_entry_pair(symbol::Symbol, index::Integer) = string(symbol, "[", index, "]")
__io_entry_pair(symbol::Symbol, range::AbstractRange) = string(symbol, "[", range, "]")

FactorisationConstraintsEntry(::Val{N}, indices::T) where { N, T } = FactorisationConstraintsEntry{N, T}(indices)

make_factorisation_constraint_entry(N, T) = :(FactorisationConstraintsEntry($N, $T))

function Base.show(io::IO, entry::FactorisationConstraintsEntry) 
    print(io, "q(")
    entries = map(__io_entry_pair, collect(pairs(entry.indices)))
    join(io, entries, ", ")
    print(io, ")")
end

In [5]:
struct FactorisationConstraintsSpecification{N, E}
    entries :: E
end

Base.:(*)(left::FactorisationConstraintsSpecification, right::NTuple{N2, <: FactorisationConstraintsEntry}) where { N1, N2 } = (left, right...)
Base.:(*)(left::NTuple{N1, <: FactorisationConstraintsEntry}, right::FactorisationConstraintsSpecification) where { N1, N2 } = (left..., right)

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Names}) where Names
    
    print(io, "q(")
    join(io, Names, ", ")
    print(io, ")")
    
    compact = get(io, :compact, false)
    
    if !compact 
        print(io, " = ")
        foreach(factorisation.entries) do e
            print(IOContext(io, :compact => true), e)
        end
    end
    
end

FactorisationConstraintsSpecification(::Val{N}, entries::E) where { N, E } = FactorisationConstraintsSpecification{N, E}(entries)
    
make_factorisation_constraint(N, E) = :(FactorisationConstraintsSpecification($N, $E))

make_factorisation_constraint (generic function with 1 method)

In [6]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [7]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)
isref(something) = ishead(something, :ref)

isref (generic function with 1 method)

In [8]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [27]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            @capture(rhs, *(rhs_prod_entries__)) || error("Invalid RHS $(rhs) of the factorisation specification expression $(expression)")
            
            lhs_names = Set(names)
            rhs_names = Set{Symbol}()
            
            # We do a simple check to be sure that LHS and RHS has the exact same set of names
            # We also check here that all indices are either a simple Symbol or an indexing expression here
            for entry in rhs_prod_entries
                if @capture(entry, q(indices__))
                    for index in indices
                        if index isa Symbol
                            (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.")
                            push!(rhs_names, index)
                            (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.")
                        elseif isref(index)
                            push!(rhs_names, first(index.args))
                            (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.")
                        else
                           error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                        end
                    end
                end
            end
            
            (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
            
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            
            new_factorisation_specification = make_factorisation_constraint(:(Val(($(map(QuoteNode, names)...),))), rhs)
            
            result = quote 
                ($(lhs_varname) === nothing) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_varname = lhs_meta.varname
        return quote 
            $(lhs_varname) = nothing
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            expr_hash = hash(expression)
            if haskey(lhs_dict, expr_hash)
                lhs_meta = lhs_dict[ expr_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                return lhs_varname
            else
                rhs_prod_names = Symbol[]
                rhs_prod_entries_args = map(args) do arg
                    if arg isa Symbol
                        push!(rhs_prod_names, arg)
                        return :($(QuoteNode(arg)) => nothing)
                    elseif isref(arg)
                        push!(rhs_prod_names, first(arg.args))
                        return :($(QuoteNode(first(arg.args))) => $(last(arg.args)))
                    else
                        error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                    end
                end
                
                entry = make_factorisation_constraint_entry(:(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(($(rhs_prod_entries_args...), )))
                
                return :(($entry, ))
            end
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [28]:
# TODO: check for intersections of ranges during an actual execution

In [29]:
@e @constraints function forbench(flag, n)
    # q(x) = q(x[1])..q(x[n-2])q(x[n-1], x[n])
    q(x) = q(x[n-2])q(x[n-1], x[n])
    q(x, y, z, r) = q(x, y)q(z, r)
end

@time forbench(true, 2)
@time forbench(true, 2)

cs = forbench(true, 2)

function forbench(flag, n; )
    anteater = ()
    caterpillar = ()
    otter = nothing
    salmon = nothing
    otter === nothing || error("Factorisation constraints specification q(x) = ... has been redefined.")
    otter = FactorisationConstraintsSpecification(Val((:x,)), (FactorisationConstraintsEntry(Val((:x,)), (:x => n - 2,)),) * (FactorisationConstraintsEntry(Val((:x, :x)), (:x => n - 1, :x => n)),))
    caterpillar = (caterpillar..., otter)
    salmon === nothing || error("Factorisation constraints specification q(x, y, z, r) = ... has been redefined.")
    salmon = FactorisationConstraintsSpecification(Val((:x, :y, :z, :r)), (FactorisationConstraintsEntry(Val((:x, :y)), (:x => nothing, :y => nothing)),) * (FactorisationConstraintsEntry(Val((:z, :r)), (:z => nothing, :r => nothing)),))
    caterpillar = (caterpillar..., salmon)
    ConstraintsSpecification(caterpillar, anteater)
end
  0.000005 seconds
  0.000004 seconds


MethodError: MethodError: no method matching __io_entry_pair(::Int64, ::Pair{Symbol, Int64})

In [12]:
..(a::Int, b::Int) = a + b

.. (generic function with 1 method)

In [13]:
@code_warntype forbench(true, 2)

MethodInstance for forbench(::Bool, ::Int64)
  from forbench(flag, n) in Main at In[9]:129
Arguments
  #self#[36m::Core.Const(forbench)[39m
  flag[36m::Bool[39m
  n[36m::Int64[39m
Locals
  q(x, y, z, r)#294[36m::Nothing[39m
  q(x)#293[36m::Nothing[39m
  factorisation_constraint#292[36m::Tuple{}[39m
  form_constraint#291[36m::Tuple{}[39m
Body[36m::Union{}[39m
[90m1 ─[39m       (form_constraint#291 = ())
[90m│  [39m       (factorisation_constraint#292 = ())
[90m│  [39m       (q(x)#293 = Main.nothing)
[90m│  [39m       (q(x, y, z, r)#294 = Main.nothing)
[90m│  [39m %5  = (q(x)#293 === Main.nothing)[36m::Core.Const(true)[39m
[90m│  [39m       Core.typeassert(%5, Core.Bool)
[90m└──[39m       goto #3
[90m2 ─[39m       Core.Const(:(Base.string("Factorisation constraints specification ", "q(x)", " = ... has been redefined.")))
[90m└──[39m       Core.Const(:(Main.error(%8)))
[90m3 ┄[39m %10 = (:x,)[36m::Core.Const((:x,))[39m
[90m│  [39m       Main.Val(%