Skip to content

Commit

Permalink
Merge pull request #151 from ZettaAI/sergiy/fix_convblock
Browse files Browse the repository at this point in the history
fix: make skips `Dict[str, int]` to allow deserialized dicts
  • Loading branch information
dodamih committed Dec 7, 2022
2 parents ef77b6e + d5522c1 commit 2b75dce
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
4 changes: 3 additions & 1 deletion tests/unit/convnet/test_convblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def not_test_forward_naive(mocker):

def test_forward_skips(mocker):
mocker.patch("torch.nn.Conv2d.forward", lambda _, x: x)
block = convnet.architecture.ConvBlock(num_channels=[1, 2, 3, 4, 5], skips={0: 2, 1: 2, 2: 3})
block = convnet.architecture.ConvBlock(
num_channels=[1, 2, 3, 4, 5], skips={"0": 2, "1": 2, "2": 3}
)
result = block(torch.ones([1, 1, 1, 1]))
assert_array_equal(
result.cpu().detach().numpy(), 6 * torch.ones([1, 1, 1, 1]).cpu().detach().numpy()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/convnet/test_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_forward_skips(mocker):
list_num_channels=[[1, 1, 1], [1, 1, 1, 1], [1, 1, 1]],
downsample=partial(torch.nn.AvgPool2d, kernel_size=2),
upsample=partial(torch.nn.Upsample, scale_factor=2),
skips=[{0: 2}, {0: 2, 1: 3}, {0: 2}],
skips=[{"0": 2}, {"0": 2, "1": 3}, {"0": 2}],
)
result = unet.forward(torch.ones([1, 1, 2, 2]))
assert_array_equal(
Expand Down
12 changes: 6 additions & 6 deletions zetta_utils/convnet/architecture/convblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ConvBlock(nn.Module):
corresponding convolution in order. The list length must match the number of
convolutions.
:param skips: Specification for residual skip connection. For example,
``skips={1: 3}`` specifies a single residual skip connection from the output of the
``skips={"1": 3}`` specifies a single residual skip connection from the output of the
first convolution (index 1) to the input of third convolution (index 3).
0 specifies the input to the first layer.
:param normalize_last: Whether to apply normalization after the last layer.
Expand All @@ -61,7 +61,7 @@ def __init__(
kernel_sizes: Union[int, Tuple[int, ...], List[Union[int, Tuple[int, ...]]]] = 3,
strides: Union[int, Tuple[int, ...], List[Union[int, Tuple[int, ...]]]] = 1,
paddings: Union[Padding, List[Padding]] = "same",
skips: Optional[Dict[int, int]] = None,
skips: Optional[Dict[str, int]] = None,
normalize_last: bool = False,
activate_last: bool = False,
): # pylint: disable=too-many-locals
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
def forward(self, data: torch.Tensor) -> torch.Tensor:
skip_data_for = {} # type: Dict[int, torch.Tensor]
conv_count = 1
if 0 in self.skips:
skip_dest = self.skips[0]
if "0" in self.skips:
skip_dest = self.skips["0"]
skip_data_for[skip_dest] = data
result = data
for this_layer, next_layer in zip(self.layers, self.layers[1:] + [None]):
Expand All @@ -123,8 +123,8 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
result = this_layer(result)

if isinstance(next_layer, torch.nn.modules.conv._ConvNd):
if conv_count in self.skips:
skip_dest = self.skips[conv_count]
if str(conv_count) in self.skips:
skip_dest = self.skips[str(conv_count)]
if skip_dest in skip_data_for:
skip_data_for[skip_dest] += result
else:
Expand Down
4 changes: 2 additions & 2 deletions zetta_utils/convnet/architecture/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
List[Union[int, Tuple[int, ...], List[Union[int, Tuple[int, ...]]]]],
] = 1,
paddings: Union[Padding, List[Union[Padding, List[Padding]]]] = "same",
skips: Optional[List[Dict[int, int]]] = None,
skips: Optional[List[Dict[str, int]]] = None,
normalize_last: bool = False,
activate_last: bool = False,
): # pylint: disable=too-many-locals, too-many-statements, too-many-branches
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
paddings_ = [paddings for _ in range(len(list_num_channels))]

if isinstance(skips, list):
skips_ = skips # type: List[Dict[int, int]]
skips_ = skips # type: List[Dict[str, int]]
assert len(skips_) == len(list_num_channels)
else:
skips_ = [{} for _ in range(len(list_num_channels))]
Expand Down

0 comments on commit 2b75dce

Please sign in to comment.