Skip to content

Commit

Permalink
[DLMED] update inverse and spatial_pad
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma committed Jun 22, 2022
1 parent 000f035 commit 63e36b6
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 98 deletions.
55 changes: 24 additions & 31 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.config import IndexSelection
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Randomizable, Transform
Expand Down Expand Up @@ -89,6 +90,8 @@ class Pad(InvertibleTransform):
"""

backend = [TransformBackends.TORCH]

def __init__(
self,
to_pad: Optional[List[Tuple[int, int]]] = None,
Expand Down Expand Up @@ -137,13 +140,16 @@ def __call__(
img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0)

if get_track_meta():
spatial_rank = max(len(img_t.affine) - 1, 1)
to_shift = [-s[0] for s in to_pad_[1:]] # skipping the channel pad
mat = create_translate(spatial_rank, to_shift)
img_t.meta["affine"] = img_t.affine @ convert_to_dst_type(mat, img_t.affine)[0]
self.push_transform(img_t, extra_info={"padded": to_pad})
self._update_meta(tensor=img_t, to_pad=to_pad_)
self.push_transform(img_t, extra_info={"padded": to_pad_})
return img_t

def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]):
spatial_rank = max(len(tensor.affine) - 1, 1)
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
mat = create_translate(spatial_rank, to_shift)
tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
padded = transform[TraceKeys.EXTRA_INFO]["padded"]
Expand All @@ -158,16 +164,10 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
return cropper(data)


class SpatialPad(Transform):
class SpatialPad(Pad):
"""
Performs padding to the data, symmetric for all sides or all on one side for each dimension.
If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used.
Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary).
Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad
for additional details.
Args:
spatial_size: the spatial size of output data after padding, if a dimension of the input
data size is bigger than the pad size, will not pad that dimension.
Expand All @@ -176,30 +176,24 @@ class SpatialPad(Transform):
`spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30].
method: {``"symmetric"``, ``"end"``}
Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``.
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html.
default to `self.mode`.
kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`.
"""

backend = Pad.backend

def __init__(
self,
spatial_size: Union[Sequence[int], int],
method: Union[Method, str] = Method.SYMMETRIC,
mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT,
mode: Union[PytorchPadMode, str] = NumpyPadMode.CONSTANT,
**kwargs,
) -> None:
self.spatial_size = spatial_size
self.method: Method = look_up_option(method, Method)
self.mode = mode
self.kwargs = kwargs
super().__init__(mode=mode, **kwargs)

def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]:
spatial_size = fall_back_tuple(self.spatial_size, data_shape)
Expand All @@ -212,8 +206,11 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int
return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]

def __call__(
self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None
) -> NdarrayOrTensor:
self,
img: torch.Tensor,
mode: Optional[Union[PytorchPadMode, str]] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand All @@ -228,12 +225,8 @@ def __call__(
"""
data_pad_width = self._determine_data_pad_width(img.shape[1:])
all_pad_width = [(0, 0)] + data_pad_width
if not np.asarray(all_pad_width).any():
# all zeros, skip padding
return img

padder = Pad(to_pad=all_pad_width, mode=mode or self.mode, **self.kwargs)
return padder(img)
return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs)


class BorderPad(Transform):
Expand Down
Loading

0 comments on commit 63e36b6

Please sign in to comment.