Skip to content

Commit 42ec757

Browse files
author
Fabio Ferreira
committed
refactor: creates a subclass of UNet and overrides the get connection block method
1 parent 69540ff commit 42ec757

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

monai/networks/nets/unet.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,7 @@ def __init__(self, module: nn.Module) -> None:
3333
self.module = module
3434

3535
def forward(self, x: torch.Tensor) -> torch.Tensor:
36-
if self.training and torch.is_grad_enabled() and x.requires_grad:
37-
try:
38-
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
39-
except TypeError:
40-
# Fallback for older PyTorch without `use_reentrant`
41-
return cast(torch.Tensor, checkpoint(self.module, x))
42-
return cast(torch.Tensor, self.module(x))
36+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
4337

4438

4539
class UNet(nn.Module):
@@ -138,7 +132,6 @@ def __init__(
138132
dropout: float = 0.0,
139133
bias: bool = True,
140134
adn_ordering: str = "NDA",
141-
use_checkpointing: bool = False,
142135
) -> None:
143136
super().__init__()
144137

@@ -167,7 +160,6 @@ def __init__(
167160
self.dropout = dropout
168161
self.bias = bias
169162
self.adn_ordering = adn_ordering
170-
self.use_checkpointing = use_checkpointing
171163

172164
def _create_block(
173165
inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool
@@ -214,8 +206,6 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo
214206
subblock: block defining the next layer in the network.
215207
Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
216208
"""
217-
if self.use_checkpointing:
218-
subblock = _ActivationCheckpointWrapper(subblock)
219209
return nn.Sequential(down_path, SkipConnection(subblock), up_path)
220210

221211
def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
@@ -321,5 +311,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
321311
x = self.model(x)
322312
return x
323313

314+
class CheckpointUNet(UNet):
315+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
316+
subblock = _ActivationCheckpointWrapper(subblock)
317+
return super()._get_connection_block(down_path, up_path, subblock)
324318

325319
Unet = UNet

0 commit comments

Comments
 (0)