-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
graph_merge.py
137 lines (108 loc) · 4.22 KB
/
graph_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import heapq
def _revalidate_node_edges(rag, node, heap_list):
"""Handles validation and invalidation of edges incident to a node.
This function invalidates all existing edges incident on `node` and inserts
new items in `heap_list` updated with the valid weights.
rag : RAG
The Region Adjacency Graph.
node : int
The id of the node whose incident edges are to be validated/invalidated
.
heap_list : list
The list containing the existing heap of edges.
"""
# networkx updates data dictionary if edge exists
# this would mean we have to reposition these edges in
# heap if their weight is updated.
# instead we invalidate them
for nbr in rag.neighbors(node):
data = rag[node][nbr]
try:
# invalidate edges incident on `dst`, they have new weights
data['heap item'][3] = False
_invalidate_edge(rag, node, nbr)
except KeyError:
# will handle the case where the edge did not exist in the existing
# graph
pass
wt = data['weight']
heap_item = [wt, node, nbr, True]
data['heap item'] = heap_item
heapq.heappush(heap_list, heap_item)
def _rename_node(graph, node_id, copy_id):
""" Rename `node_id` in `graph` to `copy_id`. """
graph._add_node_silent(copy_id)
graph.node[copy_id].update(graph.node[node_id])
for nbr in graph.neighbors(node_id):
wt = graph[node_id][nbr]['weight']
graph.add_edge(nbr, copy_id, {'weight': wt})
graph.remove_node(node_id)
def _invalidate_edge(graph, n1, n2):
""" Invalidates the edge (n1, n2) in the heap. """
graph[n1][n2]['heap item'][3] = False
def merge_hierarchical(labels, rag, thresh, rag_copy, in_place_merge,
merge_func, weight_func):
"""Perform hierarchical merging of a RAG.
Greedily merges the most similar pair of nodes until no edges lower than
`thresh` remain.
Parameters
----------
labels : ndarray
The array of labels.
rag : RAG
The Region Adjacency Graph.
thresh : float
Regions connected by an edge with weight smaller than `thresh` are
merged.
rag_copy : bool
If set, the RAG copied before modifying.
in_place_merge : bool
If set, the nodes are merged in place. Otherwise, a new node is
created for each merge..
merge_func : callable
This function is called before merging two nodes. For the RAG `graph`
while merging `src` and `dst`, it is called as follows
``merge_func(graph, src, dst)``.
weight_func : callable
The function to compute the new weights of the nodes adjacent to the
merged node. This is directly supplied as the argument `weight_func`
to `merge_nodes`.
Returns
-------
out : ndarray
The new labeled array.
"""
if rag_copy:
rag = rag.copy()
edge_heap = []
for n1, n2, data in rag.edges(data=True):
# Push a valid edge in the heap
wt = data['weight']
heap_item = [wt, n1, n2, True]
heapq.heappush(edge_heap, heap_item)
# Reference to the heap item in the graph
data['heap item'] = heap_item
while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
_, n1, n2, valid = heapq.heappop(edge_heap)
# Ensure popped edge is valid, if not, the edge is discarded
if valid:
# Invalidate all neigbors of `src` before its deleted
for nbr in rag.neighbors(n1):
_invalidate_edge(rag, n1, nbr)
for nbr in rag.neighbors(n2):
_invalidate_edge(rag, n2, nbr)
if not in_place_merge:
next_id = rag.next_id()
_rename_node(rag, n2, next_id)
src, dst = n1, next_id
else:
src, dst = n1, n2
merge_func(rag, src, dst)
new_id = rag.merge_nodes(src, dst, weight_func)
_revalidate_node_edges(rag, new_id, edge_heap)
label_map = np.arange(labels.max() + 1)
for ix, (n, d) in enumerate(rag.nodes(data=True)):
for label in d['labels']:
label_map[label] = ix
return label_map[labels]