In [1]:
using MacroTools
using BenchmarkTools 
using TupleTools

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
__range_io_helper(value::Integer) = value
__range_io_helper(value::typeof(firstindex)) = :begin
__range_io_helper(value::typeof(lastindex)) = :end

struct CombinedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::CombinedRange) = print(io, __range_io_helper(range.from), ":", __range_io_helper(range.to))

make_combined_range(l, r) = :(CombinedRange($l, $r))

struct SplittedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::SplittedRange) = print(io, __range_io_helper(range.from), "..", __range_io_helper(range.to))

make_splitted_range(l, r) = :(SplittedRange($l, $r))
make_factorisation_split(l, r) = :(factorisation_split($l, $r))

make_factorisation_split (generic function with 1 method)

In [4]:
struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
end

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

In [5]:
a = (1, 2, 3)
b = (2, 3, 4)

(2, 3, 4)

In [6]:
map(e -> +(e...), zip(a, b))

3-element Vector{Int64}:
 3
 5
 7

In [7]:
struct FactorisationConstraintsEntry{N, T}
    indices :: T
end

getnames(entry::FactorisationConstraintsEntry{N}) where N = N
getindices(entry::FactorisationConstraintsEntry) = entry.indices

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

FactorisationConstraintsEntry(::Val{N}, indices::T) where { N, T } = FactorisationConstraintsEntry{N, T}(indices)

make_factorisation_constraint_entry(N, T) = :(FactorisationConstraintsEntry($N, $T))

__io_entry_pair(pair::Tuple)                    = __io_entry_pair(first(pair), last(pair))
__io_entry_pair(symbol::Symbol, ::Nothing)      = string(symbol)
__io_entry_pair(symbol::Symbol, index::Integer) = string(symbol, "[", index, "]")
__io_entry_pair(symbol::Symbol, index::typeof(firstindex)) = string(symbol, "[begin]")
__io_entry_pair(symbol::Symbol, index::typeof(lastindex)) = string(symbol, "[end]")
__io_entry_pair(symbol::Symbol, range::CombinedRange) = string(symbol, "[", range, "]")
__io_entry_pair(symbol::Symbol, range::SplittedRange) = string(symbol, "[", range, "]")

function Base.show(io::IO, entry::FactorisationConstraintsEntry{Names}) where Names
    print(io, "q(")
    entries = map(__io_entry_pair, zip(Names, entry.indices))
    join(io, entries, ", ")
    print(io, ")")
end

In [8]:
struct FactorisationConstraintsSpecification{N, E}
    entries :: E
end

getnames(specification::FactorisationConstraintsSpecification{N}) where N = N
getentries(specification::FactorisationConstraintsSpecification)          = specification.entries

Base.:(*)(left::FactorisationConstraintsSpecification, right::NTuple{N2, <: FactorisationConstraintsEntry}) where { N1, N2 } = (left, right...)
Base.:(*)(left::NTuple{N1, <: FactorisationConstraintsEntry}, right::FactorisationConstraintsSpecification) where { N1, N2 } = (left..., right)

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Names}) where Names
    
    print(io, "q(")
    join(io, Names, ", ")
    print(io, ")")
    
    compact = get(io, :compact, false)
    
    if !compact 
        print(io, " = ")
        foreach(factorisation.entries) do e
            print(IOContext(io, :compact => true), e)
        end
    end
    
end

FactorisationConstraintsSpecification(::Val{N}, entries::E)       where { N, E <: Tuple } = FactorisationConstraintsSpecification{N, E}(entries)
FactorisationConstraintsSpecification(::Val{N}, entries::Nothing) where { N          }    = error("Cannot create q(", join(N, ","), ") factorisation constraints specification")
    
make_factorisation_constraint(N, E) = :(FactorisationConstraintsSpecification($N, $E))

make_factorisation_constraint (generic function with 1 method)

In [9]:
using Unrolled

In [10]:
# Split related functions

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

# Only these combinations are allowed to be merged
__factorisation_split_merge_range(a::Int, b::Int)                              = SplittedRange(a, b)
__factorisation_split_merge_range(a::typeof(firstindex), b::Int)               = SplittedRange(a, b)
__factorisation_split_merge_range(a::Int, b::typeof(lastindex))                = SplittedRange(a, b)
__factorisation_split_merge_range(a::typeof(firstindex), b::typeof(lastindex)) = SplittedRange(a, b)
__factorisation_split_merge_range(a::Any, b::Any) = error("Cannot merge $(a) and $(b) indexes in `factorisation_split`")

function factorisation_split(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}})
    left_last   = last(left)
    right_first = first(right)
    (getnames(left_last) === getnames(right_first)) || error("Cannot split $(left_last) and $(right_first).")
    lindices = getindices(left_last)
    rindices = getindices(right_first)
    split_merged = unrolled_map(__factorisation_split_merge_range, lindices, rindices)
    
    first_split = first(split_merged)
    unrolled_all(e -> e === first_split, split_merged) || error("Inconsistent indices within factorisation split. Check $(split_merged) indices for $(getnames(left_last)) variables.")
    
    return (left[1:end - 1]..., FactorisationConstraintsEntry(Val(getnames(left_last)), split_merged), right[begin+1:end]...)
end

factorisation_split (generic function with 1 method)

In [11]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [12]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)
isref(something) = ishead(something, :ref)

isref (generic function with 1 method)

In [13]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [14]:
# TODO: check for intersections of ranges during an actual execution

In [15]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            # We replace '..' in RHS expression with `make_factorisation_split`
            rhs = postwalk(rhs) do rexpr
                if @capture(rexpr, a_ .. b_)
                    return make_factorisation_split(a, b)
                end
                return rexpr
            end
            
            lhs_names = Set{Symbol}(names)
            rhs_names = Set{Symbol}()
            
            # We do a simple check to be sure that LHS and RHS has the exact same set of names
            # We also check here that all indices are either a simple Symbol or an indexing expression here
            rhs = postwalk(MacroTools.prettify(rhs, alias = false)) do entry
                if @capture(entry, q(indices__))
                    for index in indices
                        if index isa Symbol
                            (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.")
                            push!(rhs_names, index)
                            (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.")
                        elseif isref(index)
                            push!(rhs_names, first(index.args))
                            (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.")
                        else
                           error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                        end
                    end
                end
                return entry
            end
            
            (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
            
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            
            new_factorisation_specification = make_factorisation_constraint(:(Val(($(map(QuoteNode, names)...),))), rhs)
            
            result = quote 
                ($(lhs_varname) === nothing) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_varname = lhs_meta.varname
        return quote 
            $(lhs_varname) = nothing
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            expr_hash = hash(expression)
            if haskey(lhs_dict, expr_hash)
                lhs_meta = lhs_dict[ expr_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                return lhs_varname
            else
                rhs_prod_names = Symbol[]
                rhs_prod_entries_args = map(args) do arg
                    if arg isa Symbol
                        push!(rhs_prod_names, arg)
                        return :(nothing)
                    elseif isref(arg)
                        (length(arg.args) === 2) || error("Indexing expression $(expression) is too difficult to parse and is not supported (yet?).")
                        push!(rhs_prod_names, first(arg.args))
                        
                        index = last(arg.args)
                        
                        # First we replace all `begin` and `end` with `firstindex` and `lastindex` functions
                        index = postwalk(index) do iexpr
                            if iexpr isa Symbol && iexpr === :begin
                                return :(firstindex)
                            elseif iexpr isa Symbol && iexpr === :end
                                return :(lastindex)
                            else
                                return iexpr
                            end
                        end
                        
                        if @capture(index, a_:b_)
                            return make_combined_range(a, b)
                        elseif @capture(index, firstindex)
                            return index
                        elseif @capture(index, lastindex)
                            return index
                        else
                            return :(convert(Integer, $index))
                        end
                    else
                        error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                    end
                end
                
                entry = make_factorisation_constraint_entry(:(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(($(rhs_prod_entries_args...), )))
                
                return :(($entry, ))
            end
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [62]:
cs = @constraints begin
    q(x, y) = q(x)q(y)
    q(x, y, t) = q(x, y)q(t)
    q(x, w) = q(x)q(w)
    q(y, w) = q(y)q(w)
end

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t) = q(x, y)q(t)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [63]:
# node with (x, t, y, w) pure indices

In [64]:
ns = (:x, :t, :y, :w)
is = (nothing, nothing, nothing, nothing)

(nothing, nothing, nothing, nothing)

In [65]:
dump(cs.factorisation, maxdepth = 1)

Tuple{FactorisationConstraintsSpecification{(:x, :y), Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}, FactorisationConstraintsSpecification{(:x, :y, :t), Tuple{FactorisationConstraintsSpecification{(:x, :y), Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}, FactorisationConstraintsEntry{(:t,), Tuple{Nothing}}}}, FactorisationConstraintsSpecification{(:x, :w), Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:w,), Tuple{Nothing}}}}, FactorisationConstraintsSpecification{(:y, :w), Tuple{FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:w,), Tuple{Nothing}}}}}
  1: FactorisationConstraintsSpecification{(:x, :y), Tuple{FactorisationConstraintsEntry{(:x,), Tuple{Nothing}}, FactorisationConstraintsEntry{(:y,), Tuple{Nothing}}}}
  2: FactorisationConstraintsSpecification{(:x, :y, :t),

In [100]:
function create_factorisation(ns::Val{T}, is, constraints) where { T }
    # Indexmap creates simple mapping for each symbol in `ns`
    # E.g. ns = (:x, y)
    # indexmap = (x = 1, y = 2)
    indexmap = NamedTuple{T}(ntuple(identity, length(T)))
    
    # `factorisation` is a tuple of `FactorisationConstraintsSpecification`s
    # FactorisationConstraintsSpecification has names of LHS and specs of RHS
    factorisation = constraints.factorisation
    
    function __create_factorisation_entry(::Val{symbol}) where symbol
        
        # This function unrolles and reduces (1, 2, 3, 4) for each product entry in factorisation specification
        # E.g if q(x, y) = q(x) * q(y) this function iterates over [ q(x), q(y) ]
        function __filter_entry(factorisation_entries::Tuple, entry::NTuple{N, Int}) where { N }
             return unrolled_reduce(__filter_entry, entry, factorisation_entries)
        end
        
        function __filter_entry(spec::FactorisationConstraintsSpecification, entry::NTuple{N, Int}) where N 
            println(symbol)
            if symbol ∈ getnames(spec)
                return __filter_entry(getentries(spec), entry)
            end
            return entry
        end
        
        # This is totally wrong but just to test an idea
        function __filter_entry(csentry::FactorisationConstraintsEntry, entry::NTuple{N, Int}) where N
            if symbol ∉ getnames(csentry)
                unrolled_foreach(getnames(csentry)) do name
                    name_index = indexmap[name]
                    entry = unrolled_filter(!=(name_index), entry)
                end
            end
            return entry
        end

        # Here we create an entry template, e.g. (1, 2, 3, 4)
        template = ntuple(identity, length(T))
        
        unrolled_foreach(factorisation) do spec 
            template = __filter_entry(spec, template)
        end
        
        template
    end
    
    entries = unrolled_map(s -> __create_factorisation_entry(Val(s)), T)

end

create_factorisation (generic function with 1 method)

In [101]:
ns

(:x, :t, :y, :w)

In [102]:
cs

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t) = q(x, y)q(t)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [103]:
create_factorisation(Val(ns), is, cs)

x
x
x
x
x
t
t
t
t
t
y
y
y
y
y
w
w
w
w


((1,), (1, 2, 3, 4), (3,), (2, 4))

In [70]:
@btime create_factorisation(Val($ns), $is, $cs)

  4.970 μs (137 allocations: 3.95 KiB)


((1,), (1, 2, 3, 4), (3,), (2, 4))

In [24]:
# Assume we have a node

In [71]:
4.970 * 1000 

4970.0