Skip to content

Commit

Permalink
fix: added the ability to customize filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Jan 31, 2022
1 parent 8f2cefd commit ce3aeb4
Show file tree
Hide file tree
Showing 7 changed files with 502 additions and 22 deletions.
30 changes: 26 additions & 4 deletions rfa_toolbox/encodings/pytorch/ingest_architecture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from rfa_toolbox.graphs import EnrichedNetworkNode
from rfa_toolbox.graphs import (
KNOWN_FILTER_MAPPING,
EnrichedNetworkNode,
ReceptiveFieldInfo,
)

try:
import torch
Expand Down Expand Up @@ -101,6 +105,7 @@ def make_graph(
parent_dot=None,
ref_mod=None,
classes_to_not_visit=None,
filter_rf=KNOWN_FILTER_MAPPING[None],
):
"""
This code was adapted from this blog article:
Expand Down Expand Up @@ -131,6 +136,7 @@ def find_name(i, self_input, suffix=None):
ref_mod=ref_mod,
format="svg",
graph_attr={"label": self_type, "labelloc": "t"},
filter_rf=filter_rf,
)
# dot.attr('node', shape='box')

Expand Down Expand Up @@ -254,6 +260,7 @@ def is_relevant_type(t):
classes_to_visit=classes_to_visit,
classes_found=classes_found,
classes_to_not_visit=classes_to_not_visit,
filter_rf=filter_rf,
)
# creating a mapping from the c-values
# to the output of the respective subgraph
Expand Down Expand Up @@ -368,20 +375,28 @@ def is_relevant_type(t):

def create_graph_from_model(
model: torch.nn.Module,
filter_rf: Optional[
Union[
Callable[[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]],
str,
]
] = None,
input_res: Tuple[int, int, int, int] = (1, 3, 399, 399),
custom_layers: Optional[List[str]] = None,
) -> EnrichedNetworkNode:
"""Create a graph of enriched network nodes from a PyTorch-Model.
Args:
model: a PyTorch-Model.
filter_rf: a function that filters receptive field sizes.
Disabled by default.
input_res: input-tuple shape that can be processed by the model.
Needs to be a 4-Tuple of shape (batch_size,
color_channels, height, width) for CNNs.
Needs to be a 2-Tuple of shape (batch_size,
num_features) for fully connected networks.
custom_layers: Class-names of custom layers, like DropPath
or Involutions, which are not part of
or Involutions, which are not part of
torch.nn. Keep in mind that unknown layers
will defaulted to have no effect on the
receptive field size. You may need to
Expand All @@ -390,5 +405,12 @@ def create_graph_from_model(
Returns:
The EnrichedNetworkNodeGraph
"""
filter_func = (
filter_rf
if (not isinstance(filter_rf, str) and filter_rf is not None)
else KNOWN_FILTER_MAPPING[filter_rf]
)
tm = torch.jit.trace(model, (torch.randn(*input_res),))
return make_graph(tm, ref_mod=model, classes_to_not_visit=custom_layers).to_graph()
return make_graph(
tm, filter_rf=filter_func, ref_mod=model, classes_to_not_visit=custom_layers
).to_graph()
13 changes: 11 additions & 2 deletions rfa_toolbox/encodings/pytorch/intermediate_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from attr import attrib, attrs
Expand All @@ -19,7 +19,12 @@
numeric_substitutor,
output_substitutor,
)
from rfa_toolbox.graphs import EnrichedNetworkNode, LayerDefinition
from rfa_toolbox.graphs import (
KNOWN_FILTER_MAPPING,
EnrichedNetworkNode,
LayerDefinition,
ReceptiveFieldInfo,
)

RESOLVING_STRATEGY = [
AnyConv(),
Expand Down Expand Up @@ -59,6 +64,9 @@ class Digraph:
layer_substitutors: List[NodeSubstitutor] = attrib(
factory=lambda: SUBSTITUTION_STRATEGY
)
filter_rf: Callable[
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
] = KNOWN_FILTER_MAPPING[None]

def _find_predecessors(self, name: str) -> List[str]:
return [e[0] for e in self.edge_collection if e[1] == name]
Expand Down Expand Up @@ -250,5 +258,6 @@ def create_enriched_node(
name=name,
layer_info=layer_def,
predecessors=pred_nodes,
receptive_field_info_filter=self.filter_rf,
)
return node
65 changes: 54 additions & 11 deletions rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from json import loads
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tensorflow.keras.models import Model

Expand All @@ -10,7 +10,12 @@
KernelBasedHandler,
PoolingBasedHandler,
)
from rfa_toolbox.graphs import EnrichedNetworkNode, LayerDefinition
from rfa_toolbox.graphs import (
KNOWN_FILTER_MAPPING,
EnrichedNetworkNode,
LayerDefinition,
ReceptiveFieldInfo,
)

PARSERS = [
InputHandler(),
Expand Down Expand Up @@ -63,7 +68,11 @@ def obtain_layer_definition(node_dict: Dict[str, Any]) -> LayerDefinition:


def create_node_from_dict(
node_dict: Dict[str, Any], processed_nodes: Dict[str, EnrichedNetworkNode]
node_dict: Dict[str, Any],
processed_nodes: Dict[str, EnrichedNetworkNode],
filter_rf: Callable[
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
],
) -> EnrichedNetworkNode:
"""Create the node-representation of a layer.
Args:
Expand All @@ -72,6 +81,7 @@ def create_node_from_dict(
processed_nodes: a dictionary, which maps already processed nodes
to their EnrichedNetworkNode-instances, used for
obtaining predecessors
filter_rf: a function, which filters the receptive fields
"""
predecessors = (
[]
Expand All @@ -83,11 +93,19 @@ def create_node_from_dict(
)
layer_info: LayerDefinition = obtain_layer_definition(node_dict)
return EnrichedNetworkNode(
name=node_dict["name"], layer_info=layer_info, predecessors=predecessors
name=node_dict["name"],
layer_info=layer_info,
predecessors=predecessors,
receptive_field_info_filter=filter_rf,
)


def create_graph(layers: List[Dict[str, Any]]) -> EnrichedNetworkNode:
def create_graph(
layers: List[Dict[str, Any]],
filter_rf: Callable[
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
],
) -> EnrichedNetworkNode:
"""Create a graph of the model from a list of layers"""
processed_nodes: Dict[str, EnrichedNetworkNode] = {}
working_layers: List[Dict[str, Any]] = layers[:]
Expand All @@ -96,7 +114,9 @@ def create_graph(layers: List[Dict[str, Any]]) -> EnrichedNetworkNode:
working_layers, processed_nodes
)
node = create_node_from_dict(
node_dict=processable_node_dict, processed_nodes=processed_nodes
node_dict=processable_node_dict,
processed_nodes=processed_nodes,
filter_rf=filter_rf,
)
processed_nodes[node.name] = node
working_layers.remove(processable_node_dict)
Expand All @@ -105,12 +125,19 @@ def create_graph(layers: List[Dict[str, Any]]) -> EnrichedNetworkNode:
raise ValueError(f"Some nodes were left unprocessed: {working_layers}")


def model_dict_to_enriched_graph(model_dict: Dict[str, Any]) -> EnrichedNetworkNode:
def model_dict_to_enriched_graph(
model_dict: Dict[str, Any],
filter_rf: Callable[
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
],
) -> EnrichedNetworkNode:
"""Turn a dictionary extracted from the json-representation of a Keras
model into the rfa-toolbox specific graph representation.
Args:
model_dict: the json-representation of the model
model_dict: the json-representation of the model
filter_rf: a function, which filters the receptive fields in the input of a
layer.
Returns:
a node of the graph
Expand All @@ -121,7 +148,7 @@ def model_dict_to_enriched_graph(model_dict: Dict[str, Any]) -> EnrichedNetworkN
if "layers" not in layer_config:
raise AttributeError("Model-json export has no layers")
layers = layer_config["layers"]
graph: EnrichedNetworkNode = create_graph(layers)
graph: EnrichedNetworkNode = create_graph(layers, filter_rf)
return graph


Expand All @@ -130,10 +157,26 @@ def keras_model_to_dict(model: Model) -> Dict[str, Any]:
return loads(model.to_json())


def create_graph_from_model(model: Model) -> EnrichedNetworkNode:
def create_graph_from_model(
model: Model,
filter_rf: Optional[
Union[
Callable[[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]],
str,
]
] = None,
) -> EnrichedNetworkNode:
"""Create a graph model from tensorflow
Args:
model: the model, thus must be a Keras-model.
filter_rf: a function, which filters the receptive fields, which should be
considered for computing minimum and maximum receptive field sizes.
By default, not filtering is done.
"""
model_dict = keras_model_to_dict(model)
return model_dict_to_enriched_graph(model_dict)
callable_filter = (
filter_rf
if (not isinstance(filter_rf, str) and filter_rf is not None)
else KNOWN_FILTER_MAPPING[filter_rf]
)
return model_dict_to_enriched_graph(model_dict, filter_rf=callable_filter)
29 changes: 26 additions & 3 deletions rfa_toolbox/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,34 @@ def naive_minmax_filter(


def noop_filter(
info: Tuple["ReceptiveFieldInfo"],
) -> Tuple["ReceptiveFieldInfo"]:
info: Tuple["ReceptiveFieldInfo", ...],
) -> Tuple["ReceptiveFieldInfo", ...]:
return info


def filter_all_rf_info_with_infinite_receptive_field(
info: Tuple["ReceptiveFieldInfo", ...],
) -> Tuple["ReceptiveFieldInfo", ...]:
result = list()
for rf_info in info:
if isinstance(rf_info.receptive_field, Sequence):
if not np.isinf(rf_info.receptive_field).any():
result.append(rf_info)
else:
if not np.isinf(rf_info.receptive_field):
result.append(rf_info)
if result:
return tuple(result)
else:
return info


KNOWN_FILTER_MAPPING = {
"inf": filter_all_rf_info_with_infinite_receptive_field,
None: noop_filter,
}


@attrs(auto_attribs=True, frozen=True, slots=True)
class ReceptiveFieldInfo:
"""The container holding information for the successive receptive
Expand Down Expand Up @@ -267,7 +290,7 @@ class EnrichedNetworkNode(Node):
receptive_field_max: int = attrib(init=False)

receptive_field_info_filter: Callable[
[Tuple[ReceptiveFieldInfo]], Tuple[ReceptiveFieldInfo]
[Tuple[ReceptiveFieldInfo, ...]], Tuple[ReceptiveFieldInfo, ...]
] = noop_filter
all_layers: List["EnrichedNetworkNode"] = attrib(init=False)

Expand Down
36 changes: 36 additions & 0 deletions tests/test_encodings/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Sequence

import numpy as np
import pytest
import torch
from torchvision.models import efficientnet_b0
from torchvision.models.alexnet import alexnet
from torchvision.models.inception import inception_v3
from torchvision.models.mnasnet import mnasnet1_3
from torchvision.models.resnet import resnet18, resnet152
from torchvision.models.vgg import vgg19

from rfa_toolbox import create_graph_from_pytorch_model
from rfa_toolbox.encodings.pytorch.ingest_architecture import make_graph
from rfa_toolbox.encodings.pytorch.intermediate_graph import Digraph
from rfa_toolbox.graphs import EnrichedNetworkNode, LayerDefinition
Expand Down Expand Up @@ -79,6 +84,37 @@ def test_make_graph_vgg19(self):
assert len(output_node.all_layers) == 46
assert isinstance(output_node, EnrichedNetworkNode)

def test_make_graph_efficientnetb0_with_inf_filter(self):
model = efficientnet_b0
m = model()
graph = create_graph_from_pytorch_model(m, filter_rf="inf")
infs = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs += 1
graph = create_graph_from_pytorch_model(m, filter_rf=None)
infs_nf = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs_nf += 1
assert infs_nf > infs
assert infs_nf == 661
assert infs == 64

def test_make_graph_efficientnetb0_custom_func(self):
def func(x):
return x

model = efficientnet_b0
m = model()
graph = create_graph_from_pytorch_model(m, filter_rf=func)
for node in graph.all_layers:
assert node.receptive_field_info_filter == func


class SomeModule(torch.nn.Module):
def __init__(self):
Expand Down
21 changes: 20 additions & 1 deletion tests/test_encodings/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rfa_toolbox.encodings.tensorflow_keras.ingest_architecture import (
create_graph_from_model,
)
from rfa_toolbox.graphs import EnrichedNetworkNode
from rfa_toolbox.graphs import KNOWN_FILTER_MAPPING, EnrichedNetworkNode


class TestKerasEncoding:
Expand All @@ -28,3 +28,22 @@ def test_inceptionv3(self):
graph = create_graph_from_model(model)
assert isinstance(graph, EnrichedNetworkNode)
assert len(graph.all_layers) == len(loads(model.to_json())["config"]["layers"])

def test_inceptionv3_with_infinite_rf_filter(self):
model = InceptionV3(weights=None)
graph = create_graph_from_model(model, filter_rf="inf")
for node in graph.all_layers:
assert node.receptive_field_info_filter == KNOWN_FILTER_MAPPING["inf"]
graph = create_graph_from_model(model, filter_rf=None)
for node in graph.all_layers:
assert node.receptive_field_info_filter == KNOWN_FILTER_MAPPING[None]

def test_inceptionv3_with_custom_rf_filter(self):
model = InceptionV3(weights=None)

def func(x):
return x

graph = create_graph_from_model(model, filter_rf=func)
for node in graph.all_layers:
assert node.receptive_field_info_filter == func
Loading

0 comments on commit ce3aeb4

Please sign in to comment.