Skip to content

Commit

Permalink
Available keys at BasicTransform and BaseCompose (#1692)
Browse files Browse the repository at this point in the history
* docstring for BasicTransform

* targets2func base

* _targets2func

* transform fix from comments

* ToTensorV2 fix

* fix set_target

* remove target_dependence

* Compose - move calculation always_apply to init

* Revert "docstring for BasicTransform"

This reverts commit b3076e4.

* fix merge

* available_keys at compose

* fix compose, basic_transform, tests

* fix _set_keys at compose, fix tests

* _check_args, fix tests
  • Loading branch information
ayasyrev committed May 2, 2024
1 parent ad8ce3f commit ee3c634
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 42 deletions.
36 changes: 31 additions & 5 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Union, cast

import cv2
import numpy as np
Expand Down Expand Up @@ -65,7 +65,9 @@ def __init__(self, transforms: TransformsSeqType, p: float):
self.replay_mode = False
self.applied_in_replay = False
self._additional_targets: Dict[str, str] = {}
self._available_keys: Set[str] = set()
self.processors: Dict[str, Union[BboxProcessor, KeypointsProcessor]] = {}
self._set_keys()

def __iter__(self) -> Iterator[TransformType]:
return iter(self.transforms)
Expand All @@ -86,6 +88,10 @@ def __repr__(self) -> str:
def additional_targets(self) -> Dict[str, str]:
return self._additional_targets

@property
def available_keys(self) -> Set[str]:
return self._available_keys

def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")}
repr_string = self.__class__.__name__ + "(["
Expand Down Expand Up @@ -127,11 +133,22 @@ def add_targets(self, additional_targets: Optional[Dict[str, str]]) -> None:
f"Trying to overwrite existed additional targets. "
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
)
self._additional_targets.update(additional_targets)
self._additional_targets.update(additional_targets)
for t in self.transforms:
t.add_targets(additional_targets)
for proc in self.processors.values():
proc.add_targets(additional_targets)
self._set_keys()

def _set_keys(self) -> None:
"""Set _available_keys"""
for t in self.transforms:
self._available_keys.update(t.available_keys)
if self.processors:
self._available_keys.update(["labels"])
for proc in self.processors.values():
if proc.params.label_fields:
self._available_keys.update(proc.params.label_fields)

def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
for t in self.transforms:
Expand Down Expand Up @@ -192,6 +209,7 @@ def __init__(
self._disable_check_args_for_transforms(self.transforms)

self.is_check_shapes = is_check_shapes
self._always_apply = get_always_apply(self.transforms) # transforms list that always apply
self._check_each_transform = tuple( # processors that checks after each transform
proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False)
)
Expand All @@ -211,18 +229,22 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[s
if args:
msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
raise KeyError(msg)
if self.is_check_args:
self._check_args(**data)

if not isinstance(force_apply, (bool, int)):
msg = "force_apply must have bool or int type"
raise TypeError(msg)

need_to_run = force_apply or random.random() < self.p
if not need_to_run and not self._always_apply:
return data

transforms = self.transforms if need_to_run else self._always_apply

if self.is_check_args:
self._check_args(**data)

for p in self.processors.values():
p.ensure_data_valid(data)
transforms = self.transforms if need_to_run else get_always_apply(self.transforms)

for p in self.processors.values():
p.preprocess(data)
Expand Down Expand Up @@ -286,6 +308,9 @@ def _check_args(self, **kwargs: Any) -> None:
check_keypoints_param = ["keypoints"]
shapes = []
for data_name, data in kwargs.items():
if data_name not in self._available_keys and data_name not in ["mask", "masks"]:
msg = f"Key {data_name} is not in available keys."
raise ValueError(msg)
internal_data_name = self._additional_targets.get(data_name, data_name)
if internal_data_name in checked_single:
if not isinstance(data, np.ndarray):
Expand Down Expand Up @@ -493,6 +518,7 @@ def __init__(
super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes)
self.set_deterministic(True, save_key=save_key)
self.save_key = save_key
self._available_keys.add(save_key)

def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Dict[str, Any]:
kwargs[self.save_key] = defaultdict(dict)
Expand Down
61 changes: 40 additions & 21 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from warnings import warn

import cv2
Expand Down Expand Up @@ -43,8 +43,12 @@ class CombinedMeta(SerializableMeta, ValidatedTransformMeta):


class BasicTransform(Serializable, metaclass=CombinedMeta):
# `_targets` defines the types of targets (e.g., image, mask) that the transform can be applied to.
_targets: Union[Tuple[Targets, ...], Targets]
_targets: Union[Tuple[Targets, ...], Targets] # targets that this transform can work on
_available_keys: Set[str] # targets that this transform, as string, lower-cased
_key2func: Dict[
str,
Callable[..., Any],
] # mapping for targets (plus additional targets) and methods for which they depend
call_backup = None
interpolation: Union[int, Interpolation]
fill_value: ColorType
Expand All @@ -64,6 +68,8 @@ def __init__(self, always_apply: bool = False, p: float = 0.5):
self._additional_targets: Dict[str, str] = {}
# replay mode params
self.params: Dict[Any, Any] = {}
self._key2func = {}
self._set_keys()

def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any:
if args:
Expand Down Expand Up @@ -97,12 +103,11 @@ def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -
params = self.update_params(params, **kwargs)
res = {}
for key, arg in kwargs.items():
if arg is not None:
target_function = self._get_target_function(key)
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
res[key] = target_function(arg, **dict(params, **target_dependencies))
if key in self._key2func and arg is not None:
target_function = self._key2func[key]
res[key] = target_function(arg, **params)
else:
res[key] = None
res[key] = arg
return res

def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform":
Expand All @@ -125,14 +130,6 @@ def __repr__(self) -> str:
state.update(self.get_transform_init_args())
return f"{self.__class__.__name__}({format_args(state)})"

def _get_target_function(self, key: str) -> Callable[..., Any]:
"""Returns function to process target"""
transform_key = key
if key in self._additional_targets:
transform_key = self._additional_targets.get(key, key)

return self.targets.get(transform_key, lambda x, **p: x)

def apply(self, img: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
"""Apply transform on image."""
raise NotImplementedError
Expand All @@ -149,6 +146,23 @@ def targets(self) -> Dict[str, Callable[..., Any]]:
# >> {"masks": self.apply_to_masks}
raise NotImplementedError

def _set_keys(self) -> None:
"""Set _available_keys"""
if not hasattr(self, "_targets"):
self._available_keys = set()
else:
self._available_keys = {
target.value.lower()
for target in (self._targets if isinstance(self._targets, tuple) else [self._targets])
}
self._available_keys.update(self.targets.keys())
self._key2func = {key: self.targets[key] for key in self._available_keys if key in self.targets}

@property
def available_keys(self) -> Set[str]:
"""Returns set of available keys"""
return self._available_keys

def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
"""Update parameters with transform specific params"""
if hasattr(self, "interpolation"):
Expand All @@ -160,10 +174,6 @@ def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]
params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]})
return params

@property
def target_dependence(self) -> Dict[str, Any]:
return {}

def add_targets(self, additional_targets: Dict[str, str]) -> None:
"""Add targets to transform them the same way as one of existing targets
ex: {'target_image': 'image'}
Expand All @@ -174,7 +184,16 @@ def add_targets(self, additional_targets: Dict[str, str]) -> None:
additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
"""
self._additional_targets = {**self._additional_targets, **additional_targets}
for k, v in additional_targets.items():
if k in self._additional_targets and v != self._additional_targets[k]:
raise ValueError(
f"Trying to overwrite existed additional targets. "
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
)
if v in self._available_keys:
self._additional_targets[k] = v
self._key2func[k] = self.targets[v]
self._available_keys.add(k)

@property
def targets_as_params(self) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions albumentations/pytorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from albumentations.core.transforms_interface import BasicTransform
from albumentations.core.types import Targets

__all__ = ["ToTensorV2"]

Expand All @@ -23,6 +24,8 @@ class ToTensorV2(BasicTransform):
"""

_targets = (Targets.IMAGE, Targets.MASK)

def __init__(self, transpose_mask: bool = False, always_apply: bool = True, p: float = 1.0):
super().__init__(always_apply=always_apply, p=p)
self.transpose_mask = transpose_mask
Expand Down Expand Up @@ -51,6 +54,3 @@ def apply_to_masks(self, masks: List[np.ndarray], **params: Any) -> List[torch.T

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return ("transpose_mask",)

def get_params_dependent_on_targets(self, params: Any) -> Dict[str, Any]:
return {}
42 changes: 29 additions & 13 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_one_or_other():


def test_compose():
first = MagicMock()
second = MagicMock()
first = MagicMock(available_keys={"image"})
second = MagicMock(available_keys={"image"})
augmentation = Compose([first, second], p=1)
image = np.ones((8, 8))
augmentation(image=image)
Expand All @@ -70,8 +70,8 @@ def oneof_always_apply_crash():


def test_always_apply():
first = MagicMock(always_apply=True)
second = MagicMock(always_apply=False)
first = MagicMock(always_apply=True, available_keys={"image"})
second = MagicMock(always_apply=False, available_keys={"image"})
augmentation = Compose([first, second], p=0)
image = np.ones((8, 8))
augmentation(image=image)
Expand All @@ -80,7 +80,7 @@ def test_always_apply():


def test_one_of():
transforms = [Mock(p=1) for _ in range(10)]
transforms = [Mock(p=1, available_keys={"image"}) for _ in range(10)]
augmentation = OneOf(transforms, p=1)
image = np.ones((8, 8))
augmentation(image=image)
Expand All @@ -90,7 +90,7 @@ def test_one_of():
@pytest.mark.parametrize("N", [1, 2, 5, 10])
@pytest.mark.parametrize("replace", [True, False])
def test_n_of(N, replace):
transforms = [Mock(p=1, side_effect=lambda **kw: {"image": kw["image"]}) for _ in range(10)]
transforms = [Mock(p=1, side_effect=lambda **kw: {"image": kw["image"]}, available_keys={"image"}) for _ in range(10)]
augmentation = SomeOf(transforms, N, p=1, replace=replace)
image = np.ones((8, 8))
augmentation(image=image)
Expand All @@ -100,7 +100,7 @@ def test_n_of(N, replace):


def test_sequential():
transforms = [Mock(side_effect=lambda **kw: kw) for _ in range(10)]
transforms = [Mock(side_effect=lambda **kw: kw, available_keys={"image"}) for _ in range(10)]
augmentation = Sequential(transforms, p=1)
image = np.ones((8, 8))
augmentation(image=image)
Expand Down Expand Up @@ -254,13 +254,13 @@ def test_named_args():
],
)
def test_targets_type_check(targets, additional_targets, err_message):
aug = Compose([], additional_targets=additional_targets)
aug = Compose([A.NoOp()], additional_targets=additional_targets)

with pytest.raises(TypeError) as exc_info:
aug(**targets)
assert str(exc_info.value) == err_message

aug = Compose([])
aug = Compose([A.NoOp()])
aug.add_targets(additional_targets)
with pytest.raises(TypeError) as exc_info:
aug(**targets)
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_check_each_transform(targets, bbox_params, keypoint_params, expected):

@pytest.mark.parametrize("image", IMAGES)
def test_bbox_params_is_not_set(image, bboxes):
t = Compose([])
t = Compose([A.NoOp(p=1.0)])
with pytest.raises(ValueError) as exc_info:
t(image=image, bboxes=bboxes)
assert str(exc_info.value) == "bbox_params must be specified for bbox transformations"
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_choice_inner_compositions(transforms):
"transforms",
[
Compose([ChannelShuffle(p=1)], p=1),
Compose([ChannelShuffle(p=0)], p=0),
# Compose([ChannelShuffle(p=0)], p=0), # p=0, never calls, no process for data
],
)
def test_contiguous_output(transforms):
Expand All @@ -421,7 +421,7 @@ def test_contiguous_output(transforms):
],
)
def test_compose_image_mask_equal_size(targets):
transforms = Compose([])
transforms = Compose([A.NoOp()])

with pytest.raises(ValueError) as exc_info:
transforms(**targets)
Expand All @@ -432,7 +432,7 @@ def test_compose_image_mask_equal_size(targets):
"of Compose class (do it only if you are sure about your data consistency)."
)
# test after disabling shapes check
transforms = Compose([], is_check_shapes=False)
transforms = Compose([A.NoOp()], is_check_shapes=False)
transforms(**targets)


Expand Down Expand Up @@ -491,3 +491,19 @@ def test_sequential_multiple_transformations(image, aug):
# Since HorizontalFlip, VerticalFlip, and Transpose are all applied twice, the image should be the same
assert np.array_equal(result['image'], image)
assert np.array_equal(result['mask'], mask)


def test_compose_non_available_keys() -> None:
"""Check that non available keys raises error, except `mask` and `masks`"""
transform = A.Compose(
[MagicMock(available_keys={"image"}),],
)
image = np.empty([10, 10, 3], dtype=np.uint8)
mask = np.empty([10, 10], dtype=np.uint8)
_res = transform(image=image, mask=mask)
_res = transform(image=image, masks=[mask])
with pytest.raises(ValueError) as exc_info:
_res = transform(image=image, image_2=mask)

expected_msg = "Key image_2 is not in available keys."
assert str(exc_info.value) == expected_msg

0 comments on commit ee3c634

Please sign in to comment.