In [None]:
# | default_exp transforms/spatial

# Imports

In [None]:
# | export


import torch
from monai.data import MetaTensor
from monai.transforms.spatial.dictionary import RandFlipd, Resized

# Transforms

In [None]:
# | export


class RandFlipWithCropTrackingd(RandFlipd):
    def __init__(
        self,
        crop_offset_key: str = "crop_offset",
        original_shape_key: str = "original_shape",
        *args,
        **kwargs,
    ) -> MetaTensor:
        super().__init__(*args, **kwargs)
        self.crop_offset_key = crop_offset_key
        self.original_shape_key = original_shape_key

    def __call__(self, data, *args, **kwargs):
        output = super().__call__(data, *args, **kwargs)
        # Check if image was flipped and was cropped previously
        if self._do_transform and self.crop_offset_key in data:
            # Ensure all inputs had input of same size for consistent crop value
            shapes = [data[key].shape for key in self.keys]
            if not all(shape == shapes[0] for shape in shapes):
                raise ValueError("All inputs must have the same shape for consistent crop value")
            crop_shape = shapes[0]
            orig_shape = torch.Tensor(data[self.original_shape_key])
            # New crop value will be flipped
            crop_offset = torch.Tensor(data[self.crop_offset_key])
            # Flip crop value
            crop_offset[-1] = orig_shape[-1] - crop_offset[-1] - crop_shape[-1]
            # Update crop value
            output[self.crop_offset_key] = crop_offset
        return output

In [None]:
img = torch.rand(4, 100, 100, 100)
x = {"image": img, "crop_offset": (10, 20, 30), "original_shape": (200, 200, 200)}

RandFlipWithCropTrackingd(keys=["image"], prob=1.0)(x)


[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m1.2556e-01[0m, [1;36m9.4766e-01[0m, [1;36m7.2156e-01[0m,  [33m...[0m, [1;36m1.8215e-01[0m,
           [1;36m1.3025e-01[0m, [1;36m6.1415e-01[0m[1m][0m,
          [1m[[0m[1;36m9.9686e-01[0m, [1;36m4.6192e-02[0m, [1;36m9.0233e-01[0m,  [33m...[0m, [1;36m7.5272e-02[0m,
           [1;36m3.5425e-01[0m, [1;36m1.0359e-01[0m[1m][0m,
          [1m[[0m[1;36m2.7953e-01[0m, [1;36m4.0598e-01[0m, [1;36m7.5143e-01[0m,  [33m...[0m, [1;36m3.9831e-01[0m,
           [1;36m2.6774e-01[0m, [1;36m8.4779e-01[0m[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m3.5447e-01[0m, [1;36m7.2037e-02[0m, [1;36m3.7588e-01[0m,  [33m...[0m, [1;36m7.2441e-01[0m,
           [1;36m7.0234e-02[0m, [1;36m5.2725e-01[0m[1m][0m,
          [1m[[0m[1;36m5.7474e-01[0m, [1;36m7.7326e-01[0m, [1;36m5.4428e-02[0m,  [33m...[0m, [1;36m8.0043e-01[0m,
  

In [None]:
# | export


class ResizedWithCropTrackingd(Resized):
    def __init__(
        self,
        crop_offset_key: str = "crop_offset",
        original_shape_key: str = "original_shape",
        *args,
        **kwargs,
    ) -> MetaTensor:
        super().__init__(*args, **kwargs)
        self.crop_offset_key = crop_offset_key
        self.original_shape_key = original_shape_key

    def __call__(self, data, *args, **kwargs):
        # Check if image was cropped previously
        if self.crop_offset_key in data:
            # Ensure all inputs had input of same size for consistent crop value
            shapes = [data[key].shape for key in self.keys]
            if not all(shape == shapes[0] for shape in shapes):
                raise ValueError("All inputs must have the same shape for consistent crop value")
            old_shape = data[self.keys[0]].shape[1:]

        output = super().__call__(data, *args, **kwargs)

        # Check if image was cropped previously
        if self.crop_offset_key in data:
            # Calculate scale of resizing along every axis
            new_shape = output[self.keys[0]].shape[1:]
            scale = torch.tensor(new_shape) / torch.tensor(old_shape)

            # Update crop_offset
            crop_offset = torch.Tensor(data[self.crop_offset_key])
            output[self.crop_offset_key] = (crop_offset * scale).int()

            # Update original_shape
            if self.original_shape_key in data:
                original_shape = torch.Tensor(data[self.original_shape_key])
                output[self.original_shape_key] = (original_shape * scale).int()

        return output

In [None]:
img = torch.rand(4, 100, 100, 100)
x = {"image": img, "crop_offset": (10, 20, 30), "original_shape": (200, 200, 200)}

ResizedWithCropTrackingd(keys=["image"], spatial_size=(60, 60, 60))(x)


[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m0.2471[0m, [1;36m0.4995[0m, [1;36m0.5993[0m,  [33m...[0m, [1;36m0.6837[0m, [1;36m0.4922[0m, [1;36m0.3836[0m[1m][0m,
          [1m[[0m[1;36m0.4113[0m, [1;36m0.4792[0m, [1;36m0.4810[0m,  [33m...[0m, [1;36m0.5323[0m, [1;36m0.5230[0m, [1;36m0.4629[0m[1m][0m,
          [1m[[0m[1;36m0.6377[0m, [1;36m0.6658[0m, [1;36m0.4599[0m,  [33m...[0m, [1;36m0.5872[0m, [1;36m0.6815[0m, [1;36m0.5865[0m[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m0.5383[0m, [1;36m0.4991[0m, [1;36m0.4617[0m,  [33m...[0m, [1;36m0.5246[0m, [1;36m0.5131[0m, [1;36m0.5159[0m[1m][0m,
          [1m[[0m[1;36m0.5867[0m, [1;36m0.5324[0m, [1;36m0.5722[0m,  [33m...[0m, [1;36m0.4862[0m, [1;36m0.4494[0m, [1;36m0.6092[0m[1m][0m,
          [1m[[0m[1;36m0.5035[0m, [1;36m0.4671[0m, [1;36m0.5119[0m,  [33m...[0m, [1;36m0.4933[0m, 

# nbdev

In [None]:
!nbdev_export