Skip to content

Commit

Permalink
Color parameter for Spatter (#1305)
Browse files Browse the repository at this point in the history
* Added color as Spatter configurable parameter

* added color to get_transform_init_args_names

* added support for dict as color param for Spatter aug

* removed unnecessary quotes after black fix
  • Loading branch information
Andredance committed Oct 9, 2022
1 parent 6c16e39 commit b3b9684
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
39 changes: 35 additions & 4 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,10 @@ class Spatter(ImageOnlyTransform):
If tuple of float intensity will be sampled from range `[intensity[0], intensity[1])`. Default: (0.6).
mode (string, or list of strings): Type of corruption. Currently, supported options are 'rain' and 'mud'.
If list is provided type of corruption will be sampled list. Default: ("rain").
color (list of (r, g, b) or dict or None): Corruption elements color.
If list uses provided list as color for specified mode.
If dict uses provided color for specified mode. Color for each specified mode should be provided in dict.
If None uses default colors (rain: (238, 238, 175), mud: (20, 42, 63)).
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -2412,6 +2416,7 @@ def __init__(
cutout_threshold: ScaleFloatType = 0.68,
intensity: ScaleFloatType = 0.6,
mode: Union[str, Sequence[str]] = "rain",
color: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None,
always_apply: bool = False,
p: float = 0.5,
):
Expand All @@ -2422,10 +2427,34 @@ def __init__(
self.gauss_sigma = to_tuple(gauss_sigma, gauss_sigma)
self.intensity = to_tuple(intensity, intensity)
self.cutout_threshold = to_tuple(cutout_threshold, cutout_threshold)
self.color = (
color
if color is not None
else {
"rain": [238, 238, 175],
"mud": [20, 42, 63],
}
)
self.mode = mode if isinstance(mode, (list, tuple)) else [mode]

if len(set(self.mode)) > 1 and not isinstance(self.color, dict):
raise ValueError(f"Unsupported color: {self.color}. Please specify color for each mode (use dict for it).")

for i in self.mode:
if i not in ["rain", "mud"]:
raise ValueError(f"Unsupported color mode: {mode}. Transform supports only `rain` and `mud` mods.")
if isinstance(self.color, dict):
if i not in self.color:
raise ValueError(f"Wrong color definition: {self.color}. Color for mode: {i} not specified.")
if len(self.color[i]) != 3:
raise ValueError(
f"Unsupported color: {self.color[i]} for mode {i}. Color should be presented in RGB format."
)

if isinstance(self.color, (list, tuple)):
if len(self.color) != 3:
raise ValueError(f"Unsupported color: {self.color}. Color should be presented in RGB format.")
self.color = {self.mode[0]: self.color}

def apply(
self,
Expand All @@ -2451,6 +2480,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
sigma = random.uniform(self.gauss_sigma[0], self.gauss_sigma[1])
mode = random.choice(self.mode)
intensity = random.uniform(self.intensity[0], self.intensity[1])
color = np.array(self.color[mode]) / 255.0

liquid_layer = random_utils.normal(size=(h, w), loc=mean, scale=std)
liquid_layer = gaussian_filter(liquid_layer, sigma=sigma, mode="nearest")
Expand All @@ -2471,15 +2501,16 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
m = liquid_layer * dist
m *= 1 / np.max(m, axis=(0, 1))

drops = m[:, :, None] * np.array([238 / 255.0, 238 / 255.0, 175 / 255.0]) * intensity
drops = m[:, :, None] * color * intensity
mud = None
non_mud = None
else:
m = np.where(liquid_layer > cutout_threshold, 1, 0)
m = gaussian_filter(m.astype(np.float32), sigma=sigma, mode="nearest")
m[m < 1.2 * cutout_threshold] = 0
m = m[..., np.newaxis]
mud = m * np.array([20 / 255.0, 42 / 255.0, 63 / 255.0])

mud = m * color
non_mud = 1 - m
drops = None

Expand All @@ -2490,5 +2521,5 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
"mode": mode,
}

def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]:
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode"
def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]:
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color"
28 changes: 28 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,3 +1293,31 @@ def test_spatter_incorrect_mode(image):

message = f"Unsupported color mode: {unsupported_mode}. Transform supports only `rain` and `mud` mods."
assert str(exc_info.value).startswith(message)


@pytest.mark.parametrize(
"unsupported_color,mode,message",
[
([255, 255], "rain", "Unsupported color: [255, 255]. Color should be presented in RGB format."),
(
{"rain": [255, 255, 255]},
"mud",
"Wrong color definition: {'rain': [255, 255, 255]}. Color for mode: mud not specified.",
),
(
{"rain": [255, 255]},
"rain",
"Unsupported color: [255, 255] for mode rain. Color should be presented in RGB format.",
),
(
[255, 255, 255],
["rain", "mud"],
"Unsupported color: [255, 255, 255]. Please specify color for each mode (use dict for it).",
),
],
)
def test_spatter_incorrect_color(unsupported_color, mode, message):
with pytest.raises(ValueError) as exc_info:
A.Spatter(mode=mode, color=unsupported_color)

assert str(exc_info.value).startswith(message)

0 comments on commit b3b9684

Please sign in to comment.