In [43]:
using MacroTools
using BenchmarkTools 

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [210]:
struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
end

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

In [203]:
struct FactorisationConstraintsEntry{N, T}
    indices :: NamedTuple{N, T}
end

__io_entry_pair(pair::Pair{Symbol, Nothing})::String = string(first(pair))

make_factorisation_constraints_entry(indices) = :(FactorisationConstraintsEntry($indices))

function Base.show(io::IO, entry::FactorisationConstraintsEntry) 
    print(io, "q(")
    entries = map(__io_entry_pair, collect(pairs(entry.indices)))
    join(io, entries, ", ")
    print(io, ")")
end

In [214]:
struct FactorisationConstraintsSpecification{N, E}
    entries :: E
end

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Val{Names}}) where Names
    print(io, "q(")
    join(io, Names, ", ")
    print(io, ") = ")
    foreach(factorisation.entries) do e
        print(io, e)
    end
end

FactorisationConstraintsSpecification(::Type{ Val{N} }, entries::E) where { N, E } = FactorisationConstraintsSpecification{Val{N}, E}(entries)
    
register_q_factorisation_specification(N, E) = :(FactorisationConstraintsSpecification($N, ($(E...), )))

register_q_factorisation_specification (generic function with 1 method)

In [138]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

register_q_factorisation_specification (generic function with 1 method)

In [139]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)

isblock (generic function with 1 method)

In [140]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
    varname_used :: Symbol
end

In [170]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_varname_used = Symbol(lhs_varname, :_used)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname, lhs_varname_used)
                lhs_dict[lhs_hash] = lhs_meta
            end
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            rhs_prod_entries = nothing
            
            # Here we extract all entry expression from the right hand side of the base expression
            # TODO add special case with `..`
            if @capture(rhs, *(prod_entries__, ))
                rhs_prod_entries = map(prod_entries) do entry 
                    @capture(entry, q(b_)) || error("Cannot parse a single RHS entry $(entry) of the expressiob $(expression)") 
                    return b
                end
            else
                error("Cannot parse RHS of the $(expression) expression.")
            end
            
            rhs_prod_entries = map(rhs_prod_entries) do prod_entry 
                # Check if it is a raw symbol
                if @capture(prod_entry, entryname_Symbol)
                    return make_factorisation_constraints_entry(:($entryname = nothing, ))
                else
                    # TODO parse single/multiple indices
                    error("Not yet implemented")
                end
            end
            
            # TODO add check that entries from LHS and RHS has same names
            
            next_factorisation_constraints_symbol = gensym(:factorisation_constraint)
            
            new_factorisation_specification = register_q_factorisation_specification(:(Val{ ($(map(QuoteNode, names)...),) }), rhs_prod_entries)
            
            result = quote 
                ($(lhs_varname) === nothing) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $next_factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            factorisation_constraints_symbol = next_factorisation_constraints_symbol
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_varname = lhs_meta.varname
        lhs_varname_used = lhs_meta.varname_used
        return quote 
            $(lhs_varname) = nothing
            $(lhs_varname_used) = false
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            reconstructed_expression = Expr(:call, :q, args...)
            reconstructed_hash = hash(reconstructed_expression)
            if haskey(lhs_dict, reconstructed_hash)
                lhs_meta = lhs_dict[ reconstructed_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                lhs_varname_used = lhs_meta.varname_used
                return quote
                    if $(lhs_varname_used)
                        error("Factorisation constraint $(lhs_name) has been used multiple times")
                    end
                    $(lhs_varname_used) = true
                    $(lhs_varname)
                end
            else
                # TODO parse rhs here
                return expression
            end
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [224]:
cs = @e @constraints begin
    q(x, y) = q(x)q(y)
    q(x, θ) = q(θ)q(x)
end

(function albatross(; )
    nightingale = ()
    lemur = ()
    lyrebird = nothing
    pheasant = false
    sheep = nothing
    oyster = false
    lyrebird === nothing || error("Factorisation constraints specification q(x, y) = ... has been redefined.")
    lyrebird = FactorisationConstraintsSpecification(Val{(:x, :y)}, (FactorisationConstraintsEntry((x = nothing,)), FactorisationConstraintsEntry((y = nothing,))))
    heron = (lemur..., lyrebird)
    sheep === nothing || error("Factorisation constraints specification q(x, θ) = ... has been redefined.")
    sheep = FactorisationConstraintsSpecification(Val{(:x, :θ)}, (FactorisationConstraintsEntry((θ = nothing,)), FactorisationConstraintsEntry((x = nothing,))))
    panda = (heron..., sheep)
    ConstraintsSpecification(panda, nightingale)
end)()


Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, θ) = q(θ)q(x)


In [225]:
cs

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, θ) = q(θ)q(x)


In [221]:
cs = constraints(true)

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, θ) = q(θ)q(x)


In [160]:
dump(cs)

ConstraintsSpecification{Tuple{FactorisationConstraintsSpecification{Val{(:x, :y)}, FactorisationConstraintsEntries{Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}}}, Tuple{}}
  factorisation: Tuple{FactorisationConstraintsSpecification{Val{(:x, :y)}, FactorisationConstraintsEntries{Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}}}
    1: FactorisationConstraintsSpecification{Val{(:x, :y)}, FactorisationConstraintsEntries{Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}}
      entries: FactorisationConstraintsEntries{Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}
        prod: Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}
          1: FactorisationConstrai

In [99]:
@code_warntype constraints(true)

MethodInstance for constraints(::Bool)
  from constraints(flag) in Main at In[96]:101
Arguments
  #self#[36m::Core.Const(constraints)[39m
  flag[36m::Bool[39m
Locals
  factorisation_constraint#455[36m::Tuple{FactorisationConstraintsSpecification{Val{(:x, :y)}, Int64}}[39m
  q(x, y)#454_used[36m::Bool[39m
  q(x, y)#454[33m[1m::Union{Nothing, FactorisationConstraintsSpecification{Val{(:x, :y)}, Int64}}[22m[39m
  factorisation_constraint#453[36m::Tuple{}[39m
  form_constraint#452[36m::Tuple{}[39m
Body[36m::Expr[39m
[90m1 ─[39m       Core.NewvarNode(:(factorisation_constraint#455))
[90m│  [39m       (form_constraint#452 = ())
[90m│  [39m       (factorisation_constraint#453 = ())
[90m│  [39m       (q(x, y)#454 = Main.nothing)
[90m│  [39m       (q(x, y)#454_used = false)
[90m│  [39m %6  = (q(x, y)#454::Core.Const(nothing) === Main.nothing)[36m::Core.Const(true)[39m
[90m│  [39m       Core.typeassert(%6, Core.Bool)
[90m└──[39m       goto #3
[90m2 ─[39m    