diff --git a/rfa_toolbox/encodings/pytorch/ingest_architecture.py b/rfa_toolbox/encodings/pytorch/ingest_architecture.py index d261cea..1e67cd3 100644 --- a/rfa_toolbox/encodings/pytorch/ingest_architecture.py +++ b/rfa_toolbox/encodings/pytorch/ingest_architecture.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import List, Optional, Tuple from rfa_toolbox.graphs import EnrichedNetworkNode @@ -19,6 +19,7 @@ def make_graph( input_preds=None, parent_dot=None, ref_mod=None, + classes_to_not_visit=None, ): """ This code was adapted from this blog article: @@ -145,20 +146,38 @@ def is_relevant_type(t): elem = "[" + elem + "]." pr += elem - if classes_found is not None: - classes_found.add(fq_submodule_name) - if ( - classes_to_visit is None - and ( - not fq_submodule_name.startswith("torch.nn") - or fq_submodule_name.startswith("torch.nn.modules.container") + def _check_white_list(submodule_type, fq_submodule_name, classes_to_visit): + return ( + classes_to_visit is None + and ( + not fq_submodule_name.startswith("torch.nn") + or fq_submodule_name.startswith("torch.nn.modules.container") + ) + ) or ( + classes_to_visit is not None + and ( + submodule_type in classes_to_visit + or fq_submodule_name in classes_to_visit + ) ) - ) or ( - classes_to_visit is not None - and ( - submodule_type in classes_to_visit - or fq_submodule_name in classes_to_visit + + def _check_black_list( + submodule_type, fq_submodule_name, classes_to_not_visit + ): + return (classes_to_not_visit is None) or ( + classes_to_not_visit is not None + and ( + submodule_type not in classes_to_not_visit + or fq_submodule_name not in classes_to_not_visit + ) ) + + if classes_found is not None: + classes_found.add(fq_submodule_name) + if _check_white_list( + submodule_type, fq_submodule_name, classes_to_visit + ) and _check_black_list( + submodule_type, fq_submodule_name, classes_to_not_visit ): # go into subgraph sub_prefix = prefix + submodule_name + "." @@ -177,6 +196,7 @@ def is_relevant_type(t): parent_dot=dot, classes_to_visit=classes_to_visit, classes_found=classes_found, + classes_to_not_visit=classes_to_not_visit, ) # creating a mapping from the c-values # to the output of the respective subgraph @@ -272,7 +292,9 @@ def is_relevant_type(t): def create_graph_from_model( - model: torch.nn.Module, input_res: Tuple[int, int, int, int] = (1, 3, 399, 399) + model: torch.nn.Module, + input_res: Tuple[int, int, int, int] = (1, 3, 399, 399), + custom_layers: Optional[List[str]] = None, ) -> EnrichedNetworkNode: """Create a graph of enriched network nodes from a PyTorch-Model. @@ -283,9 +305,15 @@ def create_graph_from_model( color_channels, height, width) for CNNs. Needs to be a 2-Tuple of shape (batch_size, num_features) for fully connected networks. + custom_layers: Class-names of custom layers, like DropPath + or Involutions, which are not part of + torch.nn. Keep in mind that unknown layers + will defaulted to have no effect on the + receptive field size. You may need to + implement some additional layer handlers. Returns: The EnrichedNetworkNodeGraph """ tm = torch.jit.trace(model, (torch.randn(*input_res),)) - return make_graph(tm, ref_mod=model).to_graph() + return make_graph(tm, ref_mod=model, classes_to_not_visit=custom_layers).to_graph() diff --git a/rfa_toolbox/encodings/pytorch/intermediate_graph.py b/rfa_toolbox/encodings/pytorch/intermediate_graph.py index 6cc5328..9caa371 100644 --- a/rfa_toolbox/encodings/pytorch/intermediate_graph.py +++ b/rfa_toolbox/encodings/pytorch/intermediate_graph.py @@ -160,10 +160,11 @@ def _check_for_lone_node(self, resolved_nodes: Dict[str, EnrichedNetworkNode]): for name, node in resolved_nodes.items(): if len(node.predecessors) == 0 and len(node.succecessors) == 0: warnings.warn( - f"Found a node with no predecessors and no successors: {name}," - f"this may be caused by some control-flow in " - f"this node disabling any processing" - f"within the node." + f"Found a node with no predecessors and no successors: " + f"'{node.layer_info.name}'," + f" this may be caused by some control-flow in " + f" this node disabling any processing" + f" within the node." ) def to_graph(self) -> EnrichedNetworkNode: