In [1]:
using MacroTools
import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
struct FactorisationConstraintsEntry{N, T}
    indices :: NamedTuple{N, T}
end

In [4]:
struct FactorisationConstraintsEntries{T}
     prod :: T
end

In [5]:
struct FactorisationConstraintsSpecification{N, T}
    entries :: FactorisationConstraintsEntries{T}
end

In [6]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

@constraints (macro with 1 method)

In [7]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)

isblock (generic function with 1 method)

In [8]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
    varname_used :: Symbol
end

In [19]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_varname_used = Symbol(lhs_varname, :_used)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname, lhs_varname_used)
                lhs_dict[lhs_hash] = lhs_meta
            end
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            return quote 
                ($(lhs_varname) === nothing) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = register_q_factorisation_specification(TODO_REPLACE_WITH_RHS_VAL, $rhs)
            end
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_varname = lhs_meta.varname
        lhs_varname_used = lhs_meta.varname_used
        return quote 
            $(lhs_varname) = nothing
            $(lhs_varname_used) = false
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            reconstructed_expression = Expr(:call, :q, args...)
            reconstructed_hash = hash(reconstructed_expression)
            if haskey(lhs_dict, reconstructed_hash)
                lhs_meta = lhs_dict[ reconstructed_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                lhs_varname_used = lhs_meta.varname_used
                return quote
                    if $(lhs_varname_used)
                        error("Factorisation constraint $(lhs_name) has been used multiple times")
                    end
                    $(lhs_varname_used) = true
                    $(lhs_varname)
                end
            else
                # TODO parse rhs here
                return expression
            end
        end
        return expression
    end
    
    # This check happens at the very and and simply checks boolean flags of all factorisation specification
    # throws an error if some has not been used
    cs_lhs_used_check = map(collect(lhs_dict)) do pair 
        lhs_meta = last(pair)
        lhs_name = lhs_meta.name
        lhs_varname_used = lhs_meta.varname_used
        return quote 
            if !$(lhs_varname_used)
                error("Factorisation constraint $($(lhs_name)) has been defined but never been used.")
            end
        end
    end
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            # TODO let block
            $(cs_lhs_init_block...)
            $(cs_body)
            $(cs_lhs_used_check...)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [23]:
@e @constraints function constraints(flag)
    q(x, y) = q(x)q(y)
    
    if flag
        q(x, y, θ) = q(x, y)q(θ)
    else
        q(x, y, θ) = q(x, y)q(θ)
    end
    
    q(x, y)
end

function constraints(flag; )
    hyena = nothing
    weasel = false
    coati = nothing
    jay = false
    hyena === nothing || error("Factorisation constraints specification q(x, y) = ... has been redefined.")
    hyena = register_q_factorisation_specification(TODO_REPLACE_WITH_RHS_VAL, q(x) * q(y))
    if flag
        coati === nothing || error("Factorisation constraints specification q(x, y, θ) = ... has been redefined.")
        coati = register_q_factorisation_specification(TODO_REPLACE_WITH_RHS_VAL, begin
                        if weasel
                            error("Factorisation constraint $(lhs_name) has been used multiple times")
                        end
                        weasel = true
                        hyena
                    end * q(θ))
    else
        coati === nothing || error("Factorisation constraints specification q(x, y, θ) = ... has been redefined.")
        coati = register_q_factorisation_specification(TODO_REPLACE_WITH_RHS_VAL, begin
     

constraints (generic function with 2 methods)

In [24]:
q(a, b) = 2
q(a) = 1
x = 1
y = 2
register_q_factorisation_specification(a, b) = 1.0

register_q_factorisation_specification (generic function with 1 method)

In [25]:
constraints(true)

LoadError: UndefVarError: TODO_REPLACE_WITH_RHS_VAL not defined