Skip to content

Commit

Permalink
Fix in CoarseDropout (#1689)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Apr 25, 2024
1 parent 4a72cc8 commit bd9c970
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
types: [python]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.1
rev: v0.4.2
hooks:
# Run the linter.
- id: ruff
Expand Down Expand Up @@ -74,7 +74,7 @@ repos:
hooks:
- id: pyproject-fmt
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.10.0
hooks:
- id: mypy
files: ^albumentations/
Expand Down
73 changes: 38 additions & 35 deletions albumentations/augmentations/dropout/coarse_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,47 +97,50 @@ class InitSchema(BaseTransformInitSchema):
fill_value: Union[ColorType, Literal["random"]] = Field(default=0, description="Value for dropped pixels.")
mask_fill_value: Optional[ColorType] = Field(default=None, description="Fill value for dropped pixels in mask.")

@staticmethod
def update_range(
min_value: Optional[ScalarType],
max_value: Optional[ScalarType],
default_range: Tuple[ScalarType, ScalarType],
) -> Tuple[ScalarType, ScalarType]:
if max_value is not None:
return (min_value or max_value, max_value)

return default_range

@staticmethod
# Validation for hole dimensions ranges
def validate_range(range_value: Tuple[ScalarType, ScalarType], range_name: str, minimum: float = 0) -> None:
if not minimum <= range_value[0] <= range_value[1]:
raise ValueError(
f"First value in {range_name} should be less or equal than the second value "
f"and at least {minimum}. Got: {range_value}",
)
if isinstance(range_value[0], float) and not all(0 <= x <= 1 for x in range_value):
raise ValueError(f"All values in {range_name} should be in [0, 1] range. Got: {range_value}")

@model_validator(mode="after")
def check_holes_and_dimensions(self) -> Self:
# Helper to update ranges and reset max values

def update_range(
min_value: Optional[ScalarType],
max_value: Optional[ScalarType],
default_range: Tuple[ScalarType, ScalarType],
) -> Tuple[ScalarType, ScalarType]:
if max_value is not None:
return (min_value or max_value, max_value)

return default_range

# Update ranges for holes, heights, and widths
self.num_holes_range = update_range(self.min_holes, self.max_holes, self.num_holes_range)
self.hole_height_range = update_range(self.min_height, self.max_height, self.hole_height_range)
self.hole_width_range = update_range(self.min_width, self.max_width, self.hole_width_range)

# Validation for hole dimensions ranges
def validate_range(range_value: Tuple[ScalarType, ScalarType], range_name: str, minimum: float = 0) -> None:
if not minimum <= range_value[0] <= range_value[1]:
raise ValueError(
f"First value in {range_name} should be less or equal than the second value "
f"and at least {minimum}. Got: {range_value}",
)
if isinstance(range_value[0], float) and not all(0 <= x <= 1 for x in range_value):
raise ValueError(f"All values in {range_name} should be in [0, 1] range. Got: {range_value}")

# Validate each range
validate_range(self.num_holes_range, "num_holes_range", minimum=1)
validate_range(self.hole_height_range, "hole_height_range")
validate_range(self.hole_width_range, "hole_width_range")
def check_num_holes_and_dimensions(self) -> Self:
if self.max_holes is not None:
# Update ranges for holes, heights, and widths
self.num_holes_range = self.update_range(self.min_holes, self.max_holes, self.num_holes_range)
self.validate_range(self.num_holes_range, "num_holes_range", minimum=1)

if self.max_height is not None:
self.hole_height_range = self.update_range(self.min_height, self.max_height, self.hole_height_range)
self.validate_range(self.hole_height_range, "hole_height_range")

if self.max_width is not None:
self.hole_width_range = self.update_range(self.min_width, self.max_width, self.hole_width_range)
self.validate_range(self.hole_width_range, "hole_width_range")

return self

def __init__(
self,
max_holes: Optional[int] = 8,
max_height: Optional[ScalarType] = 8,
max_width: Optional[ScalarType] = 8,
max_holes: Optional[int] = None,
max_height: Optional[ScalarType] = None,
max_width: Optional[ScalarType] = None,
min_holes: Optional[int] = None,
min_height: Optional[ScalarType] = None,
min_width: Optional[ScalarType] = None,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ ignore = [
"D417",
"D106",
"EM101",
"COM812",
]

# Allow fix for all enabled rules (when `--fix`) is provided.
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
deepdiff>=6.7.1
mypy>=1.9.0
mypy>=1.10.0
pre_commit>=3.5.0
pytest>=8.0.2
pytest_cov>=4.1.0
requests>=2.31.0
ruff>=0.4.1
ruff>=0.4.2
tomli>=2.0.1
types-pkg-resources
types-PyYAML
Expand Down
31 changes: 31 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,3 +1423,34 @@ def test_random_crop_from_borders(image, bboxes, keypoints, crop_left, crop_righ
keypoint_params=A.KeypointParams("xy"))

assert aug(image=image, mask=image, bboxes=bboxes, keypoints=keypoints)
@pytest.mark.parametrize("params, expected", [
# Default values
({}, {"num_holes_range": (1, 1), "hole_height_range": (8, 8), "hole_width_range": (8, 8)}),
# Boundary values
({"num_holes_range": (2, 3)}, {"num_holes_range": (2, 3)}),
({"hole_height_range": (0.1, 0.1)}, {"hole_height_range": (0.1, 0.1)}),
({"hole_width_range": (0.1, 0.1)}, {"hole_width_range": (0.1, 0.1)}),
# Random fill value
({"fill_value": 'random'}, {"fill_value": 'random'}),
({"fill_value": (255, 255, 255)}, {"fill_value": (255, 255, 255)}),
# Deprecated values handling
({"min_holes": 1, "max_holes": 5}, {"num_holes_range": (1, 5)}),
({"min_height": 2, "max_height": 6}, {"hole_height_range": (2, 6)}),
({"min_width": 3, "max_width": 7}, {"hole_width_range": (3, 7)}),
])
def test_coarse_dropout_functionality(params, expected):
aug = A.CoarseDropout(**params, p=1)
aug_dict = aug.to_dict()["transform"]
for key, value in expected.items():
assert aug_dict[key] == value, f"Failed on {key} with value {value}"


@pytest.mark.parametrize("params", [
({"num_holes_range": (5, 1)}), # Invalid range
({"num_holes_range": (0, 3)}), # Invalid range
({"hole_height_range": (2.1, 3)}), # Invalid type
({"hole_height_range": ('a', 'b')}), # Invalid type
])
def test_coarse_dropout_invalid_input(params):
with pytest.raises(Exception):
aug = A.CoarseDropout(**params, p=1)

0 comments on commit bd9c970

Please sign in to comment.