diff --git a/rfa_toolbox/encodings/pytorch/layer_handlers.py b/rfa_toolbox/encodings/pytorch/layer_handlers.py index c04da6a..af90537 100644 --- a/rfa_toolbox/encodings/pytorch/layer_handlers.py +++ b/rfa_toolbox/encodings/pytorch/layer_handlers.py @@ -106,12 +106,12 @@ def __call__( # else conv_layer.stride[0] ) filters = conv_layer.out_channels - if not isinstance(kernel_size, Sequence) and not isinstance( - kernel_size, np.ndarray - ): - kernel_size_name = f"{kernel_size}x{kernel_size}" - else: - kernel_size_name = "x".join([str(k) for k in kernel_size]) + kernel_size_name = ( + f"{kernel_size}x{kernel_size}" + if not isinstance(kernel_size, Sequence) + and not isinstance(kernel_size, np.ndarray) + else "x".join([str(k) for k in kernel_size]) + ) final_name = f"Conv-Norm{activation} {kernel_size_name} / {stride_size}" return LayerDefinition( name=final_name, # f"{name} {kernel_size}x{kernel_size}",