In [26]:
using MacroTools
import MacroTools: prewalk, postwalk

In [68]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [61]:
struct FactorisationConstraintsEntry{N, T}
    indices :: NamedTuple{N, T}
end

In [62]:
struct FactorisationConstraintsEntries{T}
     prod :: T
end

In [63]:
struct FactorisationConstraintsSpecification{N, T}
    entries :: FactorisationConstraintsEntries{T}
end

In [64]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

@constraints (macro with 1 method)

In [65]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)

isblock (generic function with 1 method)

In [210]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, Symbol}()
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            names_str = join(names, ", ")
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification q($(names_str)) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            lhs_hash = hash(lhs)
            lhs_id = if haskey(lhs_dict, lhs_hash) lhs_dict[lhs_hash] else gensym(string(lhs)) end
            lhs_dict[lhs_hash] = lhs_id
            return quote 
                ($lhs_id === nothing) || error("Factorisation constraints specification q($($names_str)) = ... has been redefined.")
                $lhs_id = register_q_factorisation_specification($lhs, $rhs)
            end
        end
        return expression
    end
    
    cs_lhs_init_block = Expr(:block, map((pair) -> :($(last(pair)) = nothing), collect(lhs_dict))...)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            # TODO let block
            $(cs_lhs_init_block)
            $(cs_body)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [248]:
@e @constraints function aab(flag)
    q(x, y) = q(x)q(y)
end

function aab(flag; )
    albatross = nothing
    albatross === nothing || error("Factorisation constraints specification q(x, y) = ... has been redefined.")
    albatross = register_q_factorisation_specification(q(x, y), q(x) * q(y))
end


aab (generic function with 1 method)