Skip to content

Commit

Permalink
fix: added the ability to handle custom layers
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Jan 27, 2022
1 parent 5b17670 commit 1a8a3db
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
58 changes: 43 additions & 15 deletions rfa_toolbox/encodings/pytorch/ingest_architecture.py
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import List, Optional, Tuple

from rfa_toolbox.graphs import EnrichedNetworkNode

Expand All @@ -19,6 +19,7 @@ def make_graph(
input_preds=None,
parent_dot=None,
ref_mod=None,
classes_to_not_visit=None,
):
"""
This code was adapted from this blog article:
Expand Down Expand Up @@ -145,20 +146,38 @@ def is_relevant_type(t):
elem = "[" + elem + "]."
pr += elem

if classes_found is not None:
classes_found.add(fq_submodule_name)
if (
classes_to_visit is None
and (
not fq_submodule_name.startswith("torch.nn")
or fq_submodule_name.startswith("torch.nn.modules.container")
def _check_white_list(submodule_type, fq_submodule_name, classes_to_visit):
return (
classes_to_visit is None
and (
not fq_submodule_name.startswith("torch.nn")
or fq_submodule_name.startswith("torch.nn.modules.container")
)
) or (
classes_to_visit is not None
and (
submodule_type in classes_to_visit
or fq_submodule_name in classes_to_visit
)
)
) or (
classes_to_visit is not None
and (
submodule_type in classes_to_visit
or fq_submodule_name in classes_to_visit

def _check_black_list(
submodule_type, fq_submodule_name, classes_to_not_visit
):
return (classes_to_not_visit is None) or (
classes_to_not_visit is not None
and (
submodule_type not in classes_to_not_visit
or fq_submodule_name not in classes_to_not_visit
)
)

if classes_found is not None:
classes_found.add(fq_submodule_name)
if _check_white_list(
submodule_type, fq_submodule_name, classes_to_visit
) and _check_black_list(
submodule_type, fq_submodule_name, classes_to_not_visit
):
# go into subgraph
sub_prefix = prefix + submodule_name + "."
Expand All @@ -177,6 +196,7 @@ def is_relevant_type(t):
parent_dot=dot,
classes_to_visit=classes_to_visit,
classes_found=classes_found,
classes_to_not_visit=classes_to_not_visit,
)
# creating a mapping from the c-values
# to the output of the respective subgraph
Expand Down Expand Up @@ -272,7 +292,9 @@ def is_relevant_type(t):


def create_graph_from_model(
model: torch.nn.Module, input_res: Tuple[int, int, int, int] = (1, 3, 399, 399)
model: torch.nn.Module,
input_res: Tuple[int, int, int, int] = (1, 3, 399, 399),
custom_layers: Optional[List[str]] = None,
) -> EnrichedNetworkNode:
"""Create a graph of enriched network nodes from a PyTorch-Model.
Expand All @@ -283,9 +305,15 @@ def create_graph_from_model(
color_channels, height, width) for CNNs.
Needs to be a 2-Tuple of shape (batch_size,
num_features) for fully connected networks.
custom_layers: Class-names of custom layers, like DropPath
or Involutions, which are not part of
torch.nn. Keep in mind that unknown layers
will defaulted to have no effect on the
receptive field size. You may need to
implement some additional layer handlers.
Returns:
The EnrichedNetworkNodeGraph
"""
tm = torch.jit.trace(model, (torch.randn(*input_res),))
return make_graph(tm, ref_mod=model).to_graph()
return make_graph(tm, ref_mod=model, classes_to_not_visit=custom_layers).to_graph()
9 changes: 5 additions & 4 deletions rfa_toolbox/encodings/pytorch/intermediate_graph.py
Expand Up @@ -160,10 +160,11 @@ def _check_for_lone_node(self, resolved_nodes: Dict[str, EnrichedNetworkNode]):
for name, node in resolved_nodes.items():
if len(node.predecessors) == 0 and len(node.succecessors) == 0:
warnings.warn(
f"Found a node with no predecessors and no successors: {name},"
f"this may be caused by some control-flow in "
f"this node disabling any processing"
f"within the node."
f"Found a node with no predecessors and no successors: "
f"'{node.layer_info.name}',"
f" this may be caused by some control-flow in "
f" this node disabling any processing"
f" within the node."
)

def to_graph(self) -> EnrichedNetworkNode:
Expand Down

0 comments on commit 1a8a3db

Please sign in to comment.