From 48074efee3f1ac4ed2860a3fb17b3a6a69e8f645 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Wed, 2 Mar 2022 13:43:08 -0500 Subject: [PATCH] refactor sum product tree --- docs/src/performancetips.md | 2 +- docs/src/ref.md | 1 + src/GraphTensorNetworks.jl | 3 +- src/arithematics.jl | 117 +++++++++++++++++++++--------------- 4 files changed, 71 insertions(+), 52 deletions(-) diff --git a/docs/src/performancetips.md b/docs/src/performancetips.md index 996225b6..21d54ea1 100644 --- a/docs/src/performancetips.md +++ b/docs/src/performancetips.md @@ -89,7 +89,7 @@ The GEMM routine can speed up the computation on CPU for one order, with multi-t Benchmark shows the performance of `TropicalGEMM` is close to the theoretical optimal value. ## Sum product representation for configurations -[`TreeConfigEnumerator`](@ref) can save a lot memory for you to store exponential number of configurations in polynomial space. +[`TreeConfigEnumerator`](@ref) (an alias of [`SumProductTree`](@ref) with [`StaticElementVector`](@ref) as its data type) can save a lot memory for you to store exponential number of configurations in polynomial space. It is a sum-product expression tree to store [`ConfigEnumerator`](@ref) in a lazy style, configurations can be extracted by depth first searching the tree with the `Base.collect` method. Although it is space efficient, it is in general not easy to extract information from it. This tree structure supports directed sampling so that one can get some statistic properties from it with an intermediate effort. diff --git a/docs/src/ref.md b/docs/src/ref.md index 5ea23d24..fde2eb04 100644 --- a/docs/src/ref.md +++ b/docs/src/ref.md @@ -86,6 +86,7 @@ Polynomials.Polynomial TruncatedPoly Max2Poly ConfigEnumerator +SumProductTree TreeConfigEnumerator ConfigSampler ``` diff --git a/src/GraphTensorNetworks.jl b/src/GraphTensorNetworks.jl index ff2a82c4..84429e6d 100644 --- a/src/GraphTensorNetworks.jl +++ b/src/GraphTensorNetworks.jl @@ -17,7 +17,8 @@ export estimate_memory # Algebras export StaticBitVector, StaticElementVector, @bv_str export is_commutative_semiring -export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator +export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod +export ConfigEnumerator, onehotv, ConfigSampler, SumProductTree, TreeConfigEnumerator export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32, ExtendedTropical export generate_samples diff --git a/src/arithematics.jl b/src/arithematics.jl index 9090cdc4..8b8975cd 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -408,30 +408,27 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C}) Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C}) -# tree config enumerator -# it must be mutable, otherwise the `IdDict` trick for computing the length does not work. """ - TreeConfigEnumerator{N,S,C} <: AbstractSetNumber + SumProductTree{ET} <: AbstractSetNumber Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network and is often more memory efficient than putting the configurations in a vector. One can use [`generate_samples`](@ref) to sample configurations from this tree structure efficiently. -`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}. Fields ----------------------- -* `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`. +* `tag` is one of `ZERO`, `ONE`, `LEAF`, `SUM`, `PROD`. * `data` is the element stored in a `LEAF` node. * `left` and `right` are two operands of a `SUM` or `PROD` node. Example ------------------------ ```jldoctest; setup=:(using GraphTensorNetworks) -julia> s = TreeConfigEnumerator(bv"00111") +julia> s = SumProductTree(bv"00111") 00111 -julia> q = TreeConfigEnumerator(bv"10000") +julia> q = SumProductTree(bv"10000") 10000 @@ -469,36 +466,55 @@ julia> one(s) ``` """ -mutable struct TreeConfigEnumerator{N,S,C} <: AbstractSetNumber +mutable struct SumProductTree{ET} <: AbstractSetNumber tag::TreeTag - data::StaticElementVector{N,S,C} - left::TreeConfigEnumerator{N,S,C} - right::TreeConfigEnumerator{N,S,C} - TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = new{N,S,C}(tag, zero(StaticElementVector{N,S,C}), left, right) - function TreeConfigEnumerator(data::StaticElementVector{N,S,C}) where {N,S,C} - new{N,S,C}(LEAF, data) - end - function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C} + data::ET + left::SumProductTree{ET} + right::SumProductTree{ET} + # zero(ET) can be undef + function SumProductTree(tag::TreeTag, left::SumProductTree{ET}, right::SumProductTree{ET}) where {ET} + res = new{ET}(tag) + res.left = left + res.right = right + return res + end + function SumProductTree(data::ET) where ET + return new{ET}(LEAF, data) + end + function SumProductTree{ET}(tag::TreeTag) where {ET} @assert tag === ZERO || tag === ONE - return new{N,S,C}(tag) + return new{ET}(tag) end end +# these two interfaces must be implemented in order to collect elements +_data_mul(x::StaticElementVector, y::StaticElementVector) = x | y +_data_one(::Type{T}) where T<:StaticElementVector = zero(T) # NOTE: might be optional + +""" + TreeConfigEnumerator{N,S,C} + +An alias for [`SumProductTree`](@ref)`{StaticElementVector{N, S, C}}`, +which is a useful element type for configuration enumeration. +""" +const TreeConfigEnumerator{N,S,C} = SumProductTree{StaticElementVector{N,S,C}} +TreeConfigEnumerator(data::StaticElementVector) = SumProductTree(data) +TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = SumProductTree(tag, left, right) # AbstractTree APIs -function children(t::TreeConfigEnumerator) +function children(t::SumProductTree) if t.tag == ZERO || t.tag == LEAF || t.tag == ONE return typeof(t)[] else return [t.left, t.right] end end -function printnode(io::IO, t::TreeConfigEnumerator{N,S,C}) where {N,S,C} +function printnode(io::IO, t::SumProductTree{ET}) where {ET} if t.tag === LEAF print(io, t.data) elseif t.tag === ZERO print(io, "∅") elseif t.tag === ONE - print(io, zero(StaticElementVector{N,S,C})) + print(io, _data_one(ET)) elseif t.tag === SUM print(io, "+") else # PROD @@ -506,7 +522,8 @@ function printnode(io::IO, t::TreeConfigEnumerator{N,S,C}) where {N,S,C} end end -Base.length(x::TreeConfigEnumerator) = _length!(x, IdDict{typeof(x), Int}()) +# it must be mutable, otherwise the `IdDict` trick for computing the length does not work. +Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}()) function _length!(x, d) haskey(d, x) && return d[x] @@ -525,7 +542,7 @@ function _length!(x, d) end end -num_nodes(x::TreeConfigEnumerator) = _num_nodes(x, IdDict{typeof(x), Int}()) +num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}()) function _num_nodes(x, d) haskey(d, x) && return 0 if x.tag == ZERO || x.tag == ONE @@ -539,37 +556,37 @@ function _num_nodes(x, d) return res end -function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C} +function Base.:(==)(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET} return Set(collect(x)) == Set(collect(y)) end -Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t) +Base.show(io::IO, t::SumProductTree) = print_tree(io, t) -function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C} +function Base.collect(x::SumProductTree{ET}) where {ET} if x.tag == ZERO - return StaticElementVector{N,S,C}[] + return ET[] elseif x.tag == ONE - return StaticElementVector{N,S,C}[zero(StaticElementVector{N,S,C})] + return [_data_one(ET)] elseif x.tag == LEAF - return StaticElementVector{N,S,C}[x.data] + return [x.data] elseif x.tag == SUM return vcat(collect(x.left), collect(x.right)) else # PROD - return vec([reduce((x,y)->x|y, si) for si in Iterators.product(collect(x.left), collect(x.right))]) + return vec([reduce(_data_mul, si) for si in Iterators.product(collect(x.left), collect(x.right))]) end end -function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C} +function Base.:+(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET} if x.tag == ZERO return y elseif y.tag == ZERO return x else - return TreeConfigEnumerator(SUM, x, y) + return SumProductTree(SUM, x, y) end end -function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C} +function Base.:*(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET} if x.tag == ONE return y elseif y.tag == ONE @@ -579,18 +596,18 @@ function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) elseif y.tag == ZERO return y elseif x.tag == LEAF && y.tag == LEAF - return TreeConfigEnumerator(x.data | y.data) + return SumProductTree(_data_mul(x.data, y.data)) else - return TreeConfigEnumerator(PROD, x, y) + return SumProductTree(PROD, x, y) end end -Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO) -Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ONE) -Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C}) -Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C}) +Base.zero(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ZERO) +Base.one(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ONE) +Base.zero(::SumProductTree{ET}) where {ET} = zero(SumProductTree{ET}) +Base.one(::SumProductTree{ET}) where {ET} = one(SumProductTree{ET}) # todo, check siblings too? -function Base.iszero(t::TreeConfigEnumerator) +function Base.iszero(t::SumProductTree) if t.tag == SUM iszero(t.left) && iszero(t.right) elseif t.tag == ZERO @@ -603,9 +620,9 @@ function Base.iszero(t::TreeConfigEnumerator) end """ - generate_samples(t::TreeConfigEnumerator, nsamples::Int) + generate_samples(t::SumProductTree, nsamples::Int) -Direct sampling configurations from a [`TreeConfigEnumerator`](@ref) instance. +Direct sampling configurations from a [`SumProductTree`](@ref) instance. Example ----------------------------- @@ -623,15 +640,15 @@ julia> all(s->is_independent_set(g, s), samples) true ``` """ -function generate_samples(t::TreeConfigEnumerator{N,S,C}, nsamples::Int) where {N,S,C} +function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET} # get length dict - res = fill(zero(StaticElementVector{N,S,C}), nsamples) + res = fill(_data_one(ET), nsamples) d = IdDict{typeof(t), Int}() sample_descend!(res, t, d) return res end -function sample_descend!(res::AbstractVector, t::TreeConfigEnumerator, d::IdDict) +function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict) length(res) == 0 && return res if t.tag == LEAF res .|= Ref(t.data) @@ -695,14 +712,14 @@ onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampl # just to make matrix transpose work Base.transpose(c::ConfigEnumerator) = c Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data)) -Base.transpose(c::TreeConfigEnumerator) = c -function Base.copy(c::TreeConfigEnumerator{N,S,C}) where {N,S,C} +Base.transpose(c::SumProductTree) = c +function Base.copy(c::SumProductTree{ET}) where {ET} if c.tag == LEAF - TreeConfigEnumerator(c.data) + SumProductTree(c.data) elseif c.tag == ZERO || c.tag == ONE - TreeConfigEnumerator{N,S,C}(c.tag) + SumProductTree{ET}(c.tag) else - TreeConfigEnumerator(c.tag, c.left, c.right) + SumProductTree(c.tag, c.left, c.right) end end @@ -713,7 +730,7 @@ for TYPE in [:AbstractSetNumber, :TruncatedPoly, :ExtendedTropical] end # to handle power of polynomials -function Base.:^(x::TreeConfigEnumerator, y::Real) +function Base.:^(x::SumProductTree, y::Real) if y == 0 return one(x) elseif x.tag == LEAF || x.tag == ONE || x.tag == ZERO