Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Mar 9, 2023
1 parent b6bbd46 commit 16ec1a3
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 73 deletions.
50 changes: 34 additions & 16 deletions dali/python/nvidia/dali/auto_aug/core/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self, name: str, num_magnitude_bins: int,
"""
self.name = name
self.num_magnitude_bins = num_magnitude_bins
if not isinstance(num_magnitude_bins, int) or num_magnitude_bins < 1:
raise Exception(
f"The `num_magnitude_bins` must be a positive integer, got {num_magnitude_bins}")
if not isinstance(sub_policies, (list, tuple)):
raise Exception(f"The `sub_policies` must be a list or tuple of sub policies, "
f"got {type(sub_policies)}.")
Expand All @@ -49,12 +52,21 @@ def __init__(self, name: str, num_magnitude_bins: int,
if not isinstance(op_desc, (list, tuple)) or len(op_desc) != 3:
raise Exception(f"Each operation in sub policy must be specified as a triple: "
f"(augmentation, probability, magnitude). Got {op_desc}.")
if not isinstance(op_desc[0], Augmentation):
aug, p, mag = op_desc
if not isinstance(aug, Augmentation):
raise Exception(
f"Each augmentation in sub policies must be an instance of "
f"Augmentation. Got {op_desc[0]}. Did you forget to use `@augmentation` "
f"Augmentation. Got {aug}. Did you forget to use `@augmentation` "
f"decorator?")
self.sub_policies = sub_policy_with_unique_names(sub_policies)
if not isinstance(p, (float, int)) or not 0 <= p <= 1:
raise Exception(
f"Probability of applying the augmentation must be a number from "
f"`[0, 1]` range. Got {p} for augmentation `{aug.name}`.")
if not isinstance(mag, int) or not 0 <= mag < self.num_magnitude_bins:
raise Exception(f"Magnitude of the augmentation must be an integer from "
f"`[0, {num_magnitude_bins - 1}]` range. "
f"Got {mag} for augmentation `{aug.name}`.")
self.sub_policies = _sub_policy_with_unique_names(sub_policies)

@property
def augmentations(self):
Expand All @@ -66,31 +78,37 @@ def __repr__(self):
sub_policies_repr = ",\n\t".join(
repr([(augment.name, p, mag) for augment, p, mag in sub_policy])
for sub_policy in self.sub_policies)
sub_policies_repr_sep = "" if not sub_policies_repr else "\n\t"
augmentations_repr = ",\n\t".join(f"'{name}': {repr(augment)}"
for name, augment in self.augmentations.items())
augmentations_repr_sep = "" if not augmentations_repr else "\n\t"
return (
f"Policy(name={repr(self.name)}, num_magnitude_bins={repr(self.num_magnitude_bins)}, "
f"sub_policies=[\n\t{sub_policies_repr}], augmentations={{\n\t{augmentations_repr}}})")
f"sub_policies=[{sub_policies_repr_sep}{sub_policies_repr}], "
f"augmentations={{{augmentations_repr_sep}{augmentations_repr}}})")


def sub_policy_with_unique_names(
def _sub_policy_with_unique_names(
sub_policies: Sequence[Sequence[Tuple[Augmentation, float, int]]]
) -> Tuple[Tuple[Tuple[Augmentation, float, int]]]:
augments = set(aug for sub_policy in sub_policies for aug, p, mag in sub_policy)
"""
Check if the augmentations used in the sub-policies have unique names.
If not, rename them by adding enumeration to the names.
The aim is to have non-ambiguous presentation.
"""
all_augments = [aug for sub_policy in sub_policies for aug, p, mag in sub_policy]
augments = set(all_augments)
names = set(aug.name for aug in augments)
if len(names) == len(augments):
return tuple(tuple(sub_policy) for sub_policy in sub_policies)
aug_by_name = {name: [] for name in names}
for aug in augments:
aug_by_name[aug.name].append(aug)
num_digits = len(str(len(augments) - 1))
remap_aug = {}
for aug_name, augs in aug_by_name.items():
if len(augs) == 1:
[aug] = augs
remap_aug[aug] = aug
else:
for i, aug in enumerate(augs):
remap_aug[aug] = aug.augmentation(name=f"{aug_name}__{i}")
i = 0
for augment in all_augments:
if augment not in remap_aug:
remap_aug[augment] = augment.augmentation(
name=f"{str(i).zfill(num_digits)}__{augment.name}")
i += 1
return tuple(
tuple((remap_aug[aug], p, mag) for aug, p, mag in sub_policy)
for sub_policy in sub_policies)
Loading

0 comments on commit 16ec1a3

Please sign in to comment.