Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def forward(
"""
# 1. Check if input arguments are valid
if self.training:
check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
targets = check_training_targets(
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
)
self._check_detector_training_components()

# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
Expand Down Expand Up @@ -877,7 +879,7 @@ def get_cls_train_sample_per_image(

foreground_idxs_per_image = matched_idxs_per_image >= 0

num_foreground = foreground_idxs_per_image.sum()
num_foreground = int(foreground_idxs_per_image.sum())
num_gt_box = targets_per_image[self.target_box_key].shape[0]

if self.debug:
Expand Down
23 changes: 23 additions & 0 deletions monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
convert_box_to_standard_mode,
get_spatial_dims,
spatial_crop_boxes,
standardize_empty_box,
)
from monai.transforms import Rotate90, SpatialCrop
from monai.transforms.transform import Transform
Expand All @@ -46,6 +47,7 @@
)

__all__ = [
"StandardizeEmptyBox",
"ConvertBoxToStandardMode",
"ConvertBoxMode",
"AffineBox",
Expand All @@ -60,6 +62,27 @@
]


class StandardizeEmptyBox(Transform):
"""
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).

Args:
spatial_dims: number of spatial dimensions of the bounding boxes.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, spatial_dims: int) -> None:
self.spatial_dims = spatial_dims

def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray.
"""
return standardize_empty_box(boxes, spatial_dims=self.spatial_dims)


class ConvertBoxMode(Transform):
"""
This transform converts the boxes in src_mode to the dst_mode.
Expand Down
49 changes: 49 additions & 0 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MaskToBox,
RotateBox90,
SpatialCropBox,
StandardizeEmptyBox,
ZoomBox,
)
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
Expand All @@ -51,6 +52,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_tensor

__all__ = [
"StandardizeEmptyBoxd",
"StandardizeEmptyBoxD",
"StandardizeEmptyBoxDict",
"ConvertBoxModed",
"ConvertBoxModeD",
"ConvertBoxModeDict",
Expand Down Expand Up @@ -95,6 +99,50 @@
DEFAULT_POST_FIX = PostFix.meta()


class StandardizeEmptyBoxd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`.

When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).

Example:
.. code-block:: python

data = {"boxes": torch.ones(0,), "image": torch.ones(1, 128, 128, 128)}
box_converter = StandardizeEmptyBoxd(box_keys=["boxes"], box_ref_image_keys="image")
box_converter(data)
"""

def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None:
"""
Args:
box_keys: Keys to pick data for transformation.
box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.
allow_missing_keys: don't raise exception if key is missing.

See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`
"""
super().__init__(box_keys, allow_missing_keys)
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
if len(box_ref_image_keys_tuple) > 1:
raise ValueError(
"Please provide a single key for box_ref_image_keys.\
All boxes of box_keys are attached to box_ref_image_keys."
)
self.box_ref_image_keys = box_ref_image_keys

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
spatial_dims = len(d[self.box_ref_image_keys].shape) - 1
self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
return dict(data)


class ConvertBoxModed(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`.
Expand Down Expand Up @@ -1353,3 +1401,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld
RotateBox90D = RotateBox90Dict = RotateBox90d
RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d
StandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd
29 changes: 23 additions & 6 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.data.box_utils import standardize_empty_box
from monai.transforms.croppad.array import SpatialPad
from monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode
from monai.utils import PytorchPadMode, ensure_tuple_rep
Expand Down Expand Up @@ -56,7 +58,7 @@ def check_training_targets(
spatial_dims: int,
target_label_key: str,
target_box_key: str,
) -> None:
) -> list[dict[str, Tensor]]:
"""
Validate the input images/targets during training (raise a `ValueError` if invalid).

Expand All @@ -75,7 +77,8 @@ def check_training_targets(
if len(input_images) != len(targets):
raise ValueError(f"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.")

for target in targets:
for i in range(len(targets)):
target = targets[i]
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
raise ValueError(
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
Expand All @@ -85,10 +88,24 @@ def check_training_targets(
if not isinstance(boxes, torch.Tensor):
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims:
raise ValueError(
f"Expected target boxes to be a tensor " f"of shape [N, {2* spatial_dims}], got {boxes.shape}."
)
return
if boxes.numel() == 0:
warnings.warn(
f"Warning: Given target boxes has shape of {boxes.shape}. "
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
)
else:
raise ValueError(
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
)
if not torch.is_floating_point(boxes):
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore

labels = target[target_label_key]
if torch.is_floating_point(labels):
warnings.warn(f"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.")
targets[i][target_label_key] = labels.long()
return targets


def pad_images(
Expand Down
59 changes: 56 additions & 3 deletions monai/data/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,41 @@ def get_spatial_dims(

# Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set
if boxes is not None:
if int(boxes.shape[1]) not in [4, 6]:
if len(boxes.shape) != 2:
if boxes.shape[0] == 0:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], "
f"got boxes with shape {boxes.shape}. "
f"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])."
)
else:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
spatial_dims_set.add(int(boxes.shape[1] / 2))
if points is not None:
if len(points.shape) != 2:
if points.shape[0] == 0:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], "
f"got points with shape {points.shape}. "
f"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3])."
)
else:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got boxes with shape {points.shape}."
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
spatial_dims_set.add(int(points.shape[1]))
if corners is not None:
if len(corners) not in [4, 6]:
if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}."
)
Expand Down Expand Up @@ -494,6 +516,33 @@ def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwar
return StandardMode(*args, **kwargs)


def standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
"""
When boxes are empty, this function standardize it to shape of (0,4) or (0,6).

Args:
boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray
spatial_dims: number of spatial dimensions of the bounding boxes.

Returns:
bounding boxes with shape (N,4) or (N,6), N can be 0.

Example:
.. code-block:: python

boxes = torch.ones(0,)
standardize_empty_box(boxes, 3)
"""
# convert numpy to tensor if needed
boxes_t, *_ = convert_data_type(boxes, torch.Tensor)
# handle empty box
if boxes_t.shape[0] == 0:
boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2])
# convert tensor back to numpy if needed
boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
return boxes_dst


def convert_box_mode(
boxes: NdarrayOrTensor,
src_mode: str | BoxMode | type[BoxMode] | None = None,
Expand Down Expand Up @@ -522,6 +571,10 @@ def convert_box_mode(
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
"""
# handle empty box
if boxes.shape[0] == 0:
return boxes

src_boxmode = get_boxmode(src_mode)
dst_boxmode = get_boxmode(dst_mode)

Expand Down
29 changes: 15 additions & 14 deletions tests/test_retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,21 @@ def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):

detector.set_atss_matcher()
detector.set_hard_negative_sampler(10, 0.5)
gt_box_start = torch.randint(2, (3, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (3,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)
for num_gt_box in [0, 3]: # test for both empty and non-empty boxes
gt_box_start = torch.randint(2, (num_gt_box, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (num_gt_box,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

@parameterized.expand(TEST_CASES)
def test_naive_retina_detector_shape(self, input_param, input_shape):
Expand Down