Skip to content

Commit

Permalink
fix merge_vertices (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennedeg committed Nov 12, 2021
1 parent ab437a6 commit 761912f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 35 deletions.
59 changes: 24 additions & 35 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ Determine how many elements of `x` are less than `i` for all `i` in `1:n`.
"""
function compute_shifts(n::Integer, x::AbstractArray)
tmp = zeros(eltype(x), n)
tmp[x[2:end]] .= 1
tmp[x] .= 1
return cumsum!(tmp, tmp)
end

Expand Down Expand Up @@ -743,29 +743,25 @@ julia> collect(edges(h))
Edge 3 => 4
```
"""
function merge_vertices(g::AbstractGraph, vs)
labels = collect(1:nv(g))
function merge_vertices(g::AbstractSimpleGraph, vs)
# Use lowest value as new vertex id.
sort!(vs)
nvnew = nv(g) - length(unique(vs)) + 1
vs = unique!(sort(vs))
merged_vertex = popfirst!(vs)

nvnew = nv(g) - length(vs)
nvnew <= nv(g) || return g
(v0, vm) = extrema(vs)
v0 > 0 || throw(ArgumentError("invalid vertex ID: $v0 in list of vertices to be merged"))
vm <= nv(g) || throw(ArgumentError("vertex $vm not found in graph")) # TODO 0.7: change to DomainError?
labels[vs] .= v0
shifts = compute_shifts(nv(g), vs[2:end])
for v in vertices(g)
if labels[v] != v0
labels[v] -= shifts[v]
end
end
merged_vertex > 0 || throw(ArgumentError("invalid vertex ID: $merged_vertex in list of vertices to be merged"))
vs[end] <= nv(g) || throw(ArgumentError("vertex $(vs[end]) not found in graph")) # TODO 0.7: change to DomainError?

new_vertex_ids = collect(vertices(g)) .- compute_shifts(nv(g), vs)
new_vertex_ids[vs] .= merged_vertex

#if v in vs then labels[v] == v0 else labels[v] == v
newg = SimpleGraph(nvnew)
for e in edges(g)
u, w = src(e), dst(e)
if labels[u] != labels[w] #not a new self loop
add_edge!(newg, labels[u], labels[w])
if new_vertex_ids[u] != new_vertex_ids[w] #not a new self loop
add_edge!(newg, new_vertex_ids[u], new_vertex_ids[w])
end
end
return newg
Expand Down Expand Up @@ -812,35 +808,27 @@ julia> collect(edges(g))
```
"""
function merge_vertices!(g::Graph{T}, vs::Vector{U} where U <: Integer) where T
vs = sort!(unique(vs))
merged_vertex = popfirst!(vs)
vs = unique!(sort(vs))
(merged_vertex, vm) = extrema(vs)

x = zeros(Int, nv(g))
x[vs] .= 1
new_vertex_ids = collect(1:nv(g)) .- cumsum(x)
merged_vertex > 0 || throw(ArgumentError("invalid vertex ID: $merged_vertex in list of vertices to be merged"))
vm <= nv(g) || throw(ArgumentError("vertex $vm not found in graph")) # TODO 0.7: change to DomainError?

new_vertex_ids = collect(vertices(g)) .- compute_shifts(nv(g), vs[2:end])
new_vertex_ids[vs] .= merged_vertex

for i in vertices(g)
# Adjust connections to merged vertices
if (i != merged_vertex) && !insorted(i, vs)
if new_vertex_ids[i] != merged_vertex
nbrs_to_rewire = Set{T}()
for j in outneighbors(g, i)
if insorted(j, vs)
push!(nbrs_to_rewire, merged_vertex)
else
push!(nbrs_to_rewire, new_vertex_ids[j])
end
push!(nbrs_to_rewire, new_vertex_ids[j])
end
g.fadjlist[new_vertex_ids[i]] = sort(collect(nbrs_to_rewire))

g.fadjlist[new_vertex_ids[i]] = sort!(collect(nbrs_to_rewire))

# Collect connections to new merged vertex
else
nbrs_to_merge = Set{T}()
for element in filter(x -> !(insorted(x, vs)) && (x != merged_vertex), g.fadjlist[i])
push!(nbrs_to_merge, new_vertex_ids[element])
end

for j in vs, e in outneighbors(g, j)
if new_vertex_ids[e] != merged_vertex
push!(nbrs_to_merge, new_vertex_ids[e])
Expand All @@ -850,8 +838,9 @@ function merge_vertices!(g::Graph{T}, vs::Vector{U} where U <: Integer) where T
end
end


# Drop excess vertices
g.fadjlist = g.fadjlist[1:(end - length(vs))]
g.fadjlist = g.fadjlist[begin:(end - length(vs)+1)]

# Correct edge counts
g.ne = sum(degree(g, i) for i in vertices(g)) / 2
Expand Down
7 changes: 7 additions & 0 deletions test/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@
@test neighbors(h2, 5) == [2]
@test ne(h2) == 3
@test nv(h2) == 5

h3 = star_graph(5)
h3merged = merge_vertices(h3, [1,2])
@test neighbors(h3merged, 1) == [2,3,4]
@test neighbors(h3merged, 2) == [1]
@test neighbors(h3merged, 3) == [1]
@test neighbors(h3merged, 4) == [1]
end
end

Expand Down

0 comments on commit 761912f

Please sign in to comment.