In [None]:
# | default_exp transforms/clipping

# Imports

In [None]:
# | export


from collections.abc import Hashable

import torch
from monai.config import KeysCollection
from monai.transforms import Transform

# Transforms

In [None]:
# | export


class Clip(Transform):
    def __init__(self, min_value: float, max_value: float):
        super().__init__()

        self.min_value = min_value
        self.max_value = max_value

    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        """
        Clip the input tensor to the specified range.

        Args:
            data: Input tensor to be clipped.

        Returns:
            Clipped tensor.
        """
        return torch.clamp(data, self.min_value, self.max_value)

    def __repr__(self) -> str:
        return f"Clip(min_value={self.min_value}, max_value={self.max_value})"

In [None]:
data = torch.randn(3, 4, 5) * 10
transform = Clip(0, 1)
display(transform)

print(data.min(), data.max())

clipped_data = transform(data)

print(clipped_data.min(), clipped_data.max())

[1;35mClip[0m[1m([0m[33mmin_value[0m=[1;36m0[0m, [33mmax_value[0m=[1;36m1[0m[1m)[0m

tensor(-22.1990) tensor(22.3186)
tensor(0.) tensor(1.)


In [None]:
# | export


class Clipd(Transform):
    def __init__(self, keys: KeysCollection, min_value: float, max_value: float):
        super().__init__()

        self.keys = keys
        self.transform = Clip(min_value, max_value)

    def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
        """
        Clip the input tensor to the specified range.

        Args:
            data: Input tensor to be clipped.

        Returns:
            Clipped tensor.
        """
        for key in self.keys:
            if key in data:
                data[key] = self.transform(data[key])
            else:
                raise KeyError(f"Key {key} not found in input data.")

        return data

    def __repr__(self) -> str:
        return f"Clipd(keys={self.keys}, min_value={self.transform.min_value}, max_value={self.transform.max_value})"

In [None]:
data = {"images": torch.randn(3, 4, 5) * 10}
transform = Clipd(["images"], 0, 1)
display(transform)

print(data["images"].min(), data["images"].max())

clipped_data = transform(data)

print(clipped_data["images"].min(), clipped_data["images"].max())

[1;35mClipd[0m[1m([0m[33mkeys[0m=[1m[[0m[32m'images'[0m[1m][0m, [33mmin_value[0m=[1;36m0[0m, [33mmax_value[0m=[1;36m1[0m[1m)[0m

tensor(-19.7777) tensor(18.1412)
tensor(0.) tensor(1.)


# nbdev

In [None]:
!nbdev_export