Skip to content

Commit

Permalink
Torchscript fixes (#1259)
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
  • Loading branch information
ericspod committed Nov 18, 2020
1 parent 09f39dc commit 71b6794
Show file tree
Hide file tree
Showing 19 changed files with 332 additions and 173 deletions.
48 changes: 28 additions & 20 deletions monai/networks/blocks/squeeze_and_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
r: int = 2,
acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {"inplace": True}),
acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid",
add_residual: bool = False,
) -> None:
"""
Args:
Expand All @@ -51,6 +52,8 @@ def __init__(
"""
super(ChannelSELayer, self).__init__()

self.add_residual = add_residual

pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
self.avg_pool = pool_type(1) # spatial size (1, 1, ...)

Expand All @@ -74,8 +77,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
b, c = x.shape[:2]
y: torch.Tensor = self.avg_pool(x).view(b, c)
y = self.fc(y).view([b, c] + [1] * (x.ndimension() - 2))
return x * y
y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))
result = x * y

# Residual connection is moved here instead of providing an override of forward in ResidualSELayer since
# Torchscript has an issue with using super().
if self.add_residual:
result += x

return result


class ResidualSELayer(ChannelSELayer):
Expand All @@ -85,7 +95,6 @@ class ResidualSELayer(ChannelSELayer):
--+-- SE --o--
| |
+--------+
"""

def __init__(
Expand All @@ -105,21 +114,17 @@ def __init__(
acti_type_2: defaults to "relu".
See also:
:py:class:`monai.networks.blocks.ChannelSELayer`
"""
super().__init__(
spatial_dims=spatial_dims, in_channels=in_channels, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
spatial_dims=spatial_dims,
in_channels=in_channels,
r=r,
acti_type_1=acti_type_1,
acti_type_2=acti_type_2,
add_residual=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
return x + super().forward(x)


class SEBlock(nn.Module):
"""
Expand Down Expand Up @@ -196,28 +201,31 @@ def __init__(
spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
)

self.project = project
if self.project is None and in_channels != n_chns_3:
if project is None and in_channels != n_chns_3:
self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1)
elif project is None:
self.project = nn.Identity()
else:
self.project = project

self.act = None
if acti_type_final is not None:
act_final, act_final_args = split_args(acti_type_final)
self.act = Act[act_final](**act_final_args)
else:
self.act = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
residual = x if self.project is None else self.project(x)
residual = self.project(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.se_layer(x)
x += residual
if self.act is not None:
x = self.act(x)
x = self.act(x)
return x


Expand Down Expand Up @@ -358,7 +366,7 @@ def __init__(
conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
width = math.floor(planes * (base_width / 64)) * groups

super(SEResNeXtBottleneck, self).__init__(
super().__init__(
spatial_dims=spatial_dims,
in_channels=inplanes,
n_chns_1=width,
Expand Down
8 changes: 2 additions & 6 deletions monai/networks/nets/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -194,11 +194,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i

return decode

def forward(
self, x: torch.Tensor
) -> Union[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
]: # big tuple return necessary for VAE, which inherits
def forward(self, x: torch.Tensor) -> Any:
x = self.encode(x)
x = self.intermediate(x)
x = self.decode(x)
Expand Down
22 changes: 12 additions & 10 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks.layers.factories import Conv, Dropout, Norm, Pool


class _DenseLayer(nn.Sequential):
class _DenseLayer(nn.Module):
def __init__(
self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
) -> None:
Expand All @@ -38,21 +38,23 @@ def __init__(
out_channels = bn_size * growth_rate
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]
dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

self.add_module("norm1", norm_type(in_channels))
self.add_module("relu1", nn.ReLU(inplace=True))
self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
self.layers = nn.Sequential()

self.add_module("norm2", norm_type(out_channels))
self.add_module("relu2", nn.ReLU(inplace=True))
self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))
self.layers.add_module("norm1", norm_type(in_channels))
self.layers.add_module("relu1", nn.ReLU(inplace=True))
self.layers.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

self.layers.add_module("norm2", norm_type(out_channels))
self.layers.add_module("relu2", nn.ReLU(inplace=True))
self.layers.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

if dropout_prob > 0:
self.add_module("dropout", dropout_type(dropout_prob))
self.layers.add_module("dropout", dropout_type(dropout_prob))

def forward(self, x: torch.Tensor) -> torch.Tensor:
new_features = super(_DenseLayer, self).forward(x)
new_features = self.layers(x)
return torch.cat([x, new_features], 1)


Expand Down
11 changes: 9 additions & 2 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
error_msg = "length of kernel_size and strides should be the same, and no less than 3."
assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg

for idx in range(len(kernels)):
kernel, stride = kernels[idx], strides[idx]
if not isinstance(kernel, int):
Expand All @@ -115,20 +116,26 @@ def check_deep_supr_num(self):
def forward(self, x):
out = self.input_block(x)
outputs = [out]

for downsample in self.downsamples:
out = downsample(out)
outputs.append(out)
outputs.insert(0, out)

out = self.bottleneck(out)
upsample_outs = []
for upsample, skip in zip(self.upsamples, reversed(outputs)):

for upsample, skip in zip(self.upsamples, outputs):
out = upsample(out, skip)
upsample_outs.append(out)

out = self.output_block(out)

if self.training and self.deep_supervision:
start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num
upsample_outs = upsample_outs[start_output_idx:-1][::-1]
preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)]
return [out] + preds

return out

def get_input_block(self):
Expand Down
24 changes: 21 additions & 3 deletions monai/networks/nets/highresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(self.layers(x))


class ChannelPad(nn.Module):
def __init__(self, pad):
super().__init__()
self.pad = tuple(pad)

def forward(self, x):
return F.pad(x, self.pad)


class HighResBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -124,21 +133,26 @@ def __init__(
norm_type = Normalisation(norm_type)
acti_type = Activation(acti_type)

self.project, self.pad = None, None
self.project = None
self.pad = None

if in_channels != out_channels:
channel_matching = ChannelMatching(channel_matching)

if channel_matching == ChannelMatching.PROJECT:
self.project = conv_type(in_channels, out_channels, kernel_size=1)

if channel_matching == ChannelMatching.PAD:
if in_channels > out_channels:
raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.')
pad_1 = (out_channels - in_channels) // 2
pad_2 = out_channels - in_channels - pad_1
pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]
self.pad = lambda input: F.pad(input, pad)
self.pad = ChannelPad(pad)

layers = nn.ModuleList()
_in_chns, _out_chns = in_channels, out_channels

for kernel_size in kernels:
layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns))
layers.append(SUPPORTED_ACTI[acti_type](inplace=True))
Expand All @@ -148,14 +162,18 @@ def __init__(
)
)
_in_chns = _out_chns

self.layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_conv: torch.Tensor = self.layers(x)

if self.project is not None:
return x_conv + torch.as_tensor(self.project(x))
return x_conv + torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug

if self.pad is not None:
return x_conv + torch.as_tensor(self.pad(x))

return x_conv + x


Expand Down
Loading

0 comments on commit 71b6794

Please sign in to comment.