Skip to content

Commit

Permalink
Fix a bug related to ConfigsMax (#62)
Browse files Browse the repository at this point in the history
* Fix a bug related to ConfigsMax

* fix test and bump version
  • Loading branch information
GiggleLiu committed May 2, 2023
1 parent 71694e1 commit b1febf5
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GenericTensorNetworks"
uuid = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
version = "1.3.2"
version = "1.3.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
6 changes: 5 additions & 1 deletion src/bounding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ The backward rule for tropical einsum.
* `size_dict` is a key-value map from tensor label to dimension size.
"""
function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ymask), size_dict)
y .= inv.(y) .* ymask
# remove float to improve the stability of the algorithm
removeinf(x::CountingTropical{<:AbstractFloat}) = isinf(x.n) ? typeof(x)(prevfloat(x.n), x.c) : x
removeinf(x::Tropical{<:AbstractFloat}) = isinf(x.n) ? typeof(x)(prevfloat(x.n)) : x
removeinf(x) = x
y .= removeinf.(inv.(y) .* ymask)
masks = []
for i=1:length(ixs)
nixs = OMEinsum._insertat(ixs, i, iy)
Expand Down
3 changes: 2 additions & 1 deletion src/networks/IndependentSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ function generate_tensors(x::T, gp::IndependentSet) where T
nv(gp.graph) == 0 && return []
ixs = getixsv(gp.code)
# we only add labels at vertex tensors
return select_dims([
tensors = select_dims([
add_labels!(Array{T}[misv(_pow.(Ref(x), get_weights(gp, i))) for i=1:nv(gp.graph)], ixs[1:nv(gp.graph)], labels(gp))...,
Array{T}[misb(T, length(ix)) for ix in ixs[nv(gp.graph)+1:end]]... # if n!=2, it corresponds to set packing problem.
], ixs, fixedvertices(gp))
return tensors
end

function misb(::Type{T}, n::Integer=2) where T
Expand Down
18 changes: 18 additions & 0 deletions test/configurations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,22 @@ end
@test all(x->count_ones(x)==(i-1), s.data)
end
end
end

@testset "configs bug fix" begin
subgraph = let
g = SimpleGraph(5)
vertices = "cdefg"
for (w, v) in ["cd", "ce", "cf", "de", "ef", "fg"]
add_edge!(g, findfirst(==(w), vertices), findfirst(==(v), vertices))
end
g
end
problem = IndependentSet(subgraph, openvertices=[1,4,5])
res1 = solve(problem, SizeMax(), T=Float64)
@test res1 == Tropical.(reshape([1, 1, 2, -Inf, 2, 2, -Inf, -Inf], 2, 2, 2))
res2 = solve(problem, CountingMax(); T=Float64)
res3 = solve(problem, ConfigsMax(; bounded=true); T=Float64)
@test getfield.(res2, :n) == getfield.(res1, :n)
@test getfield.(res3, :n) == getfield.(res1, :n)
end

0 comments on commit b1febf5

Please sign in to comment.