Skip to content

Commit

Permalink
feat: added compatibility for PyTorch functional
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Jan 28, 2022
1 parent b7cb73e commit 8e68160
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 63 deletions.
78 changes: 77 additions & 1 deletion rfa_toolbox/encodings/pytorch/ingest_architecture.py
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from rfa_toolbox.graphs import EnrichedNetworkNode

Expand Down Expand Up @@ -35,6 +35,62 @@ def _check_black_list(submodule_type, fq_submodule_name, classes_to_not_visit):
)


def _obtain_variable_names(graph: torch._C.Graph) -> Dict[str, str]:
result = {}
for node in graph.nodes():
x, y = str(node).split(" : ")
key, value = x, y
result[key] = value
return result


def _obtain_node_key(node: str) -> str:
return str(node).split(" : ")[0]


def _obtain_variable_from_node_string_atten(
line: str, graph_row_dict: Dict[str, str]
) -> List[str]:
row = graph_row_dict[line]
layer_call = row.split("aten::")[-1]
variable_pruned_left = layer_call.split("(")[1]
variable_pruned_right = variable_pruned_left.split("),")[0]
variable_name = variable_pruned_right.split(",")
return [x.replace(" ", "") for x in variable_name]


def _resolve_variable(
variable: str, graph_row_dict: Dict[str, str]
) -> Union[List[int], int]:
if "prim::ListConstruct" in variable:
variable_pruned_left = variable.split("prim::ListConstruct(")[1]
variable_pruned_right = variable_pruned_left.split("),")[0]
variable_name = variable_pruned_right.split(",")
result = [
_resolve_variable(graph_row_dict[x.replace(" ", "")], graph_row_dict)
for x in variable_name
]
elif "prim::Constant" in variable and "prim::Constant()" not in variable:
variable_pruned_left = variable.split("prim::Constant[value=")[1]
variable_pruned_right = variable_pruned_left.split("]")[0]
variable = int(variable_pruned_right)
result = variable
else:
result = None
return result


def _resolve_variables(variables: List[str], graph_row_dict: Dict[str, str]):
result = []
for var in variables:
if var in graph_row_dict:
var_val = _resolve_variable(graph_row_dict[var], graph_row_dict)
result.append(var_val)
else:
result.append(None)
return result


def make_graph(
mod,
classes_to_visit=None,
Expand Down Expand Up @@ -129,6 +185,8 @@ def is_relevant_type(t):
return any([is_relevant_type(tt) for tt in t.elements()])
return False

variable_names = _obtain_variable_names(graph=gr)

for n in gr.nodes():
# this seems to be uninteresting for resnet-style models
only_first_ops = {"aten::expand_as"}
Expand Down Expand Up @@ -224,6 +282,24 @@ def is_relevant_type(t):
make_edges(pr, prefix + i.debugName(), name, op)
for o in n.outputs():
preds[o] = {name}, set()
elif "aten::" in n.kind() and "pool" in n.kind() and "adaptive" not in n.kind():
name = prefix + "." + n.output().debugName()
label = n.kind().split("::")[-1]
key = _obtain_node_key(str(n))
vars = _obtain_variable_from_node_string_atten(key, variable_names)
resolved = _resolve_variables(vars, variable_names)
dot.node(
name,
label=label,
shape="box",
kernel_size=resolved[1],
stride_size=resolved[2],
)
for i in relevant_inputs:
pr, op = preds[i]
make_edges(pr, prefix + i.debugName(), name, op)
for o in n.outputs():
preds[o] = {name}, set()
else:
unseen_ops = {
"prim::ListConstruct",
Expand Down
23 changes: 19 additions & 4 deletions rfa_toolbox/encodings/pytorch/intermediate_graph.py
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from attr import attrib, attrs
Expand Down Expand Up @@ -63,13 +63,22 @@ class Digraph:
def _find_predecessors(self, name: str) -> List[str]:
return [e[0] for e in self.edge_collection if e[1] == name]

def _get_layer_definition(self, label: str) -> LayerDefinition:
def _get_layer_definition(
self,
label: str,
kernel_size: Optional[Union[Tuple[int, ...], int]] = None,
stride_size: Optional[Union[Tuple[int, ...], int]] = None,
) -> LayerDefinition:
resolvable = self._get_resolvable(label)
name = self._get_name(label)
for handler in self.layer_info_handlers:
if handler.can_handle(label):
return handler(
model=self.ref_mod, resolvable_string=resolvable, name=name
model=self.ref_mod,
resolvable_string=resolvable,
name=name,
kernel_size=kernel_size,
stride_size=stride_size,
)
raise ValueError(f"Did not find a way to handle the following layer: {name}")

Expand Down Expand Up @@ -98,6 +107,10 @@ def node(
label: Optional[str] = None,
shape: str = "box",
style: Optional[str] = None,
kernel_size: Optional[Union[Tuple[int, ...], int]] = None,
stride_size: Optional[Union[Tuple[int, ...], int]] = None,
units: Optional[int] = None,
filters: Optional[int] = None,
) -> None:
"""Creates a node in the digraph-instance.
Expand All @@ -113,7 +126,9 @@ def node(
Nothing.
"""
label = name if label is None else label
layer_definition = self._get_layer_definition(label)
layer_definition = self._get_layer_definition(
label, kernel_size=kernel_size, stride_size=stride_size
)
self.layer_definitions[name] = layer_definition

def subgraph(self, name: str) -> GraphVizDigraph:
Expand Down
65 changes: 23 additions & 42 deletions rfa_toolbox/encodings/pytorch/layer_handlers.py
@@ -1,4 +1,3 @@
import warnings
from collections import Sequence

import numpy as np
Expand Down Expand Up @@ -49,7 +48,7 @@ def can_handle(self, name: str) -> bool:
return False

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
conv_layer = obtain_module_with_resolvable_string(resolvable_string, model)
kernel_size = (
Expand Down Expand Up @@ -98,7 +97,7 @@ def can_handle(self, name: str) -> bool:
return "Pool" in working_name and "Adaptive" not in working_name

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
conv_layer = obtain_module_with_resolvable_string(resolvable_string, model)
kernel_size = conv_layer.kernel_size
Expand Down Expand Up @@ -126,7 +125,7 @@ def can_handle(self, name: str) -> bool:
return "pool" in name.lower() and "adaptive" in name.lower()

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
kernel_size = None
stride_size = 1
Expand All @@ -143,7 +142,7 @@ def can_handle(self, name: str) -> bool:
return "Linear" in name

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
kernel_size = None
stride_size = 1
Expand All @@ -168,54 +167,37 @@ class FunctionalKernelHandler(LayerInfoHandler):
"""

coerce: bool = False
default_kernel_size: int = 1
default_stride_size: int = 1

def can_handle(self, name: str) -> bool:
return "pool" in name.split(".")[-1] or "conv" in name.split(".")[-1]

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
if not self.coerce:
raise RuntimeError(
"Using the functional API of PyTorch is not "
"directly supported by this library."
"Usage of torch.function may corrupt the "
"reconstruction of the network topology."
"If you want to continue anyway use the "
"following code snipped before calling RFA-toolbox:\n"
"from rfa_toolbox.encodings.pytorch"
" import toggle_coerce_torch_functional\n"
"toggle_coerce_torch_functional(True)"
"\n\n You can also modify the same handler to"
"adjust a correct kernel and stride sizes"
)

if "(" in resolvable_string and ")" in name:
# print(result)
result = name.split("(")[-1].replace(")", "")
else:
result = f"{name.split('.')[-1]}"

warnings.warn(
"Detected a call of a kernel based layer from "
f"the functional library of PyTorch: {name}!"
" The kernel and stride size of this layer "
"cannot be correctly extracted, "
f"defaulting to kernel_size: {self.default_kernel_size} and "
f"stride_size: {self.default_stride_size}."
" Please avoid functional calls from kernel-based "
"operations, they may also corrupt the"
"compute graph use the corresponding modules from torch.nn instead!"
kernel_size = kwargs["kernel_size"]
stride_size = kwargs["stride_size"]

if not isinstance(kernel_size, Sequence) and not isinstance(
kernel_size, np.ndarray
):
kernel_size_name = f"{kernel_size}x{kernel_size}"
else:
kernel_size_name = "x".join([str(k) for k in kernel_size])
final_name = (
f"{result} {kernel_size_name} / "
f"{stride_size if isinstance(stride_size, int) else tuple(stride_size)}"
)

return LayerDefinition(
name=result
+ (
f" {self.default_kernel_size}x{self.default_kernel_size} / "
f"{self.default_stride_size} \n(functional, values assumed)"
),
kernel_size=self.default_kernel_size,
stride_size=self.default_stride_size,
name=final_name,
kernel_size=kernel_size,
stride_size=stride_size,
)


Expand All @@ -234,12 +216,11 @@ def can_handle(self, name: str) -> bool:
return True

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
kernel_size = 1
stride_size = 1
if "(" in resolvable_string and ")" in name:
# print(result)
result = name.split("(")[-1].replace(")", "")
else:
result = f"{name.split('.')[-1]}"
Expand Down
10 changes: 6 additions & 4 deletions rfa_toolbox/encodings/pytorch/utils.py
@@ -1,4 +1,4 @@
import rfa_toolbox.encodings.pytorch.intermediate_graph as ig
import warnings


def toggle_coerce_torch_functional(
Expand All @@ -24,6 +24,8 @@ def toggle_coerce_torch_functional(
the receptive field expansion.
handler_idx: index of the handler in the list of handlers
"""
ig.RESOLVING_STRATEGY[handler_idx].coerce = coerce
ig.RESOLVING_STRATEGY[handler_idx].kernel_size = kernel_size
ig.RESOLVING_STRATEGY[handler_idx].stride_size = stride_size
warnings.warn(
"This function is deprecated and is no longer needed, "
"you may remove it from your code",
DeprecationWarning,
)
16 changes: 4 additions & 12 deletions tests/test_encodings/test_pytorch.py
Expand Up @@ -6,7 +6,6 @@
from torchvision.models.resnet import resnet18, resnet152
from torchvision.models.vgg import vgg19

from rfa_toolbox.encodings.pytorch import toggle_coerce_torch_functional
from rfa_toolbox.encodings.pytorch.ingest_architecture import make_graph
from rfa_toolbox.encodings.pytorch.intermediate_graph import Digraph
from rfa_toolbox.graphs import EnrichedNetworkNode, LayerDefinition
Expand Down Expand Up @@ -62,21 +61,14 @@ def test_make_graph_resnet152(self):
assert len(output_node.all_layers) == 515
assert isinstance(output_node, EnrichedNetworkNode)

def test_make_graph_inception_v3(self):
def test_inceptionv3(self):
model = inception_v3
m = model()
tm = torch.jit.trace(m, [torch.randn(1, 3, 399, 399)])
with pytest.raises(RuntimeError):
d = make_graph(tm, ref_mod=m)
d.to_graph()

def test_inceptionv3_no_raise(self):
model = inception_v3
m = model()
toggle_coerce_torch_functional(True)
tm = torch.jit.trace(m, [torch.randn(1, 3, 399, 399)])
d = make_graph(tm, ref_mod=m)
d.to_graph()
output_node = d.to_graph()
assert len(output_node.all_layers) == 319
assert isinstance(output_node, EnrichedNetworkNode)

def test_make_graph_vgg19(self):
model = vgg19
Expand Down

0 comments on commit 8e68160

Please sign in to comment.