In [1]:
using MacroTools
using BenchmarkTools 
using TupleTools
using Test
using ReactiveMP

import MacroTools: prewalk, postwalk

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423
[33m[1m│ [22m[39m- If you have ReactiveMP checked out for development and have
[33m[1m│ [22m[39m  added Unrolled as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with ReactiveMP


In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
# TODO: copy to GraphPPL
make_functional_index(reprsymbol::Symbol, fnsymbol::Symbol) = :(ReactiveMP.FunctionalIndex{$(QuoteNode(reprsymbol))}($fnsymbol))

make_functional_index (generic function with 1 method)

In [4]:
# TODO: copy to GraphPPL
make_combined_range(l, r) = :(ReactiveMP.CombinedRange($l, $r))
# TODO: copy to GraphPPL
make_splitted_range(l, r) = :(ReactiveMP.SplittedRange($l, $r))
make_factorisation_split(l, r) = :(ReactiveMP.factorisation_split($l, $r)) 

make_factorisation_split (generic function with 1 method)

In [5]:
# TODO: copy to GraphPPL
make_constraints_specification(factorisation, form) = :(ReactiveMP.ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [6]:
# Copied to ReactiveMP, but not GraphPPL

In [7]:
# TODO: Copy to GraphPPL
make_factorisation_constraint_entry(N, T) = :(ReactiveMP.FactorisationConstraintsEntry($N, $T))

make_factorisation_constraint_entry (generic function with 1 method)

In [8]:
# TODO: Copy to GraphPPL.jl
make_factorisation_constraint(N, E) = :(ReactiveMP.FactorisationConstraintsSpecification($N, $E))

make_factorisation_constraint (generic function with 1 method)

In [9]:
init_factorisation_not_defined(V, S)  = :($V = ReactiveMP.FactorisationSpecificationNotDefinedYet{$(QuoteNode(S))}())
check_factorisation_is_not_defined(V) = :($V isa ReactiveMP.FactorisationSpecificationNotDefinedYet)

check_factorisation_is_not_defined (generic function with 1 method)

In [10]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

@constraints (macro with 1 method)

In [11]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)
isref(something) = ishead(something, :ref)

isref (generic function with 1 method)

In [12]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [13]:
# TODO: check for intersections of ranges during an actual execution

In [14]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            # We replace '..' in RHS expression with `make_factorisation_split`
            rhs = postwalk(rhs) do rexpr
                if @capture(rexpr, a_ .. b_)
                    return make_factorisation_split(a, b)
                end
                return rexpr
            end
            
            lhs_names = Set{Symbol}(names)
            rhs_names = Set{Symbol}()
            
            # We do a simple check to be sure that LHS and RHS has the exact same set of names
            # We also check here that all indices are either a simple Symbol or an indexing expression here
            rhs = postwalk(MacroTools.prettify(rhs, alias = false)) do entry
                if @capture(entry, q(indices__))
                    for index in indices
                        if index isa Symbol
                            (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.")
                            push!(rhs_names, index)
                            (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.")
                        elseif isref(index)
                            push!(rhs_names, first(index.args))
                            (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.")
                        else
                           error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                        end
                    end
                end
                return entry
            end
            
            (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
            
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            
            new_factorisation_specification = make_factorisation_constraint(:(Val(($(map(QuoteNode, names)...),))), :(Val($(rhs))))
            check_is_not_defined            = check_factorisation_is_not_defined(lhs_varname)
            
            result = quote 
                $(check_is_not_defined) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_name = lhs_meta.name
        lhs_varname = lhs_meta.varname
        lhs_symbol = Symbol(lhs_name)
        return init_factorisation_not_defined(lhs_varname, lhs_symbol)
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            rhs_prod_names = Symbol[]
            rhs_prod_entries_args = map(args) do arg
                if arg isa Symbol
                    push!(rhs_prod_names, arg)
                    return :(nothing)
                elseif isref(arg)
                    (length(arg.args) === 2) || error("Indexing expression $(expression) is too difficult to parse and is not supported (yet?).")
                    push!(rhs_prod_names, first(arg.args))

                    index = last(arg.args)

                    # First we replace all `begin` and `end` with `firstindex` and `lastindex` functions
                    index = postwalk(index) do iexpr
                        if iexpr isa Symbol && iexpr === :begin
                            return make_functional_index(:begin, :firstindex)
                        elseif iexpr isa Symbol && iexpr === :end
                            return make_functional_index(:end, :lastindex)
                        else
                            return iexpr
                        end
                    end

                    if @capture(index, a_:b_)
                        return make_combined_range(a, b)
                    else
                        return index
                    end
                else
                    error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                end
            end

            entry = make_factorisation_constraint_entry(:(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(Val(($(rhs_prod_entries_args...), ))))

            return :(($entry, ))
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [15]:
cs = @constraints function test()
    q(x, y) = q(x)q(y)
    q(x, y, t, r) = q(x, y)q(t)q(r)
    q(x, w) = q(x)q(w)
    q(y, w) = q(y)q(w)
    q(x) = q(x[begin:begin+2])q(x[begin+3])..q(x[end])
end
        
@time test()
@time test()

cs = test()

  0.000007 seconds (14 allocations: 2.062 KiB)
  0.000010 seconds (14 allocations: 2.062 KiB)


Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t, r) = q(x, y)q(t)q(r)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)
		q(x) = q(x[(begin):((begin) + 2)])q(x[((begin) + 3)..(end)])


In [16]:
model = Model()

x = randomvar(model, :x, 10)
y = randomvar(model, :y, 10)
tmp = randomvar(model, :tmp)
t = randomvar(model, :t, 10)
r = randomvar(model, :r)

vars = (x[3], x[4], y[1], t[1]);
# vars = (xvar, yvar, tvar)

In [17]:
@time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
@time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)

  1.020368 seconds (2.67 M allocations: 146.782 MiB, 3.68% gc time, 99.91% compilation time)
  0.000057 seconds (187 allocations: 6.609 KiB)


((1,), (2,), (3,), (4,))

In [18]:
@time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
@time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)

  0.000091 seconds (187 allocations: 6.609 KiB)
  0.000064 seconds (187 allocations: 6.609 KiB)


((1,), (2,), (3,), (4,))

In [19]:
@btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)

  6.841 μs (181 allocations: 6.42 KiB)


((1,), (2,), (3,), (4,))

In [20]:
# fast playground
let 
    cs = @constraints begin
        q(x, y) = q(x[begin], y[begin])..q(x[end], y[end])
        # q(x, y) = (q(x[begin])..q(x[end]))*(q(y[begin])..q(y[end]))        
        q(x, y, t) = q(x, y)q(t)
        q(x, y, r) = q(x, y)q(r)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 10)
    x = randomvar(model, :x, 10)
    t = randomvar(model, :t)
    r = randomvar(model, :r)
    
    vars = (y[1], y[2], x[1], x[2], t, r)
    
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    
    @btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)
end

  0.666028 seconds (1.65 M allocations: 89.499 MiB, 3.30% gc time, 99.90% compilation time)
  0.000097 seconds (331 allocations: 13.547 KiB)
  13.143 μs (325 allocations: 13.36 KiB)


((1, 3), (2, 4), (5, 6))

In [21]:
# fast playground
let 
    cs = @constraints begin
        q(x, y) = (q(x[begin])..q(x[end]))*(q(y[begin])..q(y[end]))        
        q(x, y, t) = q(x, y)q(t)
        q(x, y, r) = q(x, y)q(r)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 10)
    x = randomvar(model, :x, 10)
    t = randomvar(model, :t)
    r = randomvar(model, :r)
    
    vars = (y[1], y[2], x[1], x[2], t, r)
    
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    
    @btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)
end

  0.219370 seconds (607.40 k allocations: 32.503 MiB, 99.88% compilation time)
  0.000058 seconds (307 allocations: 12.562 KiB)
  11.403 μs (301 allocations: 12.38 KiB)


((1,), (2,), (3,), (4,), (5, 6))

In [22]:
# fast playground
let 
    cs = @constraints begin
        q(x, y) = q(y[1])q(x[begin], y[begin+1])..q(x[end], y[end])       
        q(x, y, t) = q(x, y)q(t)
        q(x, y, r) = q(x, y)q(r)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    t = randomvar(model, :t)
    r = randomvar(model, :r)
    
    vars = (y[1], y[2], x[1], x[2], t, r)
    
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    
    @btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)
end

  0.408522 seconds (1.10 M allocations: 59.181 MiB, 3.02% gc time, 99.92% compilation time)
  0.000062 seconds (395 allocations: 15.141 KiB)
  15.064 μs (389 allocations: 14.95 KiB)


((1,), (2, 3), (4,), (5, 6))

In [23]:
# fast playground
let 
    cs = @constraints begin
        q(x, y) = q(x)q(y)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    
    vars = (y[1], y[2], x[1], x[2])
    
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    
    @btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)
end

  0.101990 seconds (300.36 k allocations: 15.434 MiB, 99.90% compilation time)
  0.000024 seconds (89 allocations: 3.750 KiB)
  3.887 μs (83 allocations: 3.56 KiB)


((1, 2), (3, 4))

In [24]:
# fast playground
let 
    cs = @constraints begin
        q(x, y) = q(x)q(y)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    
    vars = (y[1], x[1])
    
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    @time ReactiveMP.resolve_factorisation(:(1 + 1), vars, cs, model)
    
    @btime ReactiveMP.resolve_factorisation(:(1 + 1), $vars, $cs, $model)
end

  0.184243 seconds (449.99 k allocations: 23.848 MiB, 6.52% gc time, 99.89% compilation time)
  0.000024 seconds (49 allocations: 1.844 KiB)
  2.045 μs (43 allocations: 1.66 KiB)


((1,), (2,))

In [25]:
# fast playground
let 
    @constraints function withflag(flag)
        if flag
            q(x, y) = q(x)q(y)
        else
            q(x, y) = q(x, y)
        end
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    
    vars = (y[1], x[1])
    
    @show ReactiveMP.resolve_factorisation(:(1 + 1), vars, withflag(true), model)
    @show ReactiveMP.resolve_factorisation(:(1 + 1), vars, withflag(false), model)
end

ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), vars, withflag(true), model) = ((1,), (2,))
ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), vars, withflag(false), model) = ((1, 2),)


((1, 2),)

In [26]:
# fast playground
let 
    @constraints function withflag(n)
        q(x, y) = q(x[1:n])*(q(x[n+1])..q(x[end]))*q(y)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    
    vars = (y[1], x[2], x[3])
    
    @show ReactiveMP.resolve_factorisation(:(1 + 1), vars, withflag(2), model)
    @show ReactiveMP.resolve_factorisation(:(1 + 1), vars, withflag(4), model)
end

ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), vars, withflag(2), model) = ((1,), (2,), (3,))
ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), vars, withflag(4), model) = ((1,), (2, 3))


((1,), (2, 3))

In [30]:
# fast playground
let 
    cs = @constraints begin
       q(x, y) = q(x[begin])*q(x[begin+1:end])*q(y)
    end
            
    model = Model()
            
    y = randomvar(model, :y, 11)
    x = randomvar(model, :x, 10)
    
    @show ReactiveMP.resolve_factorisation(:(1 + 1), (y[1], x[1], x[2]), cs, model)
    @show ReactiveMP.resolve_factorisation(:(1 + 1), (y[1], x[2], x[3]), cs, model)
end

ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), (y[1], x[1], x[2]), cs, model) = ((1,), (2,), (3,))
ReactiveMP.resolve_factorisation($(Expr(:quote, :(1 + 1))), (y[1], x[2], x[3]), cs, model) = ((1,), (2, 3))


((1,), (2, 3))

In [27]:
# Should error
@constraints begin
    q(x, y) = q(x[begin], x[begin])..q(x[end], x[end])q(y)
end

LoadError: Cannot split q(x[(begin)], x[(begin)]) and q(x[(end)], x[(end)]). Names should be unique.