From 6b068522e2b23d7110ac1e886bcd441a7baf4fbd Mon Sep 17 00:00:00 2001 From: MLRichter Date: Sat, 5 Feb 2022 20:14:15 +0100 Subject: [PATCH] fix: kernel size one can be ignored. Keras now recognizes DepthWiseSeperable convolutions correctly as convolutions --- .../tensorflow_keras/ingest_architecture.py | 5 +++ .../tensorflow_keras/layer_handlers.py | 18 +++++++- rfa_toolbox/graphs.py | 18 +++++++- rfa_toolbox/utils/graph_utils.py | 44 ++++++++++++++++--- rfa_toolbox/vizualize.py | 18 ++++++-- tests/test_graph/test_utils.py | 8 ++-- 6 files changed, 95 insertions(+), 16 deletions(-) diff --git a/rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py b/rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py index daecf3c..7211437 100644 --- a/rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py +++ b/rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py @@ -6,6 +6,7 @@ from rfa_toolbox.encodings.tensorflow_keras.layer_handlers import ( AnyHandler, DenseHandler, + GlobalPoolingHandler, InputHandler, KernelBasedHandler, PoolingBasedHandler, @@ -22,6 +23,7 @@ KernelBasedHandler(), PoolingBasedHandler(), DenseHandler(), + GlobalPoolingHandler(), AnyHandler(), ] @@ -174,6 +176,9 @@ def create_graph_from_model( By default, not filtering is done. """ model_dict = keras_model_to_dict(model) + from pprint import pprint + + pprint(model_dict) callable_filter = ( filter_rf if (not isinstance(filter_rf, str) and filter_rf is not None) diff --git a/rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py b/rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py index d596379..93b9e81 100644 --- a/rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py +++ b/rfa_toolbox/encodings/tensorflow_keras/layer_handlers.py @@ -41,7 +41,7 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition: class KernelBasedHandler(LayerInfoHandler): def can_handle(self, node: Dict[str, Any]) -> bool: """Handles only layers featuring a kernel_size and filters""" - return "kernel_size" in node["config"] and "filters" in node["config"] + return "kernel_size" in node["config"] def __call__(self, node: Dict[str, Any]) -> LayerDefinition: """Transform the json-representation @@ -51,11 +51,12 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition: f"{'x'.join(str(x) for x in node['config']['kernel_size'])} " f"/ {node['config']['strides']}" ) + filters = None if "filters" not in node["config"] else node["config"]["filters"] return LayerDefinition( name=name, kernel_size=node["config"]["kernel_size"], stride_size=node["config"]["strides"], - filters=node["config"]["filters"], + filters=filters, ) @@ -93,6 +94,19 @@ def __call__(self, node: Dict[str, Any]) -> LayerDefinition: return LayerDefinition(name=name, units=node["config"]["units"]) +@attrs(frozen=True, slots=True, auto_attribs=True) +class GlobalPoolingHandler(LayerInfoHandler): + def can_handle(self, node: Dict[str, Any]) -> bool: + """Handles only layers feature units as attribute""" + return "Global" in node["class_name"] and "Pooling" in node["class_name"] + + def __call__(self, node: Dict[str, Any]) -> LayerDefinition: + """Transform the json-representation + of a compute node in the tensorflow-graph""" + name = node["class_name"] + return LayerDefinition(name=name) + + @attrs(frozen=True, slots=True, auto_attribs=True) class InputHandler(LayerInfoHandler): def can_handle(self, node: Dict[str, Any]) -> bool: diff --git a/rfa_toolbox/graphs.py b/rfa_toolbox/graphs.py index 8e63e82..9610ea2 100644 --- a/rfa_toolbox/graphs.py +++ b/rfa_toolbox/graphs.py @@ -439,6 +439,7 @@ def is_border( receptive_field_provider: Callable[ ["EnrichedNetworkNode"], Union[float, int] ] = receptive_field_provider, + filter_kernel_size_1: bool = False, ) -> bool: """Checks if this layer is a border layer. A border layer is predicted not advance the @@ -455,6 +456,8 @@ def is_border( receptive field sizes, which is currently the most reliable way of predicting unproductive layers. + filter_kernel_size_1: any layer with a kernel size of 1 may not + be a border layer if this is set. Returns: True if this layer is predicted to be unproductive @@ -487,7 +490,20 @@ def is_border( # the layer itself and all following layers have a receptive field size # GREATER than the input resolution # return all(direct_predecessors) and own and all(successors) - return all(direct_predecessors) and own # and all(successors) + can_be_border = not ( + filter_kernel_size_1 + and ( + ( + isinstance(self.layer_info.kernel_size, Sequence) + and all(np.asarray(self.layer_info.kernel_size) == 1) + ) + or ( + isinstance(self.layer_info.kernel_size, int) + and self.kernel_size == 1 + ) + ) + ) + return all(direct_predecessors) and own and can_be_border # and all(successors) def is_in(self, container: Union[List[Node], Dict[Node, Any]]) -> bool: """Checks if this node is inside a an iterable collection. diff --git a/rfa_toolbox/utils/graph_utils.py b/rfa_toolbox/utils/graph_utils.py index 68a6c1b..e0571a3 100644 --- a/rfa_toolbox/utils/graph_utils.py +++ b/rfa_toolbox/utils/graph_utils.py @@ -145,9 +145,33 @@ def filters_non_infinite_rf_sizes( return result +def _x1_larger_x2( + x1: Union[Tuple[int, ...], int], x2: Union[Tuple[int, ...], int] +) -> bool: + """Compare two receptive field sizes. + + Args: + x1: the first receptive field size + x2: the second receptive field size + + Returns: + True if the first receptive field size is smaller than the second. + """ + if ( + isinstance(x1, Sequence) + or isinstance(x1, np.ndarray) + or isinstance(x2, Sequence) + or isinstance(x2, np.ndarray) + ): + return all(np.asarray(x1) < np.asarray(x2)) + else: + return x1 < x2 + + def input_resolution_range( graph: EnrichedNetworkNode, filter_all_inf_rf: bool = False, + filter_kernel_size_1: bool = False, cardinality: int = 2, ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """Obtain the smallest and largest feasible input resolution. @@ -182,8 +206,10 @@ def input_resolution_range( """ all_nodes = obtain_all_nodes(graph) all_nodes = filters_non_convolutional_node(all_nodes) + if filter_kernel_size_1: + all_nodes = [node for node in all_nodes if node.kernel_size > 1] if not filter_all_inf_rf: - rf_min = [x.receptive_field_min for x in all_nodes] + rf_min = [xi.receptive_field_min for x in all_nodes for xi in x.predecessors] rf_max = [x.receptive_field_max for x in all_nodes] else: rf_min = [ @@ -199,7 +225,11 @@ def input_resolution_range( for x in all_nodes ] - def find_max(rf: List[Union[Tuple[int, ...], int]], axis: int = 0) -> int: + def find_max( + rf: List[Union[Tuple[int, ...], int]], + axis: int = 0, + second_largest: bool = False, + ) -> int: """Find the maximum value of a list of tuples or integers. Args: @@ -209,11 +239,15 @@ def find_max(rf: List[Union[Tuple[int, ...], int]], axis: int = 0) -> int: Returns: The maximum value of the list. """ - rf_no_tuples = [ + rf_no_tuples = { x[axis] if isinstance(x, Sequence) or isinstance(x, np.ndarray) else x for x in rf - ] - return max(rf_no_tuples) + } + if not second_largest: + return max(rf_no_tuples) + else: + rf_no_tuples.remove(max(rf_no_tuples)) + return max(rf_no_tuples) r_max = tuple(find_max(rf_max, i) for i in range(cardinality)) r_min = tuple(find_max(rf_min, i) for i in range(cardinality)) diff --git a/rfa_toolbox/vizualize.py b/rfa_toolbox/vizualize.py index 4619909..113b1c3 100644 --- a/rfa_toolbox/vizualize.py +++ b/rfa_toolbox/vizualize.py @@ -25,6 +25,7 @@ def visualize_node( color_border: bool, color_critical: bool, include_rf_info: bool = True, + filter_kernel_size_1: bool = False, ) -> None: """Create a node in a graphviz-graph based on an EnrichedNetworkNode instance. Also creates all edges that lead from predecessor nodes to this node. @@ -44,18 +45,27 @@ def visualize_node( """ color = "white" - if node.is_border(input_resolution=input_res) and color_border: + if ( + node.is_border( + input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1 + ) + and color_border + ): color = "red" elif ( np.all(np.asarray(node.receptive_field_min) > np.asarray(input_res)) and color_critical - and not node.is_border(input_resolution=input_res) + and not node.is_border( + input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1 + ) ): color = "orange" elif ( np.any(np.asarray(node.receptive_field_min) > np.asarray(input_res)) and color_critical - and not node.is_border(input_resolution=input_res) + and not node.is_border( + input_resolution=input_res, filter_kernel_size_1=filter_kernel_size_1 + ) ): color = "yellow" l_name = node.layer_info.name @@ -91,6 +101,7 @@ def visualize_architecture( color_critical: bool = True, color_border: bool = True, include_rf_info: bool = True, + filter_kernel_size_1: bool = False, ) -> graphviz.Digraph: """Visualize an architecture using graphviz and mark critical and border layers in the graph visualization. @@ -124,6 +135,7 @@ def visualize_architecture( color_border=color_border, color_critical=color_critical, include_rf_info=include_rf_info, + filter_kernel_size_1=filter_kernel_size_1, ) return f diff --git a/tests/test_graph/test_utils.py b/tests/test_graph/test_utils.py index 217ebdf..0b070b1 100644 --- a/tests/test_graph/test_utils.py +++ b/tests/test_graph/test_utils.py @@ -336,7 +336,7 @@ def test_with_scalar_receptive_field_sizes(self, sequential_network): r_min, r_max = input_resolution_range(sequential_network) assert len(r_max) == 2 assert len(r_min) == 2 - assert r_min == (13, 13) + assert r_min == (11, 11) assert r_max == (13, 13) def test_with_non_sequential(self, nonsequential_network2): @@ -350,7 +350,7 @@ def test_with_non_square_receptive_field_sizes(self, sequential_network_non_squa r_min, r_max = input_resolution_range(sequential_network_non_square) assert len(r_max) == 2 assert len(r_min) == 2 - assert r_min == (13, 25) + assert r_min == (11, 21) assert r_max == (13, 25) def test_with_non_square_receptive_field_sizes_without_se( @@ -358,9 +358,7 @@ def test_with_non_square_receptive_field_sizes_without_se( ): model = torchvision.models.resnet50() graph = create_graph_from_pytorch_model(model) - min_res, max_res = input_resolution_range( - graph, filter_all_inf_rf=False - ) # (75, 75), (427, 427) + min_res, max_res = input_resolution_range(graph) # (75, 75), (427, 427) assert len(min_res) == 2 assert len(max_res) == 2 assert min_res == (75, 75)