In [None]:
# | default_exp transforms/spatial

# Imports

In [None]:
# | export


import torch
from monai.data import MetaTensor
from monai.transforms.spatial.dictionary import RandFlipd, Resized

# Transforms

In [None]:
# | export


class RandFlipWithCropTrackingd(RandFlipd):
    def __init__(
        self,
        crop_offset_key: str = "crop_offset",
        original_shape_key: str = "original_shape",
        *args,
        **kwargs,
    ) -> MetaTensor:
        super().__init__(*args, **kwargs)
        self.crop_offset_key = crop_offset_key
        self.original_shape_key = original_shape_key

    def __call__(self, data, *args, **kwargs):
        output = super().__call__(data, *args, **kwargs)
        # Check if image was flipped and was cropped previously
        if self._do_transform and self.crop_offset_key in data:
            # Ensure all inputs had input of same size for consistent crop value
            shapes = [data[key].shape for key in self.keys]
            if not all(shape == shapes[0] for shape in shapes):
                raise ValueError("All inputs must have the same shape for consistent crop value")
            crop_shape = shapes[0]
            orig_shape = torch.Tensor(data[self.original_shape_key])
            # New crop value will be flipped
            crop_offset = torch.Tensor(data[self.crop_offset_key])
            # Flip crop value
            crop_offset[-1] = orig_shape[-1] - crop_offset[-1] - crop_shape[-1]
            # Update crop value
            output[self.crop_offset_key] = crop_offset
        return output

In [None]:
img = torch.rand(4, 100, 100, 100)
x = {"image": img, "crop_offset": (10, 20, 30), "original_shape": (200, 200, 200)}

RandFlipWithCropTrackingd(keys=["image"], prob=1.0)(x)


[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m6.8545e-01[0m, [1;36m6.0399e-01[0m, [1;36m3.4056e-01[0m,  [33m...[0m, [1;36m8.5693e-01[0m,
           [1;36m3.3220e-01[0m, [1;36m7.3319e-01[0m[1m][0m,
          [1m[[0m[1;36m9.4987e-01[0m, [1;36m5.7073e-01[0m, [1;36m4.8854e-01[0m,  [33m...[0m, [1;36m3.8559e-01[0m,
           [1;36m8.2196e-01[0m, [1;36m8.8976e-01[0m[1m][0m,
          [1m[[0m[1;36m1.6620e-01[0m, [1;36m2.8657e-01[0m, [1;36m1.0281e-01[0m,  [33m...[0m, [1;36m3.9847e-01[0m,
           [1;36m4.7287e-01[0m, [1;36m6.4543e-01[0m[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m2.5086e-01[0m, [1;36m6.9413e-01[0m, [1;36m3.3679e-02[0m,  [33m...[0m, [1;36m9.3587e-01[0m,
           [1;36m1.2995e-01[0m, [1;36m7.0720e-01[0m[1m][0m,
          [1m[[0m[1;36m5.8775e-01[0m, [1;36m3.6922e-01[0m, [1;36m9.1223e-01[0m,  [33m...[0m, [1;36m7.4780e-02[0m,
  

In [None]:
# | export


class ResizedWithCropTrackingd(Resized):
    def __init__(
        self,
        crop_offset_key: str = "crop_offset",
        original_shape_key: str = "original_shape",
        *args,
        **kwargs,
    ) -> MetaTensor:
        super().__init__(*args, **kwargs)
        self.crop_offset_key = crop_offset_key
        self.original_shape_key = original_shape_key

    def __call__(self, data, *args, **kwargs):
        # Check if image was cropped previously
        if self.crop_offset_key in data:
            # Ensure all inputs had input of same size for consistent crop value
            shapes = [data[key].shape for key in self.keys]
            if not all(shape == shapes[0] for shape in shapes):
                raise ValueError("All inputs must have the same shape for consistent crop value")
            old_shape = data[self.keys[0]].shape[1:]

        output = super().__call__(data, *args, **kwargs)

        # Check if image was cropped previously
        if self.crop_offset_key in data:
            # Calculate scale of resizing along every axis
            new_shape = output[self.keys[0]].shape[1:]
            scale = torch.tensor(new_shape) / torch.tensor(old_shape)

            # Update crop_offset
            crop_offset = torch.Tensor(data[self.crop_offset_key])
            output[self.crop_offset_key] = (crop_offset * scale).int()

            # Update original_shape
            if self.original_shape_key in data:
                original_shape = torch.Tensor(data[self.original_shape_key])
                output[self.original_shape_key] = (original_shape * scale).int()

        return output

In [None]:
img = torch.rand(4, 100, 100, 100)
x = {"image": img, "crop_offset": (10, 20, 30), "original_shape": (200, 200, 200)}

ResizedWithCropTrackingd(keys=["image"], spatial_size=(60, 60, 60))(x)


[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m0.3897[0m, [1;36m0.4213[0m, [1;36m0.4188[0m,  [33m...[0m, [1;36m0.5719[0m, [1;36m0.3877[0m, [1;36m0.2938[0m[1m][0m,
          [1m[[0m[1;36m0.5080[0m, [1;36m0.3272[0m, [1;36m0.3928[0m,  [33m...[0m, [1;36m0.5251[0m, [1;36m0.4607[0m, [1;36m0.3944[0m[1m][0m,
          [1m[[0m[1;36m0.5273[0m, [1;36m0.4151[0m, [1;36m0.4860[0m,  [33m...[0m, [1;36m0.2980[0m, [1;36m0.4643[0m, [1;36m0.4867[0m[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m0.3354[0m, [1;36m0.4427[0m, [1;36m0.4709[0m,  [33m...[0m, [1;36m0.6162[0m, [1;36m0.5052[0m, [1;36m0.5085[0m[1m][0m,
          [1m[[0m[1;36m0.3427[0m, [1;36m0.4864[0m, [1;36m0.5543[0m,  [33m...[0m, [1;36m0.5740[0m, [1;36m0.5100[0m, [1;36m0.5218[0m[1m][0m,
          [1m[[0m[1;36m0.4705[0m, [1;36m0.5747[0m, [1;36m0.7474[0m,  [33m...[0m, [1;36m0.5422[0m, 

# nbdev

In [None]:
!nbdev_export