@@ -33,13 +33,7 @@ def __init__(self, module: nn.Module) -> None:
33
33
self .module = module
34
34
35
35
def forward (self , x : torch .Tensor ) -> torch .Tensor :
36
- if self .training and torch .is_grad_enabled () and x .requires_grad :
37
- try :
38
- return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
39
- except TypeError :
40
- # Fallback for older PyTorch without `use_reentrant`
41
- return cast (torch .Tensor , checkpoint (self .module , x ))
42
- return cast (torch .Tensor , self .module (x ))
36
+ return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
43
37
44
38
45
39
class UNet (nn .Module ):
@@ -138,7 +132,6 @@ def __init__(
138
132
dropout : float = 0.0 ,
139
133
bias : bool = True ,
140
134
adn_ordering : str = "NDA" ,
141
- use_checkpointing : bool = False ,
142
135
) -> None :
143
136
super ().__init__ ()
144
137
@@ -167,7 +160,6 @@ def __init__(
167
160
self .dropout = dropout
168
161
self .bias = bias
169
162
self .adn_ordering = adn_ordering
170
- self .use_checkpointing = use_checkpointing
171
163
172
164
def _create_block (
173
165
inc : int , outc : int , channels : Sequence [int ], strides : Sequence [int ], is_top : bool
@@ -214,8 +206,6 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo
214
206
subblock: block defining the next layer in the network.
215
207
Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
216
208
"""
217
- if self .use_checkpointing :
218
- subblock = _ActivationCheckpointWrapper (subblock )
219
209
return nn .Sequential (down_path , SkipConnection (subblock ), up_path )
220
210
221
211
def _get_down_layer (self , in_channels : int , out_channels : int , strides : int , is_top : bool ) -> nn .Module :
@@ -321,5 +311,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
321
311
x = self .model (x )
322
312
return x
323
313
314
+ class CheckpointUNet (UNet ):
315
+ def _get_connection_block (self , down_path : nn .Module , up_path : nn .Module , subblock : nn .Module ) -> nn .Module :
316
+ subblock = _ActivationCheckpointWrapper (subblock )
317
+ return super ()._get_connection_block (down_path , up_path , subblock )
324
318
325
319
Unet = UNet
0 commit comments