In [None]:
# | default_exp transforms/resize

# Imports

In [None]:
# | export


import torch
from monai.data import MetaTensor
from monai.transforms.spatial.array import Resize, TraceKeys

# Transforms

In [None]:
# | export


class ResizeWithSpacing(Resize):  # Turns out that meta_tensor.pixdim already tracks this and so this class is unhelpful
    def __call__(
        self,
        img: MetaTensor,
        *args,
        **kwargs,
    ) -> MetaTensor:
        old_shape = torch.tensor(img.shape[1:])  # Channel is first dim
        old_spacing = torch.tensor(img.meta["spacing"])
        new_shape = torch.tensor(self.spatial_size)
        new_spacing = old_spacing * old_shape / new_shape
        img.meta["spacing"] = new_spacing

        return super().__call__(img, *args, **kwargs)

    def inverse_transform(self, data: MetaTensor, transform) -> MetaTensor:
        new_shape = data.shape[1:]
        new_spacing = data.meta["spacing"]
        old_shape = transform[TraceKeys.ORIG_SIZE]
        old_spacing = new_spacing * new_shape / old_shape
        data.meta["spacing"] = old_spacing
        return super().inverse_transform(data, transform)

In [None]:
a = torch.randn(1, 5, 10, 15)
a = MetaTensor(a, meta={"spacing": torch.tensor([2, 2, 2])})
display(a.shape, a.meta)

transform = ResizeWithSpacing(spatial_size=(3, 10, 30))
b = transform(a)
display(b.shape, b.meta)

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m5[0m, [1;36m10[0m, [1;36m15[0m[1m][0m[1m)[0m


[1m{[0m
    [32m'spacing'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m][0m[1m)[0m,
    affine: [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m1[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m1[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m1[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m1[0m.[1m][0m[1m][0m, [33mdtype[0m=[35mtorch[0m.float64[1m)[0m,
    space: RAS
[1m}[0m

  old_spacing = torch.tensor(img.meta["spacing"])


[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m3[0m, [1;36m10[0m, [1;36m30[0m[1m][0m[1m)[0m


[1m{[0m
    [32m'spacing'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1;36m3.3333[0m, [1;36m2.0000[0m, [1;36m1.0000[0m[1m][0m[1m)[0m,
    affine: [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m1.6667[0m,  [1;36m0.0000[0m,  [1;36m0.0000[0m,  [1;36m0.3333[0m[1m][0m,
        [1m[[0m [1;36m0.0000[0m,  [1;36m1.0000[0m,  [1;36m0.0000[0m,  [1;36m0.0000[0m[1m][0m,
        [1m[[0m [1;36m0.0000[0m,  [1;36m0.0000[0m,  [1;36m0.5000[0m, [1;36m-0.2500[0m[1m][0m,
        [1m[[0m [1;36m0.0000[0m,  [1;36m0.0000[0m,  [1;36m0.0000[0m,  [1;36m1.0000[0m[1m][0m[1m][0m, [33mdtype[0m=[35mtorch[0m.float64[1m)[0m,
    space: RAS
[1m}[0m

# nbdev

In [None]:
!nbdev_export