In [1]:
using MacroTools
using BenchmarkTools 
using TupleTools
using Test

import MacroTools: prewalk, postwalk

In [2]:
macro e(something)
    if something.head === :macrocall
        return esc(:(println(MacroTools.prettify(@macroexpand $something)); $something))
    end
    return esc(something)
end

@e (macro with 1 method)

In [3]:
"""
    FunctionalIndex

A special type of an index that represents a function that can be used only in pair with a collection. 
An example of a `FunctionalIndex` can be `firstindex` or `lastindex`, but more complex use cases are possible too, 
e.g. `firstindex + 1` can be represented as `FunctionalIndex((x) -> firstindex(x) + 1)`.

This structure is used to dispatch on and to replace `begin` or `end` (or more complex use cases, e.g. `begin + 1`) marker
"""
struct FunctionalIndex{R, F}
    f :: F
    
    FunctionalIndex{R}(f::F) where { R, F } = new{R, F}(f)
end

(index::FunctionalIndex{R, F})(collection) where { R, F } = __functional_apply(R, index.f, collection)::Integer

__functional_apply(::Symbol, f, collection) = f(collection)
__functional_apply(subindex::FunctionalIndex, f::Base.Fix2{typeof(+), <: Integer}, collection) = subindex(collection) .+ f.x
__functional_apply(subindex::FunctionalIndex, f::Base.Fix2{typeof(-), <: Integer}, collection) = subindex(collection) .- f.x

Base.:(+)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}(Base.Fix2(+, index))
Base.:(-)(left::FunctionalIndex, index::Integer) = FunctionalIndex{left}(Base.Fix2(-, index))

__functional_print(io::IO, f::typeof(firstindex)) = nothing
__functional_print(io::IO, f::typeof(lastindex))  = nothing
__functional_print(io::IO, f::Base.Fix2{typeof(+), <: Integer}) = print(io, " + ", f.x) # `f.x` here might be broken in future Julia?
__functional_print(io::IO, f::Base.Fix2{typeof(-), <: Integer}) = print(io, " - ", f.x)

function Base.show(io::IO, index::FunctionalIndex{R, F}) where { R, F }
    print(io, "(")
    print(io, R)
    __functional_print(io, index.f)
    print(io, ")")
end

make_functional_index(reprsymbol::Symbol, fnsymbol::Symbol) = :(FunctionalIndex{$(QuoteNode(reprsymbol))}($fnsymbol))

@testset "FunctionalIndex" begin
    
    for N in 1:5
        collection = ones(N)
        @test FunctionalIndex{:nothing}(firstindex)(collection) === firstindex(collection)
        @test FunctionalIndex{:nothing}(lastindex)(collection) === lastindex(collection)
        @test (FunctionalIndex{:nothing}(firstindex) + 1)(collection) === firstindex(collection) + 1
        @test (FunctionalIndex{:nothing}(lastindex) - 1)(collection) === lastindex(collection) - 1
        @test (FunctionalIndex{:nothing}(firstindex) + 1 - 2 + 3)(collection) === firstindex(collection) + 1 - 2 + 3
        @test (FunctionalIndex{:nothing}(lastindex) - 1 + 2 - 3)(collection) === lastindex(collection) - 1 + 2 - 3
    end
    
    @test repr(FunctionalIndex{:begin}(firstindex)) === "(begin)"
    @test repr(FunctionalIndex{:begin}(firstindex) + 1) === "((begin) + 1)"
    @test repr(FunctionalIndex{:begin}(firstindex) - 1) === "((begin) - 1)"
    @test repr(FunctionalIndex{:begin}(firstindex) - 1 + 1) === "(((begin) - 1) + 1)"
    
    @test repr(FunctionalIndex{:end}(lastindex)) === "(end)"
    @test repr(FunctionalIndex{:end}(lastindex) + 1) === "((end) + 1)"
    @test repr(FunctionalIndex{:end}(lastindex) - 1) === "((end) - 1)"
    @test repr(FunctionalIndex{:end}(lastindex) - 1 + 1) === "(((end) - 1) + 1)"
    
    @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1)))
    @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1)))
    @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) + 1 + 1)))
    @test isbitstype(typeof((FunctionalIndex{:begin}(firstindex) - 1 + 1)))
end

[0m[1mTest Summary:   | [22m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
FunctionalIndex | [32m  42  [39m[36m   42[39m


Test.DefaultTestSet("FunctionalIndex", Any[], 42, false, false)

In [4]:
__range_io_helper(value::Integer)         = value
__range_io_helper(value::FunctionalIndex) = repr(value)

struct CombinedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::CombinedRange) = print(io, __range_io_helper(range.from), ":", __range_io_helper(range.to))

make_combined_range(l, r) = :(CombinedRange($l, $r))

struct SplittedRange{L, R}
    from :: L
    to   :: R
end

Base.show(io::IO, range::SplittedRange) = print(io, __range_io_helper(range.from), "..", __range_io_helper(range.to))

make_splitted_range(l, r) = :(SplittedRange($l, $r))
make_factorisation_split(l, r) = :(factorisation_split($l, $r))

make_factorisation_split (generic function with 1 method)

In [5]:
struct NotDefinedYet{S} end

In [6]:
const ConstraintsSpecificationPreallocatedDefaultSize = 64

struct ConstraintsSpecificationPreallocated
    clusters_template :: BitVector
    clusters_usage    :: BitVector
    clusters_set      :: Set{Tuple}
    cluster_indices   :: Vector{Int}
    
    ConstraintsSpecificationPreallocated() = new(trues(ConstraintsSpecificationPreallocatedDefaultSize), falses(ConstraintsSpecificationPreallocatedDefaultSize), Set{Tuple}(), Vector{Int}(undef, ConstraintsSpecificationPreallocatedDefaultSize))
end

function __reset_preallocated!(preallocated::ConstraintsSpecificationPreallocated, size::Int)
    abs2size = abs2(size)
    if length(preallocated.clusters_template) < abs2size
        resize!(preallocated.clusters_template, abs2size)
        resize!(preallocated.clusters_usage, abs2size) # note: we dont need `size^2` for this, just `size` should be enough, but just to avoid extra checks
    end
    
    if length(preallocated.cluster_indices) < size
        resize!(preallocated.cluster_indices, size)
    end
    
    fill!(preallocated.clusters_template, true)
    fill!(preallocated.clusters_usage, false)
    
    empty!(preallocated.clusters_set)
end

struct ConstraintsSpecification{F, M}
    factorisation :: F
    form :: M
    preallocated :: ConstraintsSpecificationPreallocated
end

ConstraintsSpecification(factorisation::F, form::M) where { F, M } = ConstraintsSpecification{F, M}(factorisation, form, ConstraintsSpecificationPreallocated())

__reset_preallocated!(specification::ConstraintsSpecification, size::Int) = __reset_preallocated!(specification.preallocated, size)

function Base.show(io::IO, specification::ConstraintsSpecification) 
    print(io, "Constraints:\n\tform: $(specification.form)\n")
    print(io, "\tfactorisation\n")
    foreach(specification.factorisation) do f
        print(io, "\t\t", f, "\n")
    end
end

make_constraints_specification(factorisation, form) = :(ConstraintsSpecification($factorisation, $form))

make_constraints_specification (generic function with 1 method)

In [7]:
struct FactorisationConstraintsEntry{N, I} end

getnames(entry::FactorisationConstraintsEntry{N}) where N = N
getindices(entry::FactorisationConstraintsEntry{N, I}) where { N, I } = I
getpairs(entry::FactorisationConstraintsEntry) = zip(getnames(entry), getindices(entry))

FactorisationConstraintsEntry(::Val{N}, ::Val{I}) where { N, I } = FactorisationConstraintsEntry{N, I}()

make_factorisation_constraint_entry(N, T) = :(FactorisationConstraintsEntry($N, $T))

__io_entry_pair(pair::Tuple)                    = __io_entry_pair(first(pair), last(pair))
__io_entry_pair(symbol::Symbol, ::Nothing)      = string(symbol)
__io_entry_pair(symbol::Symbol, index::Integer) = string(symbol, "[", index, "]")
__io_entry_pair(symbol::Symbol, index::FunctionalIndex) = string(symbol, "[", repr(index), "]")
__io_entry_pair(symbol::Symbol, range::CombinedRange) = string(symbol, "[", range, "]")
__io_entry_pair(symbol::Symbol, range::SplittedRange) = string(symbol, "[", range, "]")

function Base.show(io::IO, entry::FactorisationConstraintsEntry)
    print(io, "q(")
    entries = map(__io_entry_pair, zip(getnames(entry), getindices(entry)))
    join(io, entries, ", ")
    print(io, ")")
end

In [8]:
struct FactorisationConstraintsSpecification{N, E} end

getnames(specification::FactorisationConstraintsSpecification{N})      where N = N
getentries(specification::FactorisationConstraintsSpecification{N, E}) where { N, E } = E

Base.:(*)(left::Tuple{Vararg{T where T <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}}, right::Tuple{Vararg{T where T <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}}) = (left..., right...)
Base.:(*)(left::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }, right::Tuple{Vararg{T where T <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}}) = (left, right...)
Base.:(*)(left::Tuple{Vararg{T where T <: Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }}}, right::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }) = (left..., right)
Base.:(*)(left::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }, right::Union{ <:FactorisationConstraintsSpecification, <:FactorisationConstraintsEntry }) = (left, right)

Base.:(*)(::NotDefinedYet{S}, something::Any) where S = error("Cannot multiply $S and $something. $S has not been defined yet.")
Base.:(*)(something::Any, ::NotDefinedYet{S}) where S = error("Cannot multiply $S and $something. $S has not been defined yet.")
Base.:(*)(::NotDefinedYet{S1}, ::NotDefinedYet{S2}) where { S1, S2 } = error("Cannot multiply $S1 and $S2. Both $S1 and $S2 have not been defined yet.")

function Base.show(io::IO, factorisation::FactorisationConstraintsSpecification{Names}) where Names
    
    print(io, "q(")
    join(io, getnames(factorisation), ", ")
    print(io, ")")
    
    iscompact = get(io, :compact, false)
    
    if !iscompact
        print(io, " = ")
        foreach(getentries(factorisation)) do e
            print(IOContext(io, :compact => true), e)
        end
    end

end

FactorisationConstraintsSpecification(::Val{N}, ::Val{E})       where { N, E }       = FactorisationConstraintsSpecification{N, E}()
FactorisationConstraintsSpecification(::Val{N}, ::Val{nothing}) where { N          } = error("Cannot create q(", join(N, ","), ") factorisation constraints specification")
    
make_factorisation_constraint(N, E) = :(FactorisationConstraintsSpecification($N, $E))

make_factorisation_constraint (generic function with 1 method)

In [9]:
using Unrolled

In [10]:
# Split related functions

Base.:(*)(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}) = (left..., right...)

# Only these combinations are allowed to be merged
__factorisation_split_merge_range(a::Int, b::Int)                         = SplittedRange(a, b)
__factorisation_split_merge_range(a::FunctionalIndex, b::Int)             = SplittedRange(a, b)
__factorisation_split_merge_range(a::Int, b::FunctionalIndex)             = SplittedRange(a, b)
__factorisation_split_merge_range(a::FunctionalIndex, b::FunctionalIndex) = SplittedRange(a, b)
__factorisation_split_merge_range(a::Any, b::Any) = error("Cannot merge $(a) and $(b) indexes in `factorisation_split`")

function factorisation_split(left::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}}, right::Tuple{Vararg{T where T <: FactorisationConstraintsEntry}})
    left_last   = last(left)
    right_first = first(right)
    (getnames(left_last) === getnames(right_first)) || error("Cannot split $(left_last) and $(right_first).")
    lindices = getindices(left_last)
    rindices = getindices(right_first)
    split_merged = unrolled_map(__factorisation_split_merge_range, lindices, rindices)
    
    first_split = first(split_merged)
    unrolled_all(e -> e === first_split, split_merged) || error("Inconsistent indices within factorisation split. Check $(split_merged) indices for $(getnames(left_last)) variables.")
    
    return (left[1:end - 1]..., FactorisationConstraintsEntry(Val(getnames(left_last)), Val(split_merged)), right[begin+1:end]...)
end

factorisation_split (generic function with 1 method)

In [11]:
struct ReactiveMPBackend end

__get_current_backend() = ReactiveMPBackend()

macro constraints(constraints_specification)
    return generate_constraints_expression(__get_current_backend(), constraints_specification)
end

@constraints (macro with 1 method)

In [12]:
isexpr(expr::Expr) = true
isexpr()           = false

ishead(something, head) = isexpr(something) && something.head === head

isblock(something) = ishead(something, :block)
isref(something) = ishead(something, :ref)

isref (generic function with 1 method)

In [13]:
struct LHSMeta
    name :: String
    hash :: UInt
    varname :: Symbol
end

In [14]:
# TODO: check for intersections of ranges during an actual execution

In [15]:
function generate_constraints_expression(backend, constraints_specification)

    if isblock(constraints_specification)
        generatedfname = gensym(:constraints)
        generatedfbody = :(function $(generatedfname)() $constraints_specification end)
        return :($(generate_constraints_expression(backend, generatedfbody))())
    end

    @capture(constraints_specification, (function cs_name_(cs_args__; cs_kwargs__) cs_body_ end) | (function cs_name_(cs_args__) cs_body_ end)) || 
        error("Constraints specification language requires full function definition")
    
    cs_args   = cs_args === nothing ? [] : cs_args
    cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
    
    lhs_dict = Dict{UInt, LHSMeta}()
    
    # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
    form_constraints_symbol      = gensym(:form_constraint)
    form_constraints_symbol_init = :($form_constraints_symbol = ())
    
    # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
    factorisation_constraints_symbol      = gensym(:factorisation_constraint)
    factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
    
    # First we record all lhs expression's hash ids and create unique variable names for them
    # q(x, y) = q(x)q(y) -> hash(q(x, y))
    # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
    cs_body = postwalk(cs_body) do expression
        # We also do a simple sanity check right now, names should be an array of Symbols only
        if @capture(expression, lhs_ = rhs_) && @capture(lhs, q(names__))
            
            (length(names) !== 0 && all(name -> name isa Symbol, names)) || 
                error("""Error in factorisation constraints specification $(lhs_name) = ...\nLeft hand side of the equality expression should have only variable identifiers.""")
            
            # We replace '..' in RHS expression with `make_factorisation_split`
            rhs = postwalk(rhs) do rexpr
                if @capture(rexpr, a_ .. b_)
                    return make_factorisation_split(a, b)
                end
                return rexpr
            end
            
            lhs_names = Set{Symbol}(names)
            rhs_names = Set{Symbol}()
            
            # We do a simple check to be sure that LHS and RHS has the exact same set of names
            # We also check here that all indices are either a simple Symbol or an indexing expression here
            rhs = postwalk(MacroTools.prettify(rhs, alias = false)) do entry
                if @capture(entry, q(indices__))
                    for index in indices
                        if index isa Symbol
                            (index ∉ rhs_names) || error("RHS of the $(expression) expression used $(index) without indexing twice, which is not allowed. Try to decompose factorisation constraint expression into several subexpression.")
                            push!(rhs_names, index)
                            (index ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(index) variable, but is used in RHS.")
                        elseif isref(index)
                            push!(rhs_names, first(index.args))
                            (first(index.args) ∉ lhs_names) && error("LHS of the $(expression) expression does not have $(first(index.args)) variable, but is used in RHS.")
                        else
                           error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                        end
                    end
                end
                return entry
            end
            
            (lhs_names == rhs_names) || error("LHS and RHS of the $(expression) expression has different set of variables.")
            
            lhs_hash = hash(lhs)
            lhs_meta = if haskey(lhs_dict, lhs_hash)
                lhs_dict[ lhs_hash ]
            else
                lhs_name = string("q(", join(names, ", "), ")")
                lhs_varname = gensym(lhs_name)
                lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
                lhs_dict[lhs_hash] = lhs_meta
            end
            
            lhs_name = lhs_meta.name
            lhs_varname = lhs_meta.varname
            
            new_factorisation_specification = make_factorisation_constraint(:(Val(($(map(QuoteNode, names)...),))), :(Val($(rhs))))
            
            result = quote 
                ($(lhs_varname) isa NotDefinedYet) || error("Factorisation constraints specification $($lhs_name) = ... has been redefined.")
                $(lhs_varname) = $(new_factorisation_specification)
                $factorisation_constraints_symbol = ($factorisation_constraints_symbol..., $(lhs_varname))
            end
            
            return result
        end
        return expression
    end
    
    # This block write initial variables for factorisation specification
    cs_lhs_init_block = map(collect(lhs_dict)) do pair
        lhs_meta = last(pair)
        lhs_name = lhs_meta.name
        lhs_varname = lhs_meta.varname
        lhs_symbol = Symbol(lhs_name)
        return quote 
            $(lhs_varname) = NotDefinedYet{$(QuoteNode(lhs_symbol))}()
        end
    end
    
    cs_body = prewalk(cs_body) do expression
        if @capture(expression, q(args__))
            rhs_prod_names = Symbol[]
            rhs_prod_entries_args = map(args) do arg
                if arg isa Symbol
                    push!(rhs_prod_names, arg)
                    return :(nothing)
                elseif isref(arg)
                    (length(arg.args) === 2) || error("Indexing expression $(expression) is too difficult to parse and is not supported (yet?).")
                    push!(rhs_prod_names, first(arg.args))

                    index = last(arg.args)

                    # First we replace all `begin` and `end` with `firstindex` and `lastindex` functions
                    index = postwalk(index) do iexpr
                        if iexpr isa Symbol && iexpr === :begin
                            return make_functional_index(:begin, :firstindex)
                        elseif iexpr isa Symbol && iexpr === :end
                            return make_functional_index(:end, :lastindex)
                        else
                            return iexpr
                        end
                    end

                    if @capture(index, a_:b_)
                        return make_combined_range(a, b)
                    else
                        return index
                    end
                else
                    error("Cannot parse expression $(index) in the RHS $(rhs) expression. Index expression should be either a single variable symbol or an indexing expression.") 
                end
            end

            entry = make_factorisation_constraint_entry(:(Val(($(map(QuoteNode, rhs_prod_names)...), ))), :(Val(($(rhs_prod_entries_args...), ))))

            return :(($entry, ))
        end
        return expression
    end
    
    return_specification = make_constraints_specification(factorisation_constraints_symbol, form_constraints_symbol)
    
    res = quote
         function $cs_name($(cs_args...); $(cs_kwargs...))
            $(form_constraints_symbol_init)
            $(factorisation_constraints_symbol_init)
            $(cs_lhs_init_block...)
            $(cs_body)
            $(return_specification)
        end 
    end
    
    return esc(res)
end

generate_constraints_expression (generic function with 1 method)

In [16]:
# node with (x, t, y, w) pure indices

In [17]:
# Didnt work out
# using DataStructures
# struct FirstForwardOrdering <: Base.Ordering end
# Base.lt(::FirstForwardOrdering, a, b) = Base.Order.lt(Base.Order.Forward, first(a), first(b))

In [18]:
# struct ConstraintsSpecificationPreallocated
#     clusters_template :: BitVector
#     clusters_set      :: Set{Tuple}
#     cluster_indices   :: Vector{Int}
# end

In [78]:
struct ClusterIntersectionError
    expression
    varrefs
    clusters
    constraints
end

__throw_intersection_error(expression, varrefs, clusters, constraints) = throw(ClusterIntersectionError(expression, varrefs, clusters, constraints))

function Base.showerror(io::IO, error::ClusterIntersectionError)
    print(io, "Cluster intersection error in the expression `$(error.expression)`.\n")
    print(io, "Based on factorisation constraints the resulting local constraint ")
    varnames = 
    print(io, "q(")
    join(io, map(first, error.varrefs), ", ")
    print(io, ") = ")
    for cluster in error.clusters
        print(io, "q(")
        entries = map(cluster) do clusterindex
            __for_error_convert(tuple::Tuple)        = __for_error_convert(first(tuple), last(tuple))
            __for_error_convert(var, index)          = string(var, "[", index, "]")
            __for_error_convert(var, index::Nothing) = string(var)
            __for_error_convert(error.varrefs[clusterindex])
        end
        join(io, entries, ", ")
        print(io, ")")
    end
    print(io, " has cluster intersections, which is disallowed by default.")
    print(io, "\nTechnical info: clusters = ", error.clusters)
    print(io, "\n", error.constraints)
end

In [85]:
function create_factorisation(expr::Expr, variables, constraints) 
    # Indexmap creates simple mapping for each symbol in `ns`
    # E.g. ns = (:x, y)
    # indexmap = (x = 1, y = 2)
    N = length(variables)
    var_refs       = map(get_factorisation_reference, variables)
    var_refs_names = map(first, var_refs)
    # indexmap = NamedTuple{Names}(ntuple(identity, N))
    
    # `factorisation` is a tuple of `FactorisationConstraintsSpecification`s
    # FactorisationConstraintsSpecification has names of LHS and specs of RHS
    factorisation = constraints.factorisation
    
    preallocated = constraints.preallocated
    
    __reset_preallocated!(preallocated, N)
    
    clusters_template = preallocated.clusters_template
    clusters_usage    = preallocated.clusters_usage
    clusters_set      = preallocated.clusters_set
    cluster_indices   = preallocated.cluster_indices
    
    function __process_factorisation_entry!(symbol::Symbol, index, shift::Int) where N
        # `symbols` refers to all possible symbols that refer to the current variable
        
        function __filter_template!(spec::FactorisationConstraintsSpecification, factorisation_entries::Tuple)
            # This function applies a given `spec` with rhs = `factorisation_entries`
            # Function goes all over `factorisation_entries` and check that the target `symbols` are found only once
            # This is to prevent situations like q(x) = q(x[1])q(x[1]), which are correct from syntax point of view, but are not allowed in runtime
            found_once = false
            for entry in factorisation_entries
                is_found = __filter_template!(Val(true), entry)
                if is_found && found_once
                    error("Found variable $(symbol) twice in the factorisation specification $(spec).")
                end
                found_once = found_once | is_found
            end
            return found_once
        end
        
        # First argument `force` is a compile time flag that indicates if we want to check names of the `spec` first
        __filter_template!(force::Val{true}, spec::FactorisationConstraintsSpecification)  = __filter_template!(spec, getentries(spec))
        __filter_template!(force::Val{false}, spec::FactorisationConstraintsSpecification) = symbol ∈ getnames(spec) ? __filter_template!(spec, getentries(spec)) : false
        
        # This is totally wrong but just to test an idea
        # Does not support indices
        # Does not support ranges
        # Does not support split ranges
        function __filter_template!(force::Val{true}, csentry::FactorisationConstraintsEntry)
            entry_names   = getnames(csentry)
            entry_indices = getindices(csentry)
            entry_pairs   = getpairs(csentry)
            
            if symbol ∉ entry_names
                unrolled_foreach(entry_pairs) do qpair
                    
                    name_position = if last(qpair) === nothing 
                        # First we check if variable in factorisation specification has no index associated with it
                        # In this case we ignore any indexing in variable references and exclude it based purely on name
                        findnext(==(first(qpair)), var_refs_names, 1)
                    else
                        # todo this check does not support functional indices (and regular indices too)
                        # `var_refs` here refers to a tuple of `(varname, index)` elements
                        findnext(==(qpair), var_refs, 1)
                    end
                        
                    if name_position !== nothing
                        clusters_template[ shift + name_position ] = false
                    end
                end
                return false
            else
                symbol_position = findnext(==(symbol), entry_names, 1)::Integer # symbol_position cannot be `nothing` here 
                symbol_index    = entry_indices[symbol_position] 
                
                # Case `1`: User passed a single variable (index = nothing), but factorisation assumes indexing. e.g q(x) = q(x[begin])..q(x[end])
                if index === nothing && symbol_index !== nothing
                    error("Factorisation specification entry $(csentry) assumes variable $(symbol) can be indexed, however variable $(symbol) created in the model cannot be indexed.")
                end
                
                # Case `2`: User passed an indexed variable (index = some_integer) and factorisation acts on whe whole set, q(x, y) = q(x)q(y)
                if index isa Integer && symbol_index === nothing # `isa` is not really a Julia way to solve things, but I believe it is reasonable here 
                     return true
                end
                
                return true
            end
        end
        
        unrolled_foreach(factorisation) do spec
            __filter_template!(Val(false), spec)
        end
        
    end
    
    shift = 0
    for varref in var_refs
        __process_factorisation_entry!(first(varref), last(varref), shift)
        shift += N
    end
    
    @inbounds for index in 1:N
        range_left  = (index - 1) * N + 1
        range_right = range_left + N - 1
        
        ki = 0
        @inbounds for (index, flag) in enumerate(view(clusters_template, range_left:range_right))
            if flag
                ki += 1
                cluster_indices[ki] = index
            end
        end
        
        output = Tuple(view(cluster_indices, 1:ki))
        
        push!(clusters_set, output)
    end
    
    sorted_clusters = sort!(collect(clusters_set); by = first, alg = QuickSort)
    
    # Check if clusters do intersect
    for cluster in sorted_clusters
        for index in cluster
            if clusters_usage[index] === true
                __throw_intersection_error(expr, var_refs, sorted_clusters, constraints)
            end 
            clusters_usage[index] = true
        end
    end
    
    return Tuple(sorted_clusters)
    # return template
end

create_factorisation (generic function with 1 method)

In [80]:
cs = @constraints function test()
    q(x, y) = q(x)q(y)
    q(x, y, t, r) = q(x, y)q(t)q(r)
    q(x, w) = q(x)q(w)
    q(y, w) = q(y)q(w)
    # q(x) = q(x[begin])q(x[begin + 1])..q(x[end])
end

cs = test()

Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t, r) = q(x, y)q(t)q(r)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [81]:
struct DummyVariable
    name   :: Symbol
    index
    proxy
end

get_factorisation_reference(variable::DummyVariable) = get_factorisation_reference(variable, variable.proxy, variable.name, variable.index)

get_factorisation_reference(::DummyVariable, ::Nothing, name::Symbol, index) = (name, index)
get_factorisation_reference(::DummyVariable, something::Tuple{T}, name, index) where T = get_factorisation_reference(first(something))
get_factorisation_reference(::DummyVariable, something::Tuple, name, index)            = error("Multiple proxies are dissalowed: todo better error message. The idea here is that if variable has been auto-created and refers to two different variables in the model we simply throw an error. User must explistly create a new variable and name it and use it in constraints specification.")

# struct DummyVariable
#     name :: Symbol
# end

# getnames(variable::DummyVariable) = (variable.name, )

get_factorisation_reference (generic function with 4 methods)

In [82]:
xvar = DummyVariable(:x, 1, nothing)
wvar = DummyVariable(:w, nothing, nothing)
yvar = DummyVariable(:tmpy, nothing, (DummyVariable(:y, nothing, nothing), ))
# yvar = DummyVariable(:tmpy, nothing, nothing)
tvar = DummyVariable(:t, nothing, nothing)

# vars = (xvar, wvar, yvar, tvar)
vars = (xvar, yvar, tvar)

(DummyVariable(:x, 1, nothing), DummyVariable(:tmpy, nothing, (DummyVariable(:y, nothing, nothing),)), DummyVariable(:t, nothing, nothing))

In [83]:
@time test()
@time test()

cs = test()

  0.000010 seconds (9 allocations: 1.125 KiB)
  0.000006 seconds (9 allocations: 1.125 KiB)


Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t, r) = q(x, y)q(t)q(r)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [84]:
@time create_factorisation(:(1 + 1), vars, cs)
@time create_factorisation(:(1 + 1), vars, cs)

LoadError: Cluster intersection error in the expression `1 + 1`.
Based on factorisation constraints the resulting local constraint q(x, y, t) = q(x[1], y)q(x[1], t)q(x[1]) has cluster intersections, which is disallowed by default.
Technical info: clusters = Tuple[(1, 2), (1, 3), (1,)]
Constraints:
	form: ()
	factorisation
		q(x, y) = q(x)q(y)
		q(x, y, t, r) = q(x, y)q(t)q(r)
		q(x, w) = q(x)q(w)
		q(y, w) = q(y)q(w)


In [26]:
@time create_factorisation(:(1 + 1), vars, cs)
@time create_factorisation(:(1 + 1), vars, cs)

  0.000041 seconds (39 allocations: 1.250 KiB)
  0.000023 seconds (39 allocations: 1.250 KiB)


((1,), (2,), (3,))

In [27]:
@btime create_factorisation(:(1 + 1), $vars, $cs)

  2.463 μs (33 allocations: 1.06 KiB)


((1,), (2,), (3,))

In [28]:
function foo(n)
    ns = Val((:x, :w, :y, :t))
    is = (nothing, nothing, nothing, nothing)
    cs = test()
    for i in 1:n
        create_factorisation(:(1 + 1), ns, is, cs)
    end
end

foo (generic function with 1 method)

In [29]:
# @btime foo(1000)