Skip to content

Commit

Permalink
fix: added more elaborate filtering of modules for PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Mar 19, 2022
1 parent 2f87346 commit 52bf78d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions rfa_toolbox/encodings/pytorch/layer_handlers.py
Expand Up @@ -106,12 +106,12 @@ def __call__(
# else conv_layer.stride[0]
)
filters = conv_layer.out_channels
if not isinstance(kernel_size, Sequence) and not isinstance(
kernel_size, np.ndarray
):
kernel_size_name = f"{kernel_size}x{kernel_size}"
else:
kernel_size_name = "x".join([str(k) for k in kernel_size])
kernel_size_name = (
f"{kernel_size}x{kernel_size}"
if not isinstance(kernel_size, Sequence)
and not isinstance(kernel_size, np.ndarray)
else "x".join([str(k) for k in kernel_size])
)
final_name = f"Conv-Norm{activation} {kernel_size_name} / {stride_size}"
return LayerDefinition(
name=final_name, # f"{name} {kernel_size}x{kernel_size}",
Expand Down

0 comments on commit 52bf78d

Please sign in to comment.