Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/performancetips.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ The GEMM routine can speed up the computation on CPU for one order, with multi-t
Benchmark shows the performance of `TropicalGEMM` is close to the theoretical optimal value.

## Sum product representation for configurations
[`TreeConfigEnumerator`](@ref) can save a lot memory for you to store exponential number of configurations in polynomial space.
[`TreeConfigEnumerator`](@ref) (an alias of [`SumProductTree`](@ref) with [`StaticElementVector`](@ref) as its data type) can save a lot memory for you to store exponential number of configurations in polynomial space.
It is a sum-product expression tree to store [`ConfigEnumerator`](@ref) in a lazy style, configurations can be extracted by depth first searching the tree with the `Base.collect` method. Although it is space efficient, it is in general not easy to extract information from it.
This tree structure supports directed sampling so that one can get some statistic properties from it with an intermediate effort.

Expand Down
1 change: 1 addition & 0 deletions docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Polynomials.Polynomial
TruncatedPoly
Max2Poly
ConfigEnumerator
SumProductTree
TreeConfigEnumerator
ConfigSampler
```
Expand Down
3 changes: 2 additions & 1 deletion src/GraphTensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ export estimate_memory
# Algebras
export StaticBitVector, StaticElementVector, @bv_str
export is_commutative_semiring
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod
export ConfigEnumerator, onehotv, ConfigSampler, SumProductTree, TreeConfigEnumerator
export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32, ExtendedTropical
export generate_samples

Expand Down
117 changes: 67 additions & 50 deletions src/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,30 +408,27 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C})
Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})

# tree config enumerator
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
"""
TreeConfigEnumerator{N,S,C} <: AbstractSetNumber
SumProductTree{ET} <: AbstractSetNumber

Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
and is often more memory efficient than putting the configurations in a vector.
One can use [`generate_samples`](@ref) to sample configurations from this tree structure efficiently.
`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.

Fields
-----------------------
* `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
* `tag` is one of `ZERO`, `ONE`, `LEAF`, `SUM`, `PROD`.
* `data` is the element stored in a `LEAF` node.
* `left` and `right` are two operands of a `SUM` or `PROD` node.

Example
------------------------
```jldoctest; setup=:(using GraphTensorNetworks)
julia> s = TreeConfigEnumerator(bv"00111")
julia> s = SumProductTree(bv"00111")
00111


julia> q = TreeConfigEnumerator(bv"10000")
julia> q = SumProductTree(bv"10000")
10000


Expand Down Expand Up @@ -469,44 +466,64 @@ julia> one(s)

```
"""
mutable struct TreeConfigEnumerator{N,S,C} <: AbstractSetNumber
mutable struct SumProductTree{ET} <: AbstractSetNumber
tag::TreeTag
data::StaticElementVector{N,S,C}
left::TreeConfigEnumerator{N,S,C}
right::TreeConfigEnumerator{N,S,C}
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = new{N,S,C}(tag, zero(StaticElementVector{N,S,C}), left, right)
function TreeConfigEnumerator(data::StaticElementVector{N,S,C}) where {N,S,C}
new{N,S,C}(LEAF, data)
end
function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C}
data::ET
left::SumProductTree{ET}
right::SumProductTree{ET}
# zero(ET) can be undef
function SumProductTree(tag::TreeTag, left::SumProductTree{ET}, right::SumProductTree{ET}) where {ET}
res = new{ET}(tag)
res.left = left
res.right = right
return res
end
function SumProductTree(data::ET) where ET
return new{ET}(LEAF, data)
end
function SumProductTree{ET}(tag::TreeTag) where {ET}
@assert tag === ZERO || tag === ONE
return new{N,S,C}(tag)
return new{ET}(tag)
end
end
# these two interfaces must be implemented in order to collect elements
_data_mul(x::StaticElementVector, y::StaticElementVector) = x | y
_data_one(::Type{T}) where T<:StaticElementVector = zero(T) # NOTE: might be optional

"""
TreeConfigEnumerator{N,S,C}

An alias for [`SumProductTree`](@ref)`{StaticElementVector{N, S, C}}`,
which is a useful element type for configuration enumeration.
"""
const TreeConfigEnumerator{N,S,C} = SumProductTree{StaticElementVector{N,S,C}}
TreeConfigEnumerator(data::StaticElementVector) = SumProductTree(data)
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = SumProductTree(tag, left, right)

# AbstractTree APIs
function children(t::TreeConfigEnumerator)
function children(t::SumProductTree)
if t.tag == ZERO || t.tag == LEAF || t.tag == ONE
return typeof(t)[]
else
return [t.left, t.right]
end
end
function printnode(io::IO, t::TreeConfigEnumerator{N,S,C}) where {N,S,C}
function printnode(io::IO, t::SumProductTree{ET}) where {ET}
if t.tag === LEAF
print(io, t.data)
elseif t.tag === ZERO
print(io, "∅")
elseif t.tag === ONE
print(io, zero(StaticElementVector{N,S,C}))
print(io, _data_one(ET))
elseif t.tag === SUM
print(io, "+")
else # PROD
print(io, "*")
end
end

Base.length(x::TreeConfigEnumerator) = _length!(x, IdDict{typeof(x), Int}())
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}())

function _length!(x, d)
haskey(d, x) && return d[x]
Expand All @@ -525,7 +542,7 @@ function _length!(x, d)
end
end

num_nodes(x::TreeConfigEnumerator) = _num_nodes(x, IdDict{typeof(x), Int}())
num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}())
function _num_nodes(x, d)
haskey(d, x) && return 0
if x.tag == ZERO || x.tag == ONE
Expand All @@ -539,37 +556,37 @@ function _num_nodes(x, d)
return res
end

function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
function Base.:(==)(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
return Set(collect(x)) == Set(collect(y))
end

Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t)
Base.show(io::IO, t::SumProductTree) = print_tree(io, t)

function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
function Base.collect(x::SumProductTree{ET}) where {ET}
if x.tag == ZERO
return StaticElementVector{N,S,C}[]
return ET[]
elseif x.tag == ONE
return StaticElementVector{N,S,C}[zero(StaticElementVector{N,S,C})]
return [_data_one(ET)]
elseif x.tag == LEAF
return StaticElementVector{N,S,C}[x.data]
return [x.data]
elseif x.tag == SUM
return vcat(collect(x.left), collect(x.right))
else # PROD
return vec([reduce((x,y)->x|y, si) for si in Iterators.product(collect(x.left), collect(x.right))])
return vec([reduce(_data_mul, si) for si in Iterators.product(collect(x.left), collect(x.right))])
end
end

function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
function Base.:+(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
if x.tag == ZERO
return y
elseif y.tag == ZERO
return x
else
return TreeConfigEnumerator(SUM, x, y)
return SumProductTree(SUM, x, y)
end
end

function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C}
function Base.:*(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
if x.tag == ONE
return y
elseif y.tag == ONE
Expand All @@ -579,18 +596,18 @@ function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C})
elseif y.tag == ZERO
return y
elseif x.tag == LEAF && y.tag == LEAF
return TreeConfigEnumerator(x.data | y.data)
return SumProductTree(_data_mul(x.data, y.data))
else
return TreeConfigEnumerator(PROD, x, y)
return SumProductTree(PROD, x, y)
end
end

Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO)
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ONE)
Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C})
Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C})
Base.zero(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ZERO)
Base.one(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ONE)
Base.zero(::SumProductTree{ET}) where {ET} = zero(SumProductTree{ET})
Base.one(::SumProductTree{ET}) where {ET} = one(SumProductTree{ET})
# todo, check siblings too?
function Base.iszero(t::TreeConfigEnumerator)
function Base.iszero(t::SumProductTree)
if t.tag == SUM
iszero(t.left) && iszero(t.right)
elseif t.tag == ZERO
Expand All @@ -603,9 +620,9 @@ function Base.iszero(t::TreeConfigEnumerator)
end

"""
generate_samples(t::TreeConfigEnumerator, nsamples::Int)
generate_samples(t::SumProductTree, nsamples::Int)

Direct sampling configurations from a [`TreeConfigEnumerator`](@ref) instance.
Direct sampling configurations from a [`SumProductTree`](@ref) instance.

Example
-----------------------------
Expand All @@ -623,15 +640,15 @@ julia> all(s->is_independent_set(g, s), samples)
true
```
"""
function generate_samples(t::TreeConfigEnumerator{N,S,C}, nsamples::Int) where {N,S,C}
function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET}
# get length dict
res = fill(zero(StaticElementVector{N,S,C}), nsamples)
res = fill(_data_one(ET), nsamples)
d = IdDict{typeof(t), Int}()
sample_descend!(res, t, d)
return res
end

function sample_descend!(res::AbstractVector, t::TreeConfigEnumerator, d::IdDict)
function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict)
length(res) == 0 && return res
if t.tag == LEAF
res .|= Ref(t.data)
Expand Down Expand Up @@ -695,14 +712,14 @@ onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampl
# just to make matrix transpose work
Base.transpose(c::ConfigEnumerator) = c
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
Base.transpose(c::TreeConfigEnumerator) = c
function Base.copy(c::TreeConfigEnumerator{N,S,C}) where {N,S,C}
Base.transpose(c::SumProductTree) = c
function Base.copy(c::SumProductTree{ET}) where {ET}
if c.tag == LEAF
TreeConfigEnumerator(c.data)
SumProductTree(c.data)
elseif c.tag == ZERO || c.tag == ONE
TreeConfigEnumerator{N,S,C}(c.tag)
SumProductTree{ET}(c.tag)
else
TreeConfigEnumerator(c.tag, c.left, c.right)
SumProductTree(c.tag, c.left, c.right)
end
end

Expand All @@ -713,7 +730,7 @@ for TYPE in [:AbstractSetNumber, :TruncatedPoly, :ExtendedTropical]
end

# to handle power of polynomials
function Base.:^(x::TreeConfigEnumerator, y::Real)
function Base.:^(x::SumProductTree, y::Real)
if y == 0
return one(x)
elseif x.tag == LEAF || x.tag == ONE || x.tag == ZERO
Expand Down