In [None]:
# | default_exp transforms/croppad

# Imports

In [None]:
# | export


from collections.abc import Sequence

import numpy as np
import torch
from monai.data import MetaTensor
from monai.data.meta_obj import get_track_meta
from monai.transforms.croppad.array import RandSpatialCropSamples
from monai.transforms.croppad.dictionary import CropForegroundd, RandSpatialCropSamplesd
from monai.utils import ImageMetaKey as Key

# Utilty functions

In [None]:
# | export


def get_updated_crop_start(current_crop_start, new_crop_start):
    if not torch.is_tensor(new_crop_start):
        new_crop_start = torch.tensor(new_crop_start)

    if current_crop_start is None:
        return new_crop_start

    if not torch.is_tensor(current_crop_start):
        current_crop_start = torch.tensor(current_crop_start)

    updated_crop_start = current_crop_start + new_crop_start
    return updated_crop_start

In [None]:
# Original image size: 256x256
# First time cropping to size 100x101 with starting at (50, 60)

original_size = np.array((256, 256))

current_crop_start = None
new_crop_start = np.array((50, 60))

updated_crop_start = get_updated_crop_start(current_crop_start, new_crop_start)
updated_crop_start

[1;35mtensor[0m[1m([0m[1m[[0m[1;36m50[0m, [1;36m60[0m[1m][0m[1m)[0m

In [None]:
# Now cropping that new image to size 50x50 with starting at (10, 11)

current_crop_start = updated_crop_start
new_crop_start = np.array((10, 11))

updated_crop_start = get_updated_crop_start(current_crop_start, new_crop_start)
updated_crop_start

[1;35mtensor[0m[1m([0m[1m[[0m[1;36m60[0m, [1;36m71[0m[1m][0m[1m)[0m

# Transforms

In [None]:
# | export


class CropForegroundWithCropTrackingd(CropForegroundd):
    def __init__(
        self,
        crop_offset_key: str = "crop_offset",
        *args,
        **kwargs,
    ) -> MetaTensor:
        super().__init__(*args, **kwargs)
        self.crop_offset_key = crop_offset_key

    def __call__(self, data, *args, **kwargs):
        output = super().__call__(data, *args, **kwargs)
        crop_offset = output[self.start_coord_key]
        output[self.crop_offset_key] = get_updated_crop_start(output.get(self.crop_offset_key), crop_offset)
        return output

In [None]:
img = torch.zeros(4, 100, 100, 100)
img[0, 20:30, 30:40, 40:50] = 1
img[1, 10:20, 30:40, 40:50] = 1
img[2, 20:30, 20:30, 40:50] = 1
img[3, 20:30, 30:40, 50:60] = 1
x = {"image": img, "crop_offset": (1, 1, 1)}

CropForegroundWithCropTrackingd(keys=["image"], source_key="image")(x)




[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m[1m][0m,

         [1m[[0m[1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, 

In [None]:
# | export


class RandSpatialCropSamplesWithCropTracking(RandSpatialCropSamples):  # To return the crops along with the crop offset
    def __init__(self, crop_offset_key: str = "crop_offset", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.crop_key = crop_offset_key

    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> list[torch.Tensor]:
        """
        Apply the transform to `img`, assuming `img` is channel-first and
        cropping doesn't change the channel dim.
        """
        ret = []
        lazy_ = self.lazy if lazy is None else lazy
        for i in range(self.num_samples):
            cropped = self.cropper(img, lazy=lazy_)
            if get_track_meta():
                cropped.meta[self.crop_key] = tuple(_slice.start for _slice in self.cropper._slices)
                cropped.meta[Key.PATCH_INDEX] = i  # type: ignore
                self.push_transform(cropped, replace=True, lazy=lazy_)  # track as this class instead of RandSpatialCrop
            ret.append(cropped)
        return ret


class RandSpatialCropSamplesWithCropTrackingd(RandSpatialCropSamplesd):
    def __init__(
        self,
        keys,
        roi_size: Sequence[int] | int,
        num_samples: int,
        max_roi_size: Sequence[int] | int | None = None,
        random_center: bool = True,
        random_size: bool = False,
        allow_missing_keys: bool = False,
        lazy: bool = False,
        crop_offset_key: str = "crop_offset",
    ) -> MetaTensor:
        super().__init__(
            keys, roi_size, num_samples, max_roi_size, random_center, random_size, allow_missing_keys, lazy
        )
        self.crop_offset_key = crop_offset_key
        self.cropper = RandSpatialCropSamplesWithCropTracking(
            crop_offset_key, roi_size, num_samples, max_roi_size, random_center, random_size, lazy=lazy
        )

    def __call__(self, data, *args, **kwargs):
        output = super().__call__(data, *args, **kwargs)
        for key in self.keys:
            for o in output:
                crop_offset = o[key].meta.get(self.crop_offset_key)
                o[self.crop_offset_key] = get_updated_crop_start(o.get(self.crop_offset_key), crop_offset)
        return output

In [None]:
x = {"image": img, "crop_offset": (1, 1, 1)}

RandSpatialCropSamplesWithCropTrackingd(
    keys=["image"],
    roi_size=(50, 50, 50),
    max_roi_size=(60, 60, 60),
    num_samples=2,
    random_size=True,
)(x)[0]


[1m{[0m
    [32m'image'[0m: [1;35mmetatensor[0m[1m([0m[1m[[0m[1m[[0m[1m[[0m[1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [33m...[0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
          [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m[1m][0m,

         [1m[[0m[1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m.,  [33m...[0m, 

# nbdev

In [None]:
!nbdev_export

Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/bin/nbdev_export", line 8, in <module>
    sys.exit(nbdev_export())
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/fastcore-1.7.13-py3.10.egg/fastcore/script.py", line 121, in _f
    return tfunc(**merge(args, args_from_prog(func, xtra)))
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/nbdev/doclinks.py", line 157, in nbdev_export
    _build_modidx()
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/nbdev/doclinks.py", line 114, in _build_modidx
    res['syms'].update(_get_modidx((dest.parent/file).resolve(), code_root, nbs_path=nbs_path))
  File "/home/ubuntu/miniconda3/lib/python3.10/site-packages/nbdev/doclinks.py", line 91, in _get_modidx
    for tree in ast.parse(cell.code).body:
  File "/home/ubuntu/miniconda3/lib/python3.10/ast.py", line 50, in parse
    return compile(source, filename, mode, flags,
  File "<unknown>", line 17
    return output
    ^^^^^^
IndentationError: ex