Skip to content

Commit 8ef905b

Browse files
committed
Adds a transform to generate heatmap from landmarks
Adds a `GenerateHeatmap` transform to create gaussian response maps from landmark coordinates. This transform is implemented for both array and dictionary-based workflows. It enables the generation of heatmaps from landmark data, facilitating tasks like landmark localization and visualization. The transform supports 2D and 3D coordinates and offers options for controlling the gaussian standard deviation, spatial shape, truncation, normalization, and data type.
1 parent cf5790d commit 8ef905b

File tree

3 files changed

+375
-1
lines changed

3 files changed

+375
-1
lines changed

monai/transforms/post/array.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@
3838
remove_small_objects,
3939
)
4040
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
41-
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
41+
from monai.utils import (
42+
TransformBackends,
43+
convert_data_type,
44+
convert_to_tensor,
45+
ensure_tuple,
46+
get_equivalent_dtype,
47+
look_up_option,
48+
)
4249
from monai.utils.type_conversion import convert_to_dst_type
4350

4451
__all__ = [
@@ -54,6 +61,7 @@
5461
"SobelGradients",
5562
"VoteEnsemble",
5663
"Invert",
64+
"GenerateHeatmap",
5765
"DistanceTransformEDT",
5866
]
5967

@@ -742,6 +750,146 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
742750
return self.post_convert(out_pt, img)
743751

744752

753+
class GenerateHeatmap(Transform):
754+
"""
755+
Generate per-landmark gaussian response maps for 2D or 3D coordinates.
756+
757+
Args:
758+
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
759+
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
760+
truncate: extent, in multiples of ``sigma``, used to crop the gaussian support window.
761+
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
762+
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
763+
764+
Raises:
765+
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.
766+
767+
"""
768+
769+
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
770+
771+
def __init__(
772+
self,
773+
sigma: Sequence[float] | float = 5.0,
774+
spatial_shape: Sequence[int] | None = None,
775+
truncate: float = 3.0,
776+
normalize: bool = True,
777+
dtype: np.dtype | torch.dtype | type = np.float32,
778+
) -> None:
779+
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
780+
if any(s <= 0 for s in sigma):
781+
raise ValueError("sigma values must be positive.")
782+
self._sigma = tuple(float(s) for s in sigma)
783+
else:
784+
if float(sigma) <= 0:
785+
raise ValueError("sigma must be positive.")
786+
self._sigma = float(sigma)
787+
if truncate <= 0:
788+
raise ValueError("truncate must be positive.")
789+
self.truncate = float(truncate)
790+
self.normalize = normalize
791+
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
792+
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
793+
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
794+
795+
def __call__(
796+
self,
797+
points: NdarrayOrTensor,
798+
spatial_shape: Sequence[int] | None = None,
799+
) -> NdarrayOrTensor:
800+
original_points = points
801+
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
802+
if points_t.ndim != 2:
803+
raise ValueError("points must be a 2D array with shape (num_points, spatial_dims).")
804+
device = points_t.device
805+
num_points, spatial_dims = points_t.shape
806+
if spatial_dims not in (2, 3):
807+
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")
808+
809+
target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
810+
sigma = self._resolve_sigma(spatial_dims)
811+
radius = tuple(int(np.ceil(self.truncate * s)) for s in sigma)
812+
813+
heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)
814+
image_bounds = tuple(int(s) for s in target_shape)
815+
for idx, center in enumerate(points_t):
816+
center_vals = center.tolist()
817+
if not np.all(np.isfinite(center_vals)):
818+
continue
819+
if not self._is_inside(center_vals, image_bounds):
820+
continue
821+
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
822+
if window_slices is None:
823+
continue
824+
region = heatmap[(idx, *window_slices)]
825+
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
826+
torch.maximum(region, gaussian, out=region)
827+
if self.normalize:
828+
max_val = heatmap[idx].max()
829+
if max_val.item() > 0:
830+
heatmap[idx] /= max_val
831+
832+
target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
833+
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
834+
return converted
835+
836+
def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
837+
shape = call_shape if call_shape is not None else self.spatial_shape
838+
if shape is None:
839+
raise ValueError("spatial_shape must be provided either at construction time or call time.")
840+
shape_tuple = ensure_tuple(shape)
841+
if len(shape_tuple) != spatial_dims:
842+
if len(shape_tuple) == 1:
843+
shape_tuple = shape_tuple * spatial_dims # type: ignore
844+
else:
845+
raise ValueError("spatial_shape length must match spatial dimension of the landmarks.")
846+
return tuple(int(s) for s in shape_tuple)
847+
848+
def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
849+
if isinstance(self._sigma, tuple):
850+
if len(self._sigma) == spatial_dims:
851+
return self._sigma
852+
if len(self._sigma) == 1:
853+
return self._sigma * spatial_dims
854+
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")
855+
return (self._sigma,) * spatial_dims
856+
857+
@staticmethod
858+
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
859+
return all(0 <= c < size for c, size in zip(center, bounds))
860+
861+
def _make_window(
862+
self,
863+
center: Sequence[float],
864+
radius: tuple[int, ...],
865+
bounds: tuple[int, ...],
866+
device: torch.device,
867+
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
868+
slices: list[slice] = []
869+
coord_shifts: list[torch.Tensor] = []
870+
for dim, (c, r, size) in enumerate(zip(center, radius, bounds)):
871+
start = max(int(np.floor(c - r)), 0)
872+
stop = min(int(np.ceil(c + r)) + 1, size)
873+
if start >= stop:
874+
return None, ()
875+
slices.append(slice(start, stop))
876+
coord_shifts.append(torch.arange(start, stop, device=device, dtype=self.torch_dtype) - float(c))
877+
return tuple(slices), tuple(coord_shifts)
878+
879+
def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
880+
device = coord_shifts[0].device
881+
shape = tuple(len(axis) for axis in coord_shifts)
882+
if 0 in shape:
883+
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
884+
exponent = torch.zeros(shape, dtype=self.torch_dtype, device=device)
885+
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
886+
scaled = (shift / float(sig)) ** 2
887+
reshape_shape = [1] * len(coord_shifts)
888+
reshape_shape[dim] = shift.numel()
889+
exponent += scaled.reshape(reshape_shape)
890+
return torch.exp(-0.5 * exponent)
891+
892+
745893
class ProbNMS(Transform):
746894
"""
747895
Performs probability based non-maximum suppression (NMS) on the probabilities map via

monai/transforms/post/dictionary.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AsDiscrete,
3636
DistanceTransformEDT,
3737
FillHoles,
38+
GenerateHeatmap,
3839
KeepLargestConnectedComponent,
3940
LabelFilter,
4041
LabelToContour,
@@ -48,6 +49,7 @@
4849
from monai.transforms.utility.array import ToTensor
4950
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
5051
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
52+
from monai.utils.type_conversion import convert_to_dst_type
5153

5254
__all__ = [
5355
"ActivationsD",
@@ -95,6 +97,9 @@
9597
"DistanceTransformEDTd",
9698
"DistanceTransformEDTD",
9799
"DistanceTransformEDTDict",
100+
"GenerateHeatmapd",
101+
"GenerateHeatmapD",
102+
"GenerateHeatmapDict",
98103
]
99104

100105
DEFAULT_POST_FIX = PostFix.meta()
@@ -508,6 +513,137 @@ def __init__(self, keys: KeysCollection, output_key: str | None = None, num_clas
508513
super().__init__(keys, ensemble, output_key)
509514

510515

516+
class GenerateHeatmapd(MapTransform):
517+
"""
518+
Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
519+
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520+
"""
521+
522+
backend = GenerateHeatmap.backend
523+
524+
def __init__(
525+
self,
526+
keys: KeysCollection,
527+
sigma: Sequence[float] | float = 5.0,
528+
heatmap_keys: KeysCollection | None = None,
529+
ref_image_keys: KeysCollection | None = None,
530+
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
531+
truncate: float = 3.0,
532+
normalize: bool = True,
533+
dtype: np.dtype | type = np.float32,
534+
allow_missing_keys: bool = False,
535+
) -> None:
536+
super().__init__(keys, allow_missing_keys)
537+
self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys)
538+
self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)
539+
self.static_shapes = self._prepare_shapes(spatial_shape)
540+
self.generator = GenerateHeatmap(
541+
sigma=sigma,
542+
spatial_shape=None,
543+
truncate=truncate,
544+
normalize=normalize,
545+
dtype=dtype,
546+
)
547+
548+
def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
549+
d = dict(data)
550+
for key, out_key, ref_key, static_shape in self.key_iterator(
551+
d, self.heatmap_keys, self.ref_image_keys, self.static_shapes
552+
):
553+
points = d[key]
554+
shape = self._determine_shape(points, static_shape, d, ref_key)
555+
heatmap = self.generator(points, spatial_shape=shape)
556+
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
557+
d[out_key] = self._prepare_output(heatmap, reference)
558+
return d
559+
560+
def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
561+
if heatmap_keys is None:
562+
return tuple(f"{key}_heatmap" for key in self.keys)
563+
keys_tuple = ensure_tuple(heatmap_keys)
564+
if len(keys_tuple) == 1 and len(self.keys) > 1:
565+
keys_tuple = keys_tuple * len(self.keys)
566+
if len(keys_tuple) != len(self.keys):
567+
raise ValueError("heatmap_keys length must match keys length.")
568+
return keys_tuple
569+
570+
def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:
571+
if maybe_keys is None:
572+
return (None,) * len(self.keys)
573+
keys_tuple = ensure_tuple(maybe_keys)
574+
if len(keys_tuple) == 1 and len(self.keys) > 1:
575+
keys_tuple = keys_tuple * len(self.keys)
576+
if len(keys_tuple) != len(self.keys):
577+
raise ValueError("ref_image_keys length must match keys length when provided.")
578+
return tuple(keys_tuple)
579+
580+
def _prepare_shapes(
581+
self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None
582+
) -> tuple[tuple[int, ...] | None, ...]:
583+
if spatial_shape is None:
584+
return (None,) * len(self.keys)
585+
shape_tuple = ensure_tuple(spatial_shape)
586+
if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple):
587+
shape = tuple(int(v) for v in shape_tuple)
588+
return (shape,) * len(self.keys)
589+
if len(shape_tuple) == 1 and len(self.keys) > 1:
590+
shape_tuple = shape_tuple * len(self.keys)
591+
if len(shape_tuple) != len(self.keys):
592+
raise ValueError("spatial_shape length must match keys length when providing per-key shapes.")
593+
prepared: list[tuple[int, ...] | None] = []
594+
for item in shape_tuple:
595+
if item is None:
596+
prepared.append(None)
597+
else:
598+
dims = ensure_tuple(item)
599+
prepared.append(tuple(int(v) for v in dims))
600+
return tuple(prepared)
601+
602+
def _determine_shape(
603+
self,
604+
points: Any,
605+
static_shape: tuple[int, ...] | None,
606+
data: Mapping[Hashable, Any],
607+
ref_key: Hashable | None,
608+
) -> tuple[int, ...]:
609+
if static_shape is not None:
610+
return static_shape
611+
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
612+
if points_t.ndim != 2:
613+
raise ValueError("landmark arrays must be 2D with shape (num_points, spatial_dims).")
614+
spatial_dims = int(points_t.shape[1])
615+
if ref_key is not None and ref_key in data:
616+
return self._shape_from_reference(data[ref_key], spatial_dims)
617+
raise ValueError(
618+
"Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
619+
)
620+
621+
def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]:
622+
if isinstance(reference, MetaTensor):
623+
meta_shape = reference.meta.get("spatial_shape")
624+
if meta_shape is not None:
625+
dims = ensure_tuple(meta_shape)
626+
if len(dims) == spatial_dims:
627+
return tuple(int(v) for v in dims)
628+
return tuple(int(v) for v in reference.shape[-spatial_dims:])
629+
if hasattr(reference, "shape"):
630+
return tuple(int(v) for v in reference.shape[-spatial_dims:])
631+
raise ValueError("Reference data must define a shape attribute.")
632+
633+
def _prepare_output(self, heatmap: NdarrayOrTensor, reference: Any) -> Any:
634+
if isinstance(reference, MetaTensor):
635+
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device)
636+
converted.meta["spatial_shape"] = tuple(int(v) for v in heatmap.shape[1:])
637+
return converted
638+
if isinstance(reference, torch.Tensor):
639+
converted, _, _ = convert_to_dst_type(heatmap, reference, dtype=reference.dtype, device=reference.device)
640+
return converted
641+
return heatmap
642+
643+
644+
GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
645+
646+
511647
class ProbNMSd(MapTransform):
512648
"""
513649
Performs probability based non-maximum suppression (NMS) on the probabilities map via

0 commit comments

Comments
 (0)