diff --git a/rfa_toolbox/utils/graph_utils.py b/rfa_toolbox/utils/graph_utils.py index 8e63e25..c346e7f 100644 --- a/rfa_toolbox/utils/graph_utils.py +++ b/rfa_toolbox/utils/graph_utils.py @@ -106,7 +106,16 @@ def filters_non_convolutional_node( are treated like a convolutional layer with a kernel and stride size of 1. """ - return [node for node in nodes if node.layer_info.kernel_size != np.inf] + result = [] + for node in nodes: + if isinstance(node.receptive_field_min, Sequence) or isinstance( + node.receptive_field_min, tuple + ): + if not np.any(np.isinf(node.receptive_field_min)): + result.append(node) + elif node.receptive_field_min != np.inf: + result.append(node) + return result def input_resolution_range(