|
38 | 38 | remove_small_objects,
|
39 | 39 | )
|
40 | 40 | from monai.transforms.utils_pytorch_numpy_unification import unravel_index
|
41 |
| -from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option |
| 41 | +from monai.utils import ( |
| 42 | + TransformBackends, |
| 43 | + convert_data_type, |
| 44 | + convert_to_tensor, |
| 45 | + ensure_tuple, |
| 46 | + get_equivalent_dtype, |
| 47 | + look_up_option, |
| 48 | +) |
42 | 49 | from monai.utils.type_conversion import convert_to_dst_type
|
43 | 50 |
|
44 | 51 | __all__ = [
|
|
54 | 61 | "SobelGradients",
|
55 | 62 | "VoteEnsemble",
|
56 | 63 | "Invert",
|
| 64 | + "GenerateHeatmap", |
57 | 65 | "DistanceTransformEDT",
|
58 | 66 | ]
|
59 | 67 |
|
@@ -742,6 +750,146 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
|
742 | 750 | return self.post_convert(out_pt, img)
|
743 | 751 |
|
744 | 752 |
|
| 753 | +class GenerateHeatmap(Transform): |
| 754 | + """ |
| 755 | + Generate per-landmark gaussian response maps for 2D or 3D coordinates. |
| 756 | +
|
| 757 | + Args: |
| 758 | + sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. |
| 759 | + spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. |
| 760 | + truncate: extent, in multiples of ``sigma``, used to crop the gaussian support window. |
| 761 | + normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. |
| 762 | + dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). |
| 763 | +
|
| 764 | + Raises: |
| 765 | + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. |
| 766 | +
|
| 767 | + """ |
| 768 | + |
| 769 | + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] |
| 770 | + |
| 771 | + def __init__( |
| 772 | + self, |
| 773 | + sigma: Sequence[float] | float = 5.0, |
| 774 | + spatial_shape: Sequence[int] | None = None, |
| 775 | + truncate: float = 3.0, |
| 776 | + normalize: bool = True, |
| 777 | + dtype: np.dtype | torch.dtype | type = np.float32, |
| 778 | + ) -> None: |
| 779 | + if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)): |
| 780 | + if any(s <= 0 for s in sigma): |
| 781 | + raise ValueError("sigma values must be positive.") |
| 782 | + self._sigma = tuple(float(s) for s in sigma) |
| 783 | + else: |
| 784 | + if float(sigma) <= 0: |
| 785 | + raise ValueError("sigma must be positive.") |
| 786 | + self._sigma = float(sigma) |
| 787 | + if truncate <= 0: |
| 788 | + raise ValueError("truncate must be positive.") |
| 789 | + self.truncate = float(truncate) |
| 790 | + self.normalize = normalize |
| 791 | + self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) |
| 792 | + self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) |
| 793 | + self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) |
| 794 | + |
| 795 | + def __call__( |
| 796 | + self, |
| 797 | + points: NdarrayOrTensor, |
| 798 | + spatial_shape: Sequence[int] | None = None, |
| 799 | + ) -> NdarrayOrTensor: |
| 800 | + original_points = points |
| 801 | + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) |
| 802 | + if points_t.ndim != 2: |
| 803 | + raise ValueError("points must be a 2D array with shape (num_points, spatial_dims).") |
| 804 | + device = points_t.device |
| 805 | + num_points, spatial_dims = points_t.shape |
| 806 | + if spatial_dims not in (2, 3): |
| 807 | + raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") |
| 808 | + |
| 809 | + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) |
| 810 | + sigma = self._resolve_sigma(spatial_dims) |
| 811 | + radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma) |
| 812 | + |
| 813 | + heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device) |
| 814 | + image_bounds = tuple(int(s) for s in target_shape) |
| 815 | + for idx, center in enumerate(points_t): |
| 816 | + center_vals = center.tolist() |
| 817 | + if not np.all(np.isfinite(center_vals)): |
| 818 | + continue |
| 819 | + if not self._is_inside(center_vals, image_bounds): |
| 820 | + continue |
| 821 | + window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device) |
| 822 | + if window_slices is None: |
| 823 | + continue |
| 824 | + region = heatmap[(idx, *window_slices)] |
| 825 | + gaussian = self._evaluate_gaussian(coord_shifts, sigma) |
| 826 | + torch.maximum(region, gaussian, out=region) |
| 827 | + if self.normalize: |
| 828 | + max_val = heatmap[idx].max() |
| 829 | + if max_val.item() > 0: |
| 830 | + heatmap[idx] /= max_val |
| 831 | + |
| 832 | + target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype |
| 833 | + converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) |
| 834 | + return converted |
| 835 | + |
| 836 | + def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]: |
| 837 | + shape = call_shape if call_shape is not None else self.spatial_shape |
| 838 | + if shape is None: |
| 839 | + raise ValueError("spatial_shape must be provided either at construction time or call time.") |
| 840 | + shape_tuple = ensure_tuple(shape) |
| 841 | + if len(shape_tuple) != spatial_dims: |
| 842 | + if len(shape_tuple) == 1: |
| 843 | + shape_tuple = shape_tuple * spatial_dims # type: ignore |
| 844 | + else: |
| 845 | + raise ValueError("spatial_shape length must match spatial dimension of the landmarks.") |
| 846 | + return tuple(int(s) for s in shape_tuple) |
| 847 | + |
| 848 | + def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: |
| 849 | + if isinstance(self._sigma, tuple): |
| 850 | + if len(self._sigma) == spatial_dims: |
| 851 | + return self._sigma |
| 852 | + if len(self._sigma) == 1: |
| 853 | + return self._sigma * spatial_dims |
| 854 | + raise ValueError("sigma sequence length must equal the number of spatial dimensions.") |
| 855 | + return (self._sigma,) * spatial_dims |
| 856 | + |
| 857 | + @staticmethod |
| 858 | + def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool: |
| 859 | + return all(0 <= c < size for c, size in zip(center, bounds)) |
| 860 | + |
| 861 | + def _make_window( |
| 862 | + self, |
| 863 | + center: Sequence[float], |
| 864 | + radius: tuple[int, ...], |
| 865 | + bounds: tuple[int, ...], |
| 866 | + device: torch.device, |
| 867 | + ) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]: |
| 868 | + slices: list[slice] = [] |
| 869 | + coord_shifts: list[torch.Tensor] = [] |
| 870 | + for dim, (c, r, size) in enumerate(zip(center, radius, bounds)): |
| 871 | + start = max(int(np.floor(c - r)), 0) |
| 872 | + stop = min(int(np.ceil(c + r)) + 1, size) |
| 873 | + if start >= stop: |
| 874 | + return None, () |
| 875 | + slices.append(slice(start, stop)) |
| 876 | + coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c)) |
| 877 | + return tuple(slices), tuple(coord_shifts) |
| 878 | + |
| 879 | + def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor: |
| 880 | + device = coord_shifts[0].device |
| 881 | + shape = tuple(len(axis) for axis in coord_shifts) |
| 882 | + if 0 in shape: |
| 883 | + return torch.zeros(shape, dtype=self.torch_dtype, device=device) |
| 884 | + exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device) |
| 885 | + for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)): |
| 886 | + scaled = (shift / float(sig)) ** 2 |
| 887 | + reshape_shape = [1] * len(coord_shifts) |
| 888 | + reshape_shape[dim] = shift.numel() |
| 889 | + exponent += scaled.reshape(reshape_shape) |
| 890 | + return torch.exp(-0.5 * exponent) |
| 891 | + |
| 892 | + |
745 | 893 | class ProbNMS(Transform):
|
746 | 894 | """
|
747 | 895 | Performs probability based non-maximum suppression (NMS) on the probabilities map via
|
|
0 commit comments