diff --git a/examples/IndependentSet.jl b/examples/IndependentSet.jl index 65a5dcf9..3d488553 100644 --- a/examples/IndependentSet.jl +++ b/examples/IndependentSet.jl @@ -32,7 +32,7 @@ show_graph(graph; locs=locations) # x_i^{w_i} # \end{matrix}\right), # ``` -# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and `w_i` is the weight of vertex ``i``. +# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and ``w_i`` is the weight of vertex ``i``. # Similarly, on each edge ``(u, v)``, we define a matrix ``B`` indexed by ``s_u`` and ``s_v`` as # ```math # B = \left(\begin{matrix} diff --git a/src/arithematics.jl b/src/arithematics.jl index 35b7d25f..15924622 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -221,42 +221,106 @@ function Base.:*(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K, return ExtendedTropical{K,TO}(sorted_sum_combination!(res, a.orders, b.orders)) end +# 1. bisect over summed value and find the critical value `c`, +# 2. collect the values with sum combination `≥ c`, +# 3. sort the collected values function sorted_sum_combination!(res::AbstractVector{TO}, A::AbstractVector{TO}, B::AbstractVector{TO}) where TO K = length(res) @assert length(B) == length(A) == K - @inbounds maxval = A[K] * B[K] - ptr = K - @inbounds res[ptr] = maxval - @inbounds queue = [(K,K-1,A[K]*B[K-1]), (K-1,K,A[K-1]*B[K])] - for k = 1:K-1 - @inbounds (i, j, res[K-k]) = _pop_max_sum!(queue) # TODO: do not enumerate, use better data structures - _push_if_not_exists!(queue, i, j-1, A, B) - _push_if_not_exists!(queue, i-1, j, A, B) + @inbounds high = A[K] * B[K] + + mA = findfirst(!iszero, A) + mB = findfirst(!iszero, B) + if mA === nothing || mB === nothing + res .= Ref(zero(TO)) + return res + end + @inbounds low = A[mA] * B[mB] + # count number bigger than x + c, _ = count_geq(A, B, mB, low, true) + @inbounds if c <= K # return + res[K-c+1:K] .= sort!(collect_geq!(view(res,1:c), A, B, mB, low)) + if c < K + res[1:K-c] .= zero(TO) + end + return res + end + # calculate by bisection for at most 30 times. + @inbounds for _ = 1:30 + mid = mid_point(high, low) + c, nB = count_geq(A, B, mB, mid, true) + if c > K + low = mid + mB = nB + elseif c == K # return + # NOTE: this is the bottleneck + return sort!(collect_geq!(res, A, B, mB, mid)) + else + high = mid + end end + clow, _ = count_geq(A, B, mB, low, false) + @inbounds res .= sort!(collect_geq!(similar(res, clow), A, B, mB, low))[end-K+1:end] return res end -function _push_if_not_exists!(queue, i, j, A, B) - @inbounds if j>=1 && i>=1 && !any(x->x[1] >= i && x[2] >= j, queue) - push!(queue, (i, j, A[i]*B[j])) +# count the number of sum-combinations with the sum >= low +function count_geq(A, B, mB, low, earlybreak) + K = length(A) + k = 1 # TODO: we should tighten mA, mB later! + @inbounds Ak = A[K-k+1] + @inbounds Bq = B[K-mB+1] + c = 0 + nB = mB + @inbounds for q = K-mB+1:-1:1 + Bq = B[K-q+1] + while k < K && Ak * Bq >= low + k += 1 + Ak = A[K-k+1] + end + if Ak * Bq >= low + c += k + else + c += (k-1) + if k==1 + nB += 1 + end + end + if earlybreak && c > K + return c, nB + end end + return c, nB end -function _pop_max_sum!(queue) - maxsum = first(queue)[3] - maxloc = 1 - @inbounds for i=2:length(queue) - m = queue[i][3] - if m > maxsum - maxsum = m - maxloc = i +function collect_geq!(res, A, B, mB, low) + K = length(A) + k = 1 # TODO: we should tighten mA, mB later! + Ak = A[K-k+1] + Bq = B[K-mB+1] + l = 0 + for q = K-mB+1:-1:1 + Bq = B[K-q+1] + while k < K && Ak * Bq >= low + k += 1 + Ak = A[K-k+1] + end + # push data + ck = Ak * Bq >= low ? k : k-1 + for j=1:ck + l += 1 + res[l] = Bq * A[end-j+1] end end - @inbounds data = queue[maxloc] - deleteat!(queue, maxloc) - return data + return res end +# for bisection +mid_point(a::Tropical{T}, b::Tropical{T}) where T = Tropical{T}((a.n + b.n) / 2) +mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T,CT} = CountingTropical{T,CT}((a.n + b.n) / 2, a.c) +mid_point(a::Tropical{T}, b::Tropical{T}) where T<:Integer = Tropical{T}((a.n + b.n) ÷ 2) +mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T<:Integer,CT} = CountingTropical{T,CT}((a.n + b.n) ÷ 2, a.c) + function Base.:+(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,TO} res = Vector{TO}(undef, K) ptr1, ptr2 = K, K diff --git a/src/bounding.jl b/src/bounding.jl index 99426924..be66b553 100644 --- a/src/bounding.jl +++ b/src/bounding.jl @@ -60,7 +60,7 @@ struct CacheTree{T} end function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict) if length(se.slicing) != 0 - @warn "Slicing is not supported for caching! Fallback to `NestedEinsum`." + @warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`." end return cached_einsum(se.eins, xs, size_dict) end diff --git a/test/arithematics.jl b/test/arithematics.jl index b7059fee..036f413f 100644 --- a/test/arithematics.jl +++ b/test/arithematics.jl @@ -153,6 +153,17 @@ end end end +@testset "count geq" begin + A = collect(1:10) + B = collect(2:2:20) + low = 20 + c, _ = GraphTensorNetworks.count_geq(A, B, 1, low, false) + @test c == count(x->x>=low, vec([a*b for a in A, b in B])) + res = similar(A, c) + @test sort!(GraphTensorNetworks.collect_geq!(res, A, B, 1, low)) == sort!(filter(x->x>=low, vec([a*b for a in A, b in B]))) +end + + # check the correctness of sampling @testset "generate samples" begin Random.seed!(2)