Skip to content

Commit

Permalink
Check for image and masks shape equality (#1310)
Browse files Browse the repository at this point in the history
* Added check for image and masks shape equality

* added way to disable shapes check for Compose class

* made is_check_shapes parameter of class Compose

Co-authored-by: Mikhail Druzhinin <dipetm@gmail.com>
  • Loading branch information
Andredance and Dipet committed Oct 18, 2022
1 parent 7029e80 commit 3904898
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
22 changes: 21 additions & 1 deletion albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class Compose(BaseCompose):
keypoint_params (KeypointParams): Parameters for keypoints transforms
additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'}
p (float): probability of applying all list of transforms. Default: 1.0.
is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you
would like to disable this check - pass False (do it only if you are sure in your data consistency).
"""

def __init__(
Expand All @@ -137,6 +139,7 @@ def __init__(
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
additional_targets: typing.Optional[typing.Dict[str, str]] = None,
p: float = 1.0,
is_check_shapes: bool = True,
):
super(Compose, self).__init__(transforms, p)

Expand Down Expand Up @@ -172,6 +175,8 @@ def __init__(
self.is_check_args = True
self._disable_check_args_for_transforms(self.transforms)

self.is_check_shapes = is_check_shapes

@staticmethod
def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None:
for transform in transforms:
Expand Down Expand Up @@ -235,6 +240,7 @@ def _to_dict(self) -> typing.Dict[str, typing.Any]:
if keypoints_processor
else None,
"additional_targets": self.additional_targets,
"is_check_shapes": self.is_check_shapes,
}
)
return dictionary
Expand All @@ -251,6 +257,7 @@ def get_dict_with_id(self) -> typing.Dict[str, typing.Any]:
else None,
"additional_targets": self.additional_targets,
"params": None,
"is_check_shapes": self.is_check_shapes,
}
)
return dictionary
Expand All @@ -260,18 +267,28 @@ def _check_args(self, **kwargs) -> None:
checked_multi = ["masks"]
check_bbox_param = ["bboxes"]
# ["bboxes", "keypoints"] could be almost any type, no need to check them
shapes = []
for data_name, data in kwargs.items():
internal_data_name = self.additional_targets.get(data_name, data_name)
if internal_data_name in checked_single:
if not isinstance(data, np.ndarray):
raise TypeError("{} must be numpy array type".format(data_name))
shapes.append(data.shape[:2])
if internal_data_name in checked_multi:
if data:
if not isinstance(data[0], np.ndarray):
raise TypeError("{} must be list of numpy arrays".format(data_name))
shapes.append(data[0].shape[:2])
if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None:
raise ValueError("bbox_params must be specified for bbox transformations")

if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
raise ValueError(
"Height and Width of image, mask or masks should be equal. You can disable shapes check "
"by calling disable_shapes_check method of Compose class (do it only if you are sure "
"about your data consistency)."
)

@staticmethod
def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
result = {}
Expand Down Expand Up @@ -423,9 +440,12 @@ def __init__(
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
additional_targets: typing.Optional[typing.Dict[str, str]] = None,
p: float = 1.0,
is_check_shapes: bool = True,
save_key: str = "replay",
):
super(ReplayCompose, self).__init__(transforms, bbox_params, keypoint_params, additional_targets, p)
super(ReplayCompose, self).__init__(
transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes
)
self.set_deterministic(True, save_key=save_key)
self.save_key = save_key

Expand Down
23 changes: 23 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,26 @@ def test_contiguous_output(transforms):
# confirm output contiguous
assert data["image"].flags["C_CONTIGUOUS"]
assert data["mask"].flags["C_CONTIGUOUS"]


@pytest.mark.parametrize(
"targets",
[
{"image": np.ones((20, 20, 3), dtype=np.uint8), "mask": np.ones((30, 20))},
{"image": np.ones((20, 20, 3), dtype=np.uint8), "masks": [np.ones((30, 20))]},
],
)
def test_compose_image_mask_equal_size(targets):
transforms = Compose([])

with pytest.raises(ValueError) as exc_info:
transforms(**targets)

assert str(exc_info.value).startswith(
"Height and Width of image, mask or masks should be equal. "
"You can disable shapes check by calling disable_shapes_check method "
"of Compose class (do it only if you are sure about your data consistency)."
)
# test after disabling shapes check
transforms = Compose([], is_check_shapes=False)
transforms(**targets)
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def test_serialization_v2_to_dict():
"bbox_params": None,
"keypoint_params": None,
"additional_targets": {},
"is_check_shapes": True,
}


Expand Down

0 comments on commit 3904898

Please sign in to comment.