-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
intermediate_graph.py
269 lines (236 loc) · 8.95 KB
/
intermediate_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from attr import attrib, attrs
from graphviz import Digraph as GraphVizDigraph
from rfa_toolbox.encodings.pytorch.domain import LayerInfoHandler, NodeSubstitutor
from rfa_toolbox.encodings.pytorch.layer_handlers import (
AnyAdaptivePool,
AnyConv,
AnyHandler,
AnyPool,
ConvNormActivationHandler,
FlattenHandler,
FunctionalKernelHandler,
LinearHandler,
SqueezeExcitationHandler,
)
from rfa_toolbox.encodings.pytorch.substitutors import (
input_substitutor,
numeric_substitutor,
output_substitutor,
)
from rfa_toolbox.graphs import (
KNOWN_FILTER_MAPPING,
EnrichedNetworkNode,
LayerDefinition,
ReceptiveFieldInfo,
)
RESOLVING_STRATEGY = [
ConvNormActivationHandler(),
SqueezeExcitationHandler(),
AnyConv(),
AnyPool(),
AnyAdaptivePool(),
FlattenHandler(),
LinearHandler(),
FunctionalKernelHandler(),
AnyHandler(),
]
SUBSTITUTION_STRATEGY = [
numeric_substitutor(),
input_substitutor(),
output_substitutor(),
]
@attrs(auto_attribs=True, slots=True)
class Digraph:
"""This digraph object is used to transform the j
it-compiled digraph into the graph-representation
of this library.
Args:
ref_mod: the neural network model in a non-jit-compiled
variant
"""
ref_mod: torch.nn.Module
format: str = ""
graph_attr: Dict[str, str] = attrib(factory=dict)
edge_collection: List[Tuple[str, str]] = attrib(factory=list)
raw_nodes: Dict[str, Tuple[str, str]] = attrib(factory=dict)
layer_definitions: Dict[str, LayerDefinition] = attrib(factory=dict)
layer_info_handlers: List[LayerInfoHandler] = attrib(
factory=lambda: RESOLVING_STRATEGY
)
layer_substitutors: List[NodeSubstitutor] = attrib(
factory=lambda: SUBSTITUTION_STRATEGY
)
filter_rf: Callable[
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
] = KNOWN_FILTER_MAPPING[None]
def _find_predecessors(self, name: str) -> List[str]:
return [e[0] for e in self.edge_collection if e[1] == name]
def _get_layer_definition(
self,
label: str,
kernel_size: Optional[Union[Tuple[int, ...], int]] = None,
stride_size: Optional[Union[Tuple[int, ...], int]] = None,
) -> LayerDefinition:
resolvable = self._get_resolvable(label)
name = self._get_name(label)
for handler in self.layer_info_handlers:
if handler.can_handle(label):
return handler(
model=self.ref_mod,
resolvable_string=resolvable,
name=name,
kernel_size=kernel_size,
stride_size=stride_size,
)
raise ValueError(f"Did not find a way to handle the following layer: {name}")
def attr(self, label: str) -> None:
"""This is a dummy function to mimic the behavior
of a digraph-object from Graphviz with no functionality."""
pass
def edge(self, node_id1: str, node_id2: str) -> None:
"""Creates an directed edge in the compute graph
from one node to the other in the current Digraph-Instance
Args:
node_id1: the id of the start node
node_id2: the id of the target node
Returns:
Nothing.
"""
self.edge_collection.append((node_id1, node_id2))
def node(
self,
name: str,
label: Optional[str] = None,
shape: str = "box",
style: Optional[str] = None,
kernel_size: Optional[Union[Tuple[int, ...], int]] = None,
stride_size: Optional[Union[Tuple[int, ...], int]] = None,
units: Optional[int] = None,
filters: Optional[int] = None,
) -> None:
"""Creates a node in the digraph-instance.
Args:
name: the name of the node, the name must be unique
to properly identify the node.
label: the label is descriptive for the functionality
of the node
shape: unused variable for compatibility with GraphViz
style: unused variable for compatibility with GraphViz
Returns:
Nothing.
"""
label = name if label is None else label
layer_definition = self._get_layer_definition(
label, kernel_size=kernel_size, stride_size=stride_size
)
self.layer_definitions[name] = layer_definition
def subgraph(self, name: str) -> GraphVizDigraph:
"""This is a dummy function to mimic the behavior
of a digraph-object from Graphviz with no functionality."""
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return
def _is_resolvable(
self, predecessors: List[str], resolved_nodes: Dict[str, EnrichedNetworkNode]
) -> bool:
if not predecessors:
return True
else:
return all([pred in resolved_nodes for pred in predecessors])
def _find_resolvable_node(
self,
node_to_pred_map: Dict[str, List[str]],
resolved_nodes: Dict[str, EnrichedNetworkNode],
) -> Optional[str]:
for name, preds in node_to_pred_map.items():
if name not in resolved_nodes and self._is_resolvable(
preds, resolved_nodes
):
return name
return None
def _substitute(self, node: EnrichedNetworkNode):
all_Layers = node.all_layers[:]
for substitutor in self.layer_substitutors:
for nd in all_Layers:
if substitutor.can_handle(nd.layer_info.name):
substitutor(nd)
continue
return
def _check_for_lone_node(self, resolved_nodes: Dict[str, EnrichedNetworkNode]):
for name, node in resolved_nodes.items():
if len(node.predecessors) == 0 and len(node.succecessors) == 0:
warnings.warn(
f"Found a node with no predecessors and no successors: "
f"'{node.layer_info.name}',"
f" this may be caused by some control-flow in "
f" this node disabling any processing"
f" within the node.",
UserWarning,
)
def to_graph(self) -> EnrichedNetworkNode:
"""Transforms the graph stored in the Digraph-Instance into
a graph consisting of EnrichedNetworkNode-objects.
Allowing the computation of border layers and the visualization of the
graph using the visualize-Module.
Returns:
The output-node of the EnrichedNetworkNode-based graph
"""
node_to_pred_map: Dict[str, List[str]] = {}
for name in self.layer_definitions.keys():
preds = self._find_predecessors(name)
node_to_pred_map[name] = preds
resolved_nodes: Dict[str, EnrichedNetworkNode] = {}
resolved_node = None
while len(resolved_nodes) != len(node_to_pred_map):
resolvable_node_name = self._find_resolvable_node(
node_to_pred_map, resolved_nodes
)
if resolvable_node_name is None:
break
resolved_node = self.create_enriched_node(
resolved_nodes,
node_to_pred_map[resolvable_node_name],
self.layer_definitions[resolvable_node_name],
resolvable_node_name,
)
resolved_nodes[resolvable_node_name] = resolved_node
self._check_for_lone_node(resolved_nodes)
self._substitute(resolved_node)
return resolved_node
def _get_resolvable(self, name: str) -> str:
return name.split(" ")[0]
def _get_name(self, label: str) -> str:
if "(" in label:
return label.split("(")[1].replace(")", "")
else:
return label
def create_enriched_node(
self,
resolved_nodes: Dict[str, EnrichedNetworkNode],
preds: List[str],
layer_def: LayerDefinition,
name: str,
) -> EnrichedNetworkNode:
"""Creates an enriched node from the current graph node.
Args:
resolved_nodes: a dicationary, mapping node-ids to the nodes
to their corresponding EnrichedNetworkNode instances
preds: a list the direct predecessor (ids)
layer_def: the layer definition instance for this node.
name: thr name of the node, used as id
Returns:
The EnrichedNetworkNode instance of the same node
"""
pred_nodes: List[EnrichedNetworkNode] = [resolved_nodes[p] for p in preds]
node = EnrichedNetworkNode(
name=name,
layer_info=layer_def,
predecessors=pred_nodes,
receptive_field_info_filter=self.filter_rf,
)
return node