Skip to content

Commit

Permalink
fix: added more elaborate filtering of modules for PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Mar 19, 2022
1 parent ab230f1 commit 1516bb5
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 3 deletions.
8 changes: 8 additions & 0 deletions rfa_toolbox/encodings/pytorch/ingest_architecture.py
Expand Up @@ -383,6 +383,7 @@ def create_graph_from_model(
] = None,
input_res: Tuple[int, int, int, int] = (1, 3, 399, 399),
custom_layers: Optional[List[str]] = None,
display_se_modules: bool = False,
) -> EnrichedNetworkNode:
"""Create a graph of enriched network nodes from a PyTorch-Model.
Expand All @@ -405,6 +406,13 @@ def create_graph_from_model(
Returns:
The EnrichedNetworkNodeGraph
"""
custom_layers = (
["ConvNormActivation"]
if custom_layers is None
else custom_layers + ["ConvNormActivation"]
)
if not display_se_modules:
custom_layers.append("SqueezeExcitation")
filter_func = (
filter_rf
if (not isinstance(filter_rf, str) and filter_rf is not None)
Expand Down
4 changes: 4 additions & 0 deletions rfa_toolbox/encodings/pytorch/intermediate_graph.py
Expand Up @@ -11,9 +11,11 @@
AnyConv,
AnyHandler,
AnyPool,
ConvNormActivationHandler,
FlattenHandler,
FunctionalKernelHandler,
LinearHandler,
SqueezeExcitationHandler,
)
from rfa_toolbox.encodings.pytorch.substitutors import (
input_substitutor,
Expand All @@ -28,6 +30,8 @@
)

RESOLVING_STRATEGY = [
ConvNormActivationHandler(),
SqueezeExcitationHandler(),
AnyConv(),
AnyPool(),
AnyAdaptivePool(),
Expand Down
60 changes: 60 additions & 0 deletions rfa_toolbox/encodings/pytorch/layer_handlers.py
Expand Up @@ -78,6 +78,49 @@ def __call__(
)


@attrs(auto_attribs=True, frozen=True, slots=True)
class ConvNormActivationHandler(LayerInfoHandler):
"""This handler explicitly operated on the torch.nn.Conv2d-Layer."""

def can_handle(self, name: str) -> bool:
if "ConvNormActivation" in name.split(".")[-1]:
return True
else:
return False

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
module = obtain_module_with_resolvable_string(resolvable_string, model)
submodules = list(module.modules())
conv_layer = submodules[1]
activation = "" if len(submodules) < 4 else f"-{type(submodules[3]).__name__}"
kernel_size = (
conv_layer.kernel_size
# if isinstance(conv_layer.kernel_size, int)
# else conv_layer.kernel_size[0]
)
stride_size = (
conv_layer.stride
# if isinstance(conv_layer.stride, int)
# else conv_layer.stride[0]
)
filters = conv_layer.out_channels
if not isinstance(kernel_size, Sequence) and not isinstance(
kernel_size, np.ndarray
):
kernel_size_name = f"{kernel_size}x{kernel_size}"
else:
kernel_size_name = "x".join([str(k) for k in kernel_size])
final_name = f"Conv-Norm{activation} {kernel_size_name} / {stride_size}"
return LayerDefinition(
name=final_name, # f"{name} {kernel_size}x{kernel_size}",
kernel_size=kernel_size,
stride_size=stride_size,
filters=filters,
)


@attrs(auto_attribs=True, frozen=True, slots=True)
class AnyConv(Conv2d):
"""Extract layer information in convolutional layers."""
Expand Down Expand Up @@ -135,6 +178,23 @@ def __call__(
)


@attrs(auto_attribs=True, frozen=True, slots=True)
class SqueezeExcitationHandler(Conv2d):
"""Extract information from adaptive pooling layers."""

def can_handle(self, name: str) -> bool:
return "SqueezeExcitation" in name

def __call__(
self, model: torch.nn.Module, resolvable_string: str, name: str, **kwargs
) -> LayerDefinition:
kernel_size = 1
stride_size = 1
return LayerDefinition(
name=name, kernel_size=kernel_size, stride_size=stride_size
)


@attrs(auto_attribs=True, frozen=True, slots=True)
class GenericLayerTypeHandler(LayerInfoHandler):
"""Extracts information from linear (fully connected) layers."""
Expand Down
35 changes: 32 additions & 3 deletions tests/test_encodings/test_pytorch.py
Expand Up @@ -87,24 +87,53 @@ def test_make_graph_vgg19(self):
def test_make_graph_efficientnetb0_with_inf_filter(self):
model = efficientnet_b0
m = model()
graph = create_graph_from_pytorch_model(m, filter_rf="inf")
graph = create_graph_from_pytorch_model(
m, filter_rf="inf", custom_layers=[], display_se_modules=True
)
infs = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs += 1
graph = create_graph_from_pytorch_model(m, filter_rf=None)
graph = create_graph_from_pytorch_model(
m, filter_rf=None, custom_layers=[], display_se_modules=True
)
infs_nf = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs_nf += 1
assert infs_nf > infs
assert infs_nf == 661
assert infs_nf == 371
assert infs == 64

def test_make_graph_efficientnetb0_with_inf_filter_with_ops_modules(self):
model = efficientnet_b0
m = model()
graph = create_graph_from_pytorch_model(
m, filter_rf="inf", custom_layers=None, display_se_modules=False
)
infs = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs += 1
graph = create_graph_from_pytorch_model(
m, filter_rf=None, custom_layers=None, display_se_modules=False
)
infs_nf = 0
for node in graph.all_layers:
for rf in node.receptive_field_sizes:
if isinstance(rf, Sequence):
if rf == np.inf or (isinstance(rf, Sequence) and np.inf in rf):
infs_nf += 1
assert infs_nf == 0
assert infs_nf == 0
assert infs == 0

def test_make_graph_efficientnetb0_custom_func(self):
def func(x):
return x
Expand Down
13 changes: 13 additions & 0 deletions tests/test_graph/test_utils.py
Expand Up @@ -14,6 +14,19 @@
)


@pytest.fixture()
def example():
import pandas as pd

pd.DataFrame.from_dict(
{"columnInt": [1, 2, 3], "colStr": ["A", "1", "B"], "colFloat": [1.0, 2.0, 3.0]}
).to_csv("float.csv", sep=";")
yield "float.csv"
import os

os.remove("float.csv")


@pytest.fixture()
def single_node():
node0 = EnrichedNetworkNode(
Expand Down

0 comments on commit 1516bb5

Please sign in to comment.