Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/IndependentSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ show_graph(graph; locs=locations)
# x_i^{w_i}
# \end{matrix}\right),
# ```
# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and `w_i` is the weight of vertex ``i``.
# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and ``w_i`` is the weight of vertex ``i``.
# Similarly, on each edge ``(u, v)``, we define a matrix ``B`` indexed by ``s_u`` and ``s_v`` as
# ```math
# B = \left(\begin{matrix}
Expand Down
108 changes: 86 additions & 22 deletions src/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,42 +221,106 @@ function Base.:*(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,
return ExtendedTropical{K,TO}(sorted_sum_combination!(res, a.orders, b.orders))
end

# 1. bisect over summed value and find the critical value `c`,
# 2. collect the values with sum combination `≥ c`,
# 3. sort the collected values
function sorted_sum_combination!(res::AbstractVector{TO}, A::AbstractVector{TO}, B::AbstractVector{TO}) where TO
K = length(res)
@assert length(B) == length(A) == K
@inbounds maxval = A[K] * B[K]
ptr = K
@inbounds res[ptr] = maxval
@inbounds queue = [(K,K-1,A[K]*B[K-1]), (K-1,K,A[K-1]*B[K])]
for k = 1:K-1
@inbounds (i, j, res[K-k]) = _pop_max_sum!(queue) # TODO: do not enumerate, use better data structures
_push_if_not_exists!(queue, i, j-1, A, B)
_push_if_not_exists!(queue, i-1, j, A, B)
@inbounds high = A[K] * B[K]

mA = findfirst(!iszero, A)
mB = findfirst(!iszero, B)
if mA === nothing || mB === nothing
res .= Ref(zero(TO))
return res
end
@inbounds low = A[mA] * B[mB]
# count number bigger than x
c, _ = count_geq(A, B, mB, low, true)
@inbounds if c <= K # return
res[K-c+1:K] .= sort!(collect_geq!(view(res,1:c), A, B, mB, low))
if c < K
res[1:K-c] .= zero(TO)
end
return res
end
# calculate by bisection for at most 30 times.
@inbounds for _ = 1:30
mid = mid_point(high, low)
c, nB = count_geq(A, B, mB, mid, true)
if c > K
low = mid
mB = nB
elseif c == K # return
# NOTE: this is the bottleneck
return sort!(collect_geq!(res, A, B, mB, mid))
else
high = mid
end
end
clow, _ = count_geq(A, B, mB, low, false)
@inbounds res .= sort!(collect_geq!(similar(res, clow), A, B, mB, low))[end-K+1:end]
return res
end

function _push_if_not_exists!(queue, i, j, A, B)
@inbounds if j>=1 && i>=1 && !any(x->x[1] >= i && x[2] >= j, queue)
push!(queue, (i, j, A[i]*B[j]))
# count the number of sum-combinations with the sum >= low
function count_geq(A, B, mB, low, earlybreak)
K = length(A)
k = 1 # TODO: we should tighten mA, mB later!
@inbounds Ak = A[K-k+1]
@inbounds Bq = B[K-mB+1]
c = 0
nB = mB
@inbounds for q = K-mB+1:-1:1
Bq = B[K-q+1]
while k < K && Ak * Bq >= low
k += 1
Ak = A[K-k+1]
end
if Ak * Bq >= low
c += k
else
c += (k-1)
if k==1
nB += 1
end
end
if earlybreak && c > K
return c, nB
end
end
return c, nB
end

function _pop_max_sum!(queue)
maxsum = first(queue)[3]
maxloc = 1
@inbounds for i=2:length(queue)
m = queue[i][3]
if m > maxsum
maxsum = m
maxloc = i
function collect_geq!(res, A, B, mB, low)
K = length(A)
k = 1 # TODO: we should tighten mA, mB later!
Ak = A[K-k+1]
Bq = B[K-mB+1]
l = 0
for q = K-mB+1:-1:1
Bq = B[K-q+1]
while k < K && Ak * Bq >= low
k += 1
Ak = A[K-k+1]
end
# push data
ck = Ak * Bq >= low ? k : k-1
for j=1:ck
l += 1
res[l] = Bq * A[end-j+1]
end
end
@inbounds data = queue[maxloc]
deleteat!(queue, maxloc)
return data
return res
end

# for bisection
mid_point(a::Tropical{T}, b::Tropical{T}) where T = Tropical{T}((a.n + b.n) / 2)
mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T,CT} = CountingTropical{T,CT}((a.n + b.n) / 2, a.c)
mid_point(a::Tropical{T}, b::Tropical{T}) where T<:Integer = Tropical{T}((a.n + b.n) ÷ 2)
mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T<:Integer,CT} = CountingTropical{T,CT}((a.n + b.n) ÷ 2, a.c)

function Base.:+(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,TO}
res = Vector{TO}(undef, K)
ptr1, ptr2 = K, K
Expand Down
2 changes: 1 addition & 1 deletion src/bounding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct CacheTree{T}
end
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
if length(se.slicing) != 0
@warn "Slicing is not supported for caching! Fallback to `NestedEinsum`."
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
end
return cached_einsum(se.eins, xs, size_dict)
end
Expand Down
11 changes: 11 additions & 0 deletions test/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ end
end
end

@testset "count geq" begin
A = collect(1:10)
B = collect(2:2:20)
low = 20
c, _ = GraphTensorNetworks.count_geq(A, B, 1, low, false)
@test c == count(x->x>=low, vec([a*b for a in A, b in B]))
res = similar(A, c)
@test sort!(GraphTensorNetworks.collect_geq!(res, A, B, 1, low)) == sort!(filter(x->x>=low, vec([a*b for a in A, b in B])))
end


# check the correctness of sampling
@testset "generate samples" begin
Random.seed!(2)
Expand Down