Skip to content

Commit

Permalink
fix: kernel size one can be ignored. Keras now recognizes DepthWiseSe…
Browse files Browse the repository at this point in the history
…perable convolutions correctly as convolutions
  • Loading branch information
MLRichter committed Feb 5, 2022
1 parent ce3aeb4 commit 6b06852
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 16 deletions.
5 changes: 5 additions & 0 deletions rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py
Expand Up @@ -6,6 +6,7 @@
from rfa_toolbox.encodings.tensorflow_keras.layer_handlers import (
AnyHandler,
DenseHandler,
GlobalPoolingHandler,
InputHandler,
KernelBasedHandler,
PoolingBasedHandler,
Expand All @@ -22,6 +23,7 @@
KernelBasedHandler(),
PoolingBasedHandler(),
DenseHandler(),
GlobalPoolingHandler(),
AnyHandler(),
]

Expand Down Expand Up @@ -174,6 +176,9 @@ def create_graph_from_model(
By default, not filtering is done.
"""
model_dict = keras_model_to_dict(model)
from pprint import pprint

pprint(model_dict)
callable_filter = (
filter_rf
if (not isinstance(filter_rf, str) and filter_rf is not None)
Expand Down
18 changes: 16 additions & 2 deletions rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py
Expand Up @@ -41,7 +41,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
class KernelBasedHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Handles only layers featuring a kernel_size and filters"""
return "kernel_size" in node["config"] and "filters" in node["config"]
return "kernel_size" in node["config"]

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation
Expand All @@ -51,11 +51,12 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
f"{'x'.join(str(x) for x in node['config']['kernel_size'])} "
f"/ {node['config']['strides']}"
)
filters = None if "filters" not in node["config"] else node["config"]["filters"]
return LayerDefinition(
name=name,
kernel_size=node["config"]["kernel_size"],
stride_size=node["config"]["strides"],
filters=node["config"]["filters"],
filters=filters,
)


Expand Down Expand Up @@ -93,6 +94,19 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
return LayerDefinition(name=name, units=node["config"]["units"])


@attrs(frozen=True, slots=True, auto_attribs=True)
class GlobalPoolingHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
"""Handles only layers feature units as attribute"""
return "Global" in node["class_name"] and "Pooling" in node["class_name"]

def __call__(self, node: Dict[str, Any]) -> LayerDefinition:
"""Transform the json-representation
of a compute node in the tensorflow-graph"""
name = node["class_name"]
return LayerDefinition(name=name)


@attrs(frozen=True, slots=True, auto_attribs=True)
class InputHandler(LayerInfoHandler):
def can_handle(self, node: Dict[str, Any]) -> bool:
Expand Down
18 changes: 17 additions & 1 deletion rfa_toolbox/graphs.py
Expand Up @@ -439,6 +439,7 @@ def is_border(
receptive_field_provider: Callable[
["EnrichedNetworkNode"], Union[float, int]
] = receptive_field_provider,
filter_kernel_size_1: bool = False,
) -> bool:
"""Checks if this layer is a border layer.
A border layer is predicted not advance the
Expand All @@ -455,6 +456,8 @@ def is_border(
receptive field sizes, which is currently
the most reliable way of predicting
unproductive layers.
filter_kernel_size_1: any layer with a kernel size of 1 may not
be a border layer if this is set.
Returns:
True if this layer is predicted to be unproductive
Expand Down Expand Up @@ -487,7 +490,20 @@ def is_border(
# the layer itself and all following layers have a receptive field size
# GREATER than the input resolution
# return all(direct_predecessors) and own and all(successors)
return all(direct_predecessors) and own # and all(successors)
can_be_border = not (
filter_kernel_size_1
and (
(
isinstance(self.layer_info.kernel_size, Sequence)
and all(np.asarray(self.layer_info.kernel_size) == 1)
)
or (
isinstance(self.layer_info.kernel_size, int)
and self.kernel_size == 1
)
)
)
return all(direct_predecessors) and own and can_be_border # and all(successors)

def is_in(self, container: Union[List[Node], Dict[Node, Any]]) -> bool:
"""Checks if this node is inside a an iterable collection.
Expand Down
44 changes: 39 additions & 5 deletions rfa_toolbox/utils/graph_utils.py
Expand Up @@ -145,9 +145,33 @@ def filters_non_infinite_rf_sizes(
return result


def _x1_larger_x2(
x1: Union[Tuple[int, ...], int], x2: Union[Tuple[int, ...], int]
) -> bool:
"""Compare two receptive field sizes.
Args:
x1: the first receptive field size
x2: the second receptive field size
Returns:
True if the first receptive field size is smaller than the second.
"""
if (
isinstance(x1, Sequence)
or isinstance(x1, np.ndarray)
or isinstance(x2, Sequence)
or isinstance(x2, np.ndarray)
):
return all(np.asarray(x1) < np.asarray(x2))
else:
return x1 < x2


def input_resolution_range(
graph: EnrichedNetworkNode,
filter_all_inf_rf: bool = False,
filter_kernel_size_1: bool = False,
cardinality: int = 2,
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""Obtain the smallest and largest feasible input resolution.
Expand Down Expand Up @@ -182,8 +206,10 @@ def input_resolution_range(
"""
all_nodes = obtain_all_nodes(graph)
all_nodes = filters_non_convolutional_node(all_nodes)
if filter_kernel_size_1:
all_nodes = [node for node in all_nodes if node.kernel_size > 1]
if not filter_all_inf_rf:
rf_min = [x.receptive_field_min for x in all_nodes]
rf_min = [xi.receptive_field_min for x in all_nodes for xi in x.predecessors]
rf_max = [x.receptive_field_max for x in all_nodes]
else:
rf_min = [
Expand All @@ -199,7 +225,11 @@ def input_resolution_range(
for x in all_nodes
]

def find_max(rf: List[Union[Tuple[int, ...], int]], axis: int = 0) -> int:
def find_max(
rf: List[Union[Tuple[int, ...], int]],
axis: int = 0,
second_largest: bool = False,
) -> int:
"""Find the maximum value of a list of tuples or integers.
Args:
Expand All @@ -209,11 +239,15 @@ def find_max(rf: List[Union[Tuple[int, ...], int]], axis: int = 0) -> int:
Returns:
The maximum value of the list.
"""
rf_no_tuples = [
rf_no_tuples = {
x[axis] if isinstance(x, Sequence) or isinstance(x, np.ndarray) else x
for x in rf
]
return max(rf_no_tuples)
}
if not second_largest:
return max(rf_no_tuples)
else:
rf_no_tuples.remove(max(rf_no_tuples))
return max(rf_no_tuples)

r_max = tuple(find_max(rf_max, i) for i in range(cardinality))
r_min = tuple(find_max(rf_min, i) for i in range(cardinality))
Expand Down
18 changes: 15 additions & 3 deletions rfa_toolbox/vizualize.py
Expand Up @@ -25,6 +25,7 @@ def visualize_node(
color_border: bool,
color_critical: bool,
include_rf_info: bool = True,
filter_kernel_size_1: bool = False,
) -> None:
"""Create a node in a graphviz-graph based on an EnrichedNetworkNode instance.
Also creates all edges that lead from predecessor nodes to this node.
Expand All @@ -44,18 +45,27 @@ def visualize_node(
"""
color = "white"
if node.is_border(input_resolution=input_res) and color_border:
if (
node.is_border(
input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1
)
and color_border
):
color = "red"
elif (
np.all(np.asarray(node.receptive_field_min) > np.asarray(input_res))
and color_critical
and not node.is_border(input_resolution=input_res)
and not node.is_border(
input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1
)
):
color = "orange"
elif (
np.any(np.asarray(node.receptive_field_min) > np.asarray(input_res))
and color_critical
and not node.is_border(input_resolution=input_res)
and not node.is_border(
input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1
)
):
color = "yellow"
l_name = node.layer_info.name
Expand Down Expand Up @@ -91,6 +101,7 @@ def visualize_architecture(
color_critical: bool = True,
color_border: bool = True,
include_rf_info: bool = True,
filter_kernel_size_1: bool = False,
) -> graphviz.Digraph:
"""Visualize an architecture using graphviz
and mark critical and border layers in the graph visualization.
Expand Down Expand Up @@ -124,6 +135,7 @@ def visualize_architecture(
color_border=color_border,
color_critical=color_critical,
include_rf_info=include_rf_info,
filter_kernel_size_1=filter_kernel_size_1,
)
return f

Expand Down
8 changes: 3 additions & 5 deletions tests/test_graph/test_utils.py
Expand Up @@ -336,7 +336,7 @@ def test_with_scalar_receptive_field_sizes(self, sequential_network):
r_min, r_max = input_resolution_range(sequential_network)
assert len(r_max) == 2
assert len(r_min) == 2
assert r_min == (13, 13)
assert r_min == (11, 11)
assert r_max == (13, 13)

def test_with_non_sequential(self, nonsequential_network2):
Expand All @@ -350,17 +350,15 @@ def test_with_non_square_receptive_field_sizes(self, sequential_network_non_squa
r_min, r_max = input_resolution_range(sequential_network_non_square)
assert len(r_max) == 2
assert len(r_min) == 2
assert r_min == (13, 25)
assert r_min == (11, 21)
assert r_max == (13, 25)

def test_with_non_square_receptive_field_sizes_without_se(
self, sequential_network_non_square
):
model = torchvision.models.resnet50()
graph = create_graph_from_pytorch_model(model)
min_res, max_res = input_resolution_range(
graph, filter_all_inf_rf=False
) # (75, 75), (427, 427)
min_res, max_res = input_resolution_range(graph) # (75, 75), (427, 427)
assert len(min_res) == 2
assert len(max_res) == 2
assert min_res == (75, 75)
Expand Down

0 comments on commit 6b06852

Please sign in to comment.