In [1]:
using MacroTools
using BenchmarkTools 

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
end

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

In [46]:
struct FactorisationConstraintsEntry{N, T}
    indices :: NamedTuple{N, T}
end

Base.:(*)(left::NTuple{N1, <: FactorisationConstraintsEntry}, right::NTuple{N2, <: FactorisationConstraintsEntry}) where { N1, N2 } = (left..., right...)


__io_entry_pair(pair::Pair{Symbol, Nothing})::String = string(first(pair))

make_factorisation_constraints_entry(indices) = :(FactorisationConstraintsEntry($indices))

function Base.show(io::IO, entry::FactorisationConstraintsEntry) 
    print(io, "q(")
    entries = map(__io_entry_pair, collect(pairs(entry.indices)))
    join(io, entries, ", ")
    print(io, ")")
end

In [82]:
struct FactorisationConstraintsSpecification{N, E}
    entries :: E
end

Base.:(*)(left::FactorisationConstraintsSpecification, right::NTuple{N2, <: FactorisationConstraintsEntry}) where { N1, N2 } = (left, right...)
Base.:(*)(left::NTuple{N1, <: FactorisationConstraintsEntry}, right::FactorisationConstraintsSpecification) where { N1, N2 } = (left..., right)

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Val{Names}}) where Names
    
    print(io, "q(")
    join(io, Names, ", ")
    print(io, ")")
    
    compact = get(io, :compact, false)
    
    if !compact 
        print(io, " = ")
        foreach(factorisation.entries) do e
            print(IOContext(io, compact => true), e)
        end
    end
end

FactorisationConstraintsSpecification(::Type{ Val{N} }, entries::E) where { N, E } = FactorisationConstraintsSpecification{Val{N}, E}(entries)
    
register_q_factorisation_specification(N, E) = :(FactorisationConstraintsSpecification($N, $E))

register_q_factorisation_specification (generic function with 1 method)

In [83]:
cs

MethodError: MethodError: [0mCannot `convert` an object of type [92mBool[39m[0m to an object of type [91mSymbol[39m
[0mClosest candidates are:
[0m  convert(::Type{T}, [91m::T[39m) where T at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/essentials.jl:218
[0m  Symbol(::Any...) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/strings/basic.jl:229

In [6]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [7]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)

isblock (generic function with 1 method)

In [8]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [56]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            next_factorisation_constraints_symbol = gensym(:factorisation_constraint)
            
            new_factorisation_specification = register_q_factorisation_specification(:(Val{ ($(map(QuoteNode, names)...),) }), rhs)
            
            result = quote 
                ($(lhs_varname) === nothing) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $next_factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            factorisation_constraints_symbol = next_factorisation_constraints_symbol
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_varname = lhs_meta.varname
        return quote 
            $(lhs_varname) = nothing
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            expr_hash = hash(expression)
            if haskey(lhs_dict, expr_hash)
                lhs_meta = lhs_dict[ expr_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                return lhs_varname
            else
                # TODO add check that entries from LHS and RHS has same names
                # TODO parse rhs here
                rhs_prod_entries_args = map(args) do arg
                    if @capture(arg, argname_Symbol)
                        return make_factorisation_constraints_entry(:($argname = nothing, ))
                    else
                        error("Not implemented yet")
                    end
                end
                
                return :(($(rhs_prod_entries_args...), ))
            end
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [59]:
cs = @e @constraints begin
    q(x, y) = q(x)q(y)
    q(x, y, θ) = q(x, y)q(θ)
    q(x, y, z) = q(x, y)q(z)
end

(function albatross(; )
    nightingale = ()
    lemur = ()
    lyrebird = nothing
    pheasant = nothing
    sheep = nothing
    lyrebird === nothing || error("Factorisation constraints specification q(x, y) = ... has been redefined.")
    lyrebird = FactorisationConstraintsSpecification(Val{(:x, :y)}, (FactorisationConstraintsEntry((x = nothing,)),) * (FactorisationConstraintsEntry((y = nothing,)),))
    oyster = (lemur..., lyrebird)
    sheep === nothing || error("Factorisation constraints specification q(x, y, θ) = ... has been redefined.")
    sheep = FactorisationConstraintsSpecification(Val{(:x, :y, :θ)}, lyrebird * (FactorisationConstraintsEntry((θ = nothing,)),))
    heron = (oyster..., sheep)
    pheasant === nothing || error("Factorisation constraints specification q(x, y, z) = ... has been redefined.")
    pheasant = FactorisationConstraintsSpecification(Val{(:x, :y, :z)}, lyrebird * (FactorisationConstraintsEntry((z = nothing,)),))
    panda = (heron..., pheasant)
    Cons

Constraints:
	form: ()
	factorisation
		q(x, y)q(x)q(y)
		q(x, y, θ)q(x, y)q(x)q(y)q(θ)
		q(x, y, z)q(x, y)q(x)q(y)q(z)


 =  =  =  =  = 

In [11]:
cs

LoadError: UndefVarError: cs not defined

In [12]:
cs = constraints(true)

LoadError: UndefVarError: constraints not defined

In [13]:
dump(cs)

LoadError: UndefVarError: cs not defined

In [14]:
@code_warntype constraints(true)

LoadError: UndefVarError: constraints not defined