/
layer_handlers.py
117 lines (92 loc) · 4.03 KB
/
layer_handlers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from rfa_toolbox.graphs import LayerDefinition
try:
from typing import Any, Dict, Protocol
except ImportError:
from typing_extensions import Protocol
from attr import attrs
class LayerInfoHandler(Protocol):
"""Creates a LayerDefinition from the model and a resolvable string."""
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Checks if this handler can process the
node in the compute graph of the model.
Args:
node: the node in question
Returns:
True if the node can be processed into a
valid LayerDefinition by this handler.
"""
...
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation of a compute node in the tensorflow-graph
into a LayerDefinition.
Args:
node: the node in question
Returns:
A LayerDefinition that reflects the properties of the layer.
"""
...
@attrs(frozen=True, slots=True, auto_attribs=True)
class KernelBasedHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Handles only layers featuring a kernel_size and filters"""
return "kernel_size" in node["config"] and "filters" in node["config"]
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation
of a compute node in the tensorflow-graph"""
name = (
f"{node['class_name']} "
f"{'x'.join(str(x) for x in node['config']['kernel_size'])} "
f"/ {node['config']['strides']}"
)
return LayerDefinition(
name=name,
kernel_size=node["config"]["kernel_size"],
stride_size=node["config"]["strides"],
filters=node["config"]["filters"],
)
@attrs(frozen=True, slots=True, auto_attribs=True)
class PoolingBasedHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Handles only layers featuring a pool_size"""
return "pool_size" in node["config"]
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation of a
compute node in the tensorflow-graph"""
name = (
f"{node['class_name']} "
f"{'x'.join(str(x) for x in node['config']['pool_size'])} "
f"/ {node['config']['strides']}"
)
return LayerDefinition(
name=name,
kernel_size=node["config"]["pool_size"],
stride_size=node["config"]["strides"],
)
@attrs(frozen=True, slots=True, auto_attribs=True)
class DenseHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Handles only layers feature units as attribute"""
return "units" in node["config"]
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation
of a compute node in the tensorflow-graph"""
name = node["class_name"]
return LayerDefinition(name=name, units=node["config"]["units"])
@attrs(frozen=True, slots=True, auto_attribs=True)
class InputHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""This is strictly meant for handling input nodes"""
return node["class_name"] == "InputLayer"
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation of
a compute node in the tensorflow-graph"""
return LayerDefinition(name=node["class_name"], kernel_size=1, stride_size=1)
@attrs(frozen=True, slots=True, auto_attribs=True)
class AnyHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""This is a catch-all handler"""
return True
def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation of a
compute node in the tensorflow-graph"""
return LayerDefinition(name=node["class_name"], kernel_size=1, stride_size=1)