In [1]:
using MacroTools
using BenchmarkTools 
using TupleTools

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
__range_io_helper(value::Integer) = value
__range_io_helper(value::typeof(firstindex)) = :begin
__range_io_helper(value::typeof(lastindex)) = :end

struct CombinedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::CombinedRange) = print(io, __range_io_helper(range.from), ":", __range_io_helper(range.to))

make_combined_range(l, r) = :(CombinedRange($l, $r))

struct SplittedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::SplittedRange) = print(io, __range_io_helper(range.from), "..", __range_io_helper(range.to))

make_splitted_range(l, r) = :(SplittedRange($l, $r))
make_factorisation_split(l, r) = :(factorisation_split($l, $r))

make_factorisation_split (generic function with 1 method)

In [4]:
struct NotDefinedYet{S} end

In [5]:
const ConstraintsSpecificationPreallocatedDefaultSize = 64

struct ConstraintsSpecificationPreallocated
    clusters_template :: BitVector
    clusters_usage    :: BitVector
    clusters_set      :: Set{Tuple}
    cluster_indices   :: Vector{Int}
    
    ConstraintsSpecificationPreallocated() = new(trues(ConstraintsSpecificationPreallocatedDefaultSize), falses(ConstraintsSpecificationPreallocatedDefaultSize), Set{Tuple}(), Vector{Int}(undef, ConstraintsSpecificationPreallocatedDefaultSize))
end

function __reset_preallocated!(preallocated::ConstraintsSpecificationPreallocated, size::Int)
    abs2size = abs2(size)
    if length(preallocated.clusters_template) < abs2size
        resize!(preallocated.clusters_template, abs2size)
        resize!(preallocated.clusters_usage, abs2size) # note: we dont need `size^2` for this, just `size` should be enough, but just to avoid extra checks
    end
    
    if length(preallocated.cluster_indices) < size
        resize!(preallocated.cluster_indices, size)
    end
    
    fill!(preallocated.clusters_template, true)
    fill!(preallocated.clusters_usage, false)
    
    empty!(preallocated.clusters_set)
end

struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
    preallocated :: ConstraintsSpecificationPreallocated
end

ConstraintsSpecification(factorisation::F, form::M) where { F, M } = ConstraintsSpecification{F, M}(factorisation, form, ConstraintsSpecificationPreallocated())

__reset_preallocated!(specification::ConstraintsSpecification, size::Int) = __reset_preallocated!(specification.preallocated, size)

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [6]:
struct FactorisationConstraintsEntry{N, I} end

getnames(entry::FactorisationConstraintsEntry{N}) where N = N
getindices(entry::FactorisationConstraintsEntry{N, I}) where { N, I } = I

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

FactorisationConstraintsEntry(::Val{N}, ::Val{I}) where { N, I } = FactorisationConstraintsEntry{N, I}()

make_factorisation_constraint_entry(N, T) = :(FactorisationConstraintsEntry($N, $T))

__io_entry_pair(pair::Tuple)                    = __io_entry_pair(first(pair), last(pair))
__io_entry_pair(symbol::Symbol, ::Nothing)      = string(symbol)
__io_entry_pair(symbol::Symbol, index::Integer) = string(symbol, "[", index, "]")
__io_entry_pair(symbol::Symbol, index::typeof(firstindex)) = string(symbol, "[begin]")
__io_entry_pair(symbol::Symbol, index::typeof(lastindex)) = string(symbol, "[end]")
__io_entry_pair(symbol::Symbol, range::CombinedRange) = string(symbol, "[", range, "]")
__io_entry_pair(symbol::Symbol, range::SplittedRange) = string(symbol, "[", range, "]")

function Base.show(io::IO, entry::FactorisationConstraintsEntry)
    print(io, "q(")
    entries = map(__io_entry_pair, zip(getnames(entry), getindices(entry)))
    join(io, entries, ", ")
    print(io, ")")
end

In [7]:
struct FactorisationConstraintsSpecification{N, E} end

getnames(specification::FactorisationConstraintsSpecification{N})      where N = N
getentries(specification::FactorisationConstraintsSpecification{N, E}) where { N, E } = E

Base.:(*)(left::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }, right::NTuple{N2, <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}) where { N1, N2 } = (left, right...)
Base.:(*)(left::NTuple{N1, <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}, right::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }) where { N1, N2 } = (left..., right)
Base.:(*)(left::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }, right::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }) = (left, right)

Base.:(*)(::NotDefinedYet{S}, something::Any) where S = error("Cannot multiply $S and $something. $S has not been defined yet.")
Base.:(*)(something::Any, ::NotDefinedYet{S}) where S = error("Cannot multiply $S and $something. $S has not been defined yet.")
Base.:(*)(::NotDefinedYet{S1}, ::NotDefinedYet{S2}) where { S1, S2 } = error("Cannot multiply $S1 and $S2. Both $S1 and $S2 have not been defined yet.")

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Names}) where Names
    
    print(io, "q(")
    join(io, getnames(factorisation), ", ")
    print(io, ")")
    
    compact = get(io, :compact, false)
    
    if !compact 
        print(io, " = ")
        foreach(getentries(factorisation)) do e
            print(IOContext(io, :compact => true), e)
        end
    end
    
end

FactorisationConstraintsSpecification(::Val{N}, ::Val{E})       where { N, E }       = FactorisationConstraintsSpecification{N, E}()
FactorisationConstraintsSpecification(::Val{N}, ::Val{nothing}) where { N          } = error("Cannot create q(", join(N, ","), ") factorisation constraints specification")
    
make_factorisation_constraint(N, E) = :(FactorisationConstraintsSpecification($N, $E))

make_factorisation_constraint (generic function with 1 method)

In [8]:
using Unrolled

In [9]:
# Split related functions

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

# Only these combinations are allowed to be merged
__factorisation_split_merge_range(a::Int, b::Int)                              = SplittedRange(a, b)
__factorisation_split_merge_range(a::typeof(firstindex), b::Int)               = SplittedRange(a, b)
__factorisation_split_merge_range(a::Int, b::typeof(lastindex))                = SplittedRange(a, b)
__factorisation_split_merge_range(a::typeof(firstindex), b::typeof(lastindex)) = SplittedRange(a, b)
__factorisation_split_merge_range(a::Any, b::Any) = error("Cannot merge $(a) and $(b) indexes in `factorisation_split`")

function factorisation_split(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}})
    left_last   = last(left)
    right_first = first(right)
    (getnames(left_last) === getnames(right_first)) || error("Cannot split $(left_last) and $(right_first).")
    lindices = getindices(left_last)
    rindices = getindices(right_first)
    split_merged = unrolled_map(__factorisation_split_merge_range, lindices, rindices)
    
    first_split = first(split_merged)
    unrolled_all(e -> e === first_split, split_merged) || error("Inconsistent indices within factorisation split. Check $(split_merged) indices for $(getnames(left_last)) variables.")
    
    return (left[1:end - 1]..., FactorisationConstraintsEntry(Val(getnames(left_last)), Val(split_merged)), right[begin+1:end]...)
end

factorisation_split (generic function with 1 method)

In [10]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

@constraints (macro with 1 method)

In [11]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)
isref(something) = ishead(something, :ref)

isref (generic function with 1 method)

In [12]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [13]:
# TODO: check for intersections of ranges during an actual execution

In [14]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            # We replace '..' in RHS expression with `make_factorisation_split`
            rhs = postwalk(rhs) do rexpr
                if @capture(rexpr, a_ .. b_)
                    return make_factorisation_split(a, b)
                end
                return rexpr
            end
            
            lhs_names = Set{Symbol}(names)
            rhs_names = Set{Symbol}()
            
            # We do a simple check to be sure that LHS and RHS has the exact same set of names
            # We also check here that all indices are either a simple Symbol or an indexing expression here
            rhs = postwalk(MacroTools.prettify(rhs, alias = false)) do entry
                if @capture(entry, q(indices__))
                    for index in indices
                        if index isa Symbol
                            (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.")
                            push!(rhs_names, index)
                            (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.")
                        elseif isref(index)
                            push!(rhs_names, first(index.args))
                            (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.")
                        else
                           error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                        end
                    end
                end
                return entry
            end
            
            (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
            
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            
            new_factorisation_specification = make_factorisation_constraint(:(Val(($(map(QuoteNode, names)...),))), :(Val($(rhs))))
            
            result = quote 
                ($(lhs_varname) isa NotDefinedYet) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_name = lhs_meta.name
        lhs_varname = lhs_meta.varname
        lhs_symbol = Symbol(lhs_name)
        return quote 
            $(lhs_varname) = NotDefinedYet{$(QuoteNode(lhs_symbol))}()
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            expr_hash = hash(expression)
            if haskey(lhs_dict, expr_hash)
                lhs_meta = lhs_dict[ expr_hash ]
                lhs_name = lhs_meta.name
                lhs_varname = lhs_meta.varname
                return lhs_varname
            else
                rhs_prod_names = Symbol[]
                rhs_prod_entries_args = map(args) do arg
                    if arg isa Symbol
                        push!(rhs_prod_names, arg)
                        return :(nothing)
                    elseif isref(arg)
                        (length(arg.args) === 2) || error("Indexing expression $(expression) is too difficult to parse and is not supported (yet?).")
                        push!(rhs_prod_names, first(arg.args))
                        
                        index = last(arg.args)
                        
                        # First we replace all `begin` and `end` with `firstindex` and `lastindex` functions
                        index = postwalk(index) do iexpr
                            if iexpr isa Symbol && iexpr === :begin
                                return :(firstindex)
                            elseif iexpr isa Symbol && iexpr === :end
                                return :(lastindex)
                            else
                                return iexpr
                            end
                        end
                        
                        if @capture(index, a_:b_)
                            return make_combined_range(a, b)
                        elseif @capture(index, firstindex)
                            return index
                        elseif @capture(index, lastindex)
                            return index
                        else
                            return :(convert(Integer, $index))
                        end
                    else
                        error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                    end
                end
                
                entry = make_factorisation_constraint_entry(:(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(Val(($(rhs_prod_entries_args...), ))))
                
                return :(($entry, ))
            end
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [15]:
# node with (x, t, y, w) pure indices

In [16]:
# Didnt work out
# using DataStructures
# struct FirstForwardOrdering <: Base.Ordering end
# Base.lt(::FirstForwardOrdering, a, b) = Base.Order.lt(Base.Order.Forward, first(a), first(b))

In [17]:
# struct ConstraintsSpecificationPreallocated
#     clusters_template :: BitVector
#     clusters_set      :: Set{Tuple}
#     cluster_indices   :: Vector{Int}
# end

In [41]:
struct ClusterIntersectionError
    names
    expression
    indexmap
    clusters
    constraints
end

__throw_intersection_error(names, expression, indexmap, clusters, constraints) = throw(ClusterIntersectionError(names, expression, indexmap, clusters, constraints))

function Base.showerror(io::IO, error::ClusterIntersectionError)
    print(io, "Cluster intersection error in the expression `$(error.expression)`.\n")
    print(io, "Based on factorisation constraints the resulting clusters are: ")
    mapping = collect(pairs(error.indexmap))
    for cluster in error.clusters
        print(io, "q(")
        names = map(cluster) do clusterindex
            name_index = findnext(el -> last(el) === clusterindex, mapping, 1)
            return first(mapping[name_index])
        end
        join(io, names, ",")
        print(io, ")")
    end
    print(io, "\n", error.constraints)
end

In [58]:
function create_factorisation(expr::Expr, ::Val{ns}, is, constraints) where { ns }
    # Indexmap creates simple mapping for each symbol in `ns`
    # E.g. ns = (:x, y)
    # indexmap = (x = 1, y = 2)
    N = length(ns)
    indexmap = NamedTuple{ns}(ntuple(identity, N))
    
    # `factorisation` is a tuple of `FactorisationConstraintsSpecification`s
    # FactorisationConstraintsSpecification has names of LHS and specs of RHS
    factorisation = constraints.factorisation
    
    preallocated = constraints.preallocated
    
    __reset_preallocated!(preallocated, N)
    
    clusters_template = preallocated.clusters_template
    clusters_usage    = preallocated.clusters_usage
    clusters_set      = preallocated.clusters_set
    cluster_indices   = preallocated.cluster_indices
    
    function __process_factorisation_entry!(symbol::Symbol, shift::Int)
        
        # This function unrolles and reduces (1, 2, 3, 4) for each product entry in factorisation specification
        # E.g if q(x, y) = q(x) * q(y) this function iterates over [ q(x), q(y) ]
        __filter_template!(factorisation_entries::Tuple)                              = unrolled_foreach((rentry) -> __filter_template!(Val(true), rentry), factorisation_entries)
        __filter_template!(::Val{true}, spec::FactorisationConstraintsSpecification)  = __filter_template!(getentries(spec))
        __filter_template!(::Val{false}, spec::FactorisationConstraintsSpecification) = symbol ∈ getnames(spec) ? __filter_template!(getentries(spec)) : nothing
        
        # This is totally wrong but just to test an idea
        # Does not support indices
        # Does not support ranges
        # Does not support split ranges
        function __filter_template!(::Val{true}, csentry::FactorisationConstraintsEntry)
            if symbol ∉ getnames(csentry)
                unrolled_foreach(getnames(csentry)) do name
                    if haskey(indexmap, name)
                        name_index = indexmap[name]
                        clusters_template[ shift + name_index ] = false
                    end
                end
            end
        end
        
        for spec in factorisation
            __filter_template!(Val(false), spec)
        end

    end
    
    shift = 0
    for spec in ns
        __process_factorisation_entry!(spec, shift)
        shift += N
    end
    
    @inbounds for index in 1:N
        range_left  = (index - 1) * N + 1
        range_right = range_left + N - 1
        
        ki = 0
        @inbounds for (index, flag) in enumerate(view(clusters_template, range_left:range_right))
            if flag
                ki += 1
                cluster_indices[ki] = index
            end
        end
        
        output = Tuple(view(cluster_indices, 1:ki))
        
        push!(clusters_set, output)
    end
    
    sorted_clusters = sort!(collect(clusters_set); by = first, alg = QuickSort)
    
    # Check if clusters do intersect
    for cluster in sorted_clusters
        for index in cluster
            if clusters_usage[index] === true
                __throw_intersection_error(ns, expr, indexmap, sorted_clusters, constraints)
            end 
            clusters_usage[index] = true
        end
    end
    
    return Tuple(sorted_clusters)
    # return template
end

create_factorisation (generic function with 1 method)

In [59]:
cs = @constraints function test()
    q(x, y) = q(x)q(y)
    q(x, y, t) = q(x, y)q(t)
    q(x, w) = q(x)q(w)
    q(y, w) = q(y)q(w)
end

cs = test()

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t) = q(x, y)q(t)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [60]:
@show ns = (:x, :w, :y, :t)
@show is = (nothing, nothing, nothing, nothing)

ns = (:x, :w, :y, :t) = (:x, :w, :y, :t)
is = (nothing, nothing, nothing, nothing) = (nothing, nothing, nothing, nothing)


(nothing, nothing, nothing, nothing)

In [61]:
cs = test()

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t) = q(x, y)q(t)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [62]:
@time create_factorisation(:(1 + 1), Val(ns), is, cs)
@time create_factorisation(:(1 + 1), Val(ns), is, cs)

  0.166264 seconds (439.70 k allocations: 23.467 MiB, 3.76% gc time, 99.84% compilation time)
  0.000028 seconds (35 allocations: 1.375 KiB)


((1,), (2, 4), (3,))

In [63]:
@time create_factorisation(:(1 + 1), Val(ns), is, cs)
@time create_factorisation(:(1 + 1), Val(ns), is, cs)

  0.000043 seconds (35 allocations: 1.375 KiB)
  0.000023 seconds (35 allocations: 1.375 KiB)


((1,), (2, 4), (3,))

In [66]:
@btime create_factorisation(:(1 + 1), $(Val(ns)), $is, $cs)

  2.351 μs (28 allocations: 1.14 KiB)


((1,), (2, 4), (3,))

In [65]:
function foo(n)
    ns = Val((:x, :w, :y, :t))
    is = (nothing, nothing, nothing, nothing)
    cs = test()
    for i in 1:n
        create_factorisation(:(1 + 1), ns, is, cs)
    end
end

foo (generic function with 1 method)

In [68]:
@btime foo(1000)

  2.359 ms (28009 allocations: 1.11 MiB)
