Skip to content

Commit

Permalink
Channel subset (#1690)
Browse files Browse the repository at this point in the history
* Added SelectiveChannelTransform

* Added SelectiveChannelTransform

* Added SelectiveChannelTransform

* Fix in tests

* Clean in tranforms
  • Loading branch information
ternaus committed Apr 25, 2024
1 parent bd9c970 commit d47389c
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 49 deletions.
16 changes: 8 additions & 8 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,9 +1375,9 @@ def apply(self, img: np.ndarray, r_shift: int = 0, g_shift: int = 0, b_shift: in

def get_params(self) -> Dict[str, Any]:
return {
"r_shift": random.uniform(self.r_shift_limit[0], self.r_shift_limit[1]),
"g_shift": random.uniform(self.g_shift_limit[0], self.g_shift_limit[1]),
"b_shift": random.uniform(self.b_shift_limit[0], self.b_shift_limit[1]),
"r_shift": random_utils.uniform(self.r_shift_limit[0], self.r_shift_limit[1]),
"g_shift": random_utils.uniform(self.g_shift_limit[0], self.g_shift_limit[1]),
"b_shift": random_utils.uniform(self.b_shift_limit[0], self.b_shift_limit[1]),
}

def get_transform_init_args_names(self) -> Tuple[str, str, str]:
Expand Down Expand Up @@ -1606,7 +1606,7 @@ def get_transform_init_args_names(self) -> Tuple[str, str]:


class ChannelShuffle(ImageOnlyTransform):
"""Randomly rearrange channels of the input RGB image.
"""Randomly rearrange channels of the image.
Args:
p: probability of applying the transform. Default: 0.5.
Expand All @@ -1623,7 +1623,7 @@ class ChannelShuffle(ImageOnlyTransform):
def targets_as_params(self) -> List[str]:
return ["image"]

def apply(self, img: np.ndarray, channels_shuffled: Tuple[int, int, int] = (0, 1, 2), **params: Any) -> np.ndarray:
def apply(self, img: np.ndarray, channels_shuffled: Tuple[int, ...] = (0, 1, 2), **params: Any) -> np.ndarray:
return F.channel_shuffle(img, channels_shuffled)

def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -2141,7 +2141,7 @@ class FancyPCA(ImageOnlyTransform):
"ImageNet Classification with Deep Convolutional Neural Networks"
Args:
alpha: how much to perturb/scale the eigen vecs and vals.
alpha: how much to perturb/scale the eigen vectors and eigenvalues.
scale is samples from gaussian distribution (mu=0, sigma=alpha)
Targets:
Expand Down Expand Up @@ -2415,9 +2415,9 @@ class Superpixels(ImageOnlyTransform):
* A probability of ``0.5`` would mean, that around half of all
segments are replaced by their average color.
* A probability of ``1.0`` would mean, that all segments are
replaced by their average color (resulting in a voronoi
replaced by their average color (resulting in a Voronoi
image).
Behaviour based on chosen data types for this parameter:
Behavior based on chosen data types for this parameter:
* If a ``float``, then that ``flat`` will always be used.
* If ``tuple`` ``(a, b)``, then a random probability will be
sampled from the interval ``[a, b]`` per image.
Expand Down
65 changes: 44 additions & 21 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast

import cv2
import numpy as np

from albumentations import random_utils
Expand Down Expand Up @@ -30,12 +31,12 @@
"Sequential",
"TransformType",
"TransformsSeqType",
"SelectiveChannelTransform",
]

TWO = 2


NUM_ONEOF_TRANSFORMS = 2
REPR_INDENT_STEP = 2

TransformType = Union[BasicTransform, "BaseCompose"]
TransformsSeqType = List[TransformType]

Expand Down Expand Up @@ -406,7 +407,7 @@ def __init__(
raise ValueError(msg)
transforms = [first, second]
super().__init__(transforms, p)
if len(self.transforms) != TWO:
if len(self.transforms) != NUM_ONEOF_TRANSFORMS:
warnings.warn("Length of transforms is not equal to 2.")

def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[str, Any]:
Expand All @@ -421,37 +422,59 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[s
return self.transforms[-1](force_apply=True, **data)


class PerChannel(BaseCompose):
"""Apply transformations per-channel
class SelectiveChannelTransform(BaseCompose):
"""A transformation class to apply specified transforms to selected channels of an image.
Args:
transforms (list): list of transformations to compose.
channels (sequence): channels to apply the transform to. Pass None to apply to all.
Default: None (apply to all)
p (float): probability of applying the transform. Default: 0.5.
This class extends BaseCompose to allow selective application of transformations to
specified image channels. It extracts the selected channels, applies the transformations,
and then reinserts the transformed channels back into their original positions in the image.
Parameters:
transforms (TransformsSeqType):
A sequence of transformations (from Albumentations) to be applied to the specified channels.
channels (Sequence[int]):
A sequence of integers specifying the indices of the channels to which the transforms should be applied.
always_apply (bool):
If True, the transform will always be applied, ignoring the probability `p`.
p (float):
Probability that the transform will be applied; the default is 1.0 (always apply).
Methods:
__call__(*args, **kwargs):
Applies the transforms to the image according to the specified channels.
The input data should include 'image' key with the image array.
Returns:
Dict[str, Any]: The transformed data dictionary, which includes the transformed 'image' key.
"""

def __init__(self, transforms: TransformsSeqType, channels: Optional[Sequence[int]] = None, p: float = 0.5):
def __init__(
self,
transforms: TransformsSeqType,
channels: Sequence[int] = (0, 1, 2),
always_apply: bool = False,
p: float = 1.0,
) -> None:
super().__init__(transforms, p)
self.channels = channels

def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[str, Any]:
if force_apply or random.random() < self.p:
image = data["image"]

# Expand mono images to have a single channel
if len(image.shape) == TWO:
image = np.expand_dims(image, -1)
selected_channels = image[:, :, self.channels]
sub_image = np.ascontiguousarray(selected_channels)

for t in self.transforms:
sub_image = t(image=sub_image)["image"]

if self.channels is None:
self.channels = range(image.shape[2])
transformed_channels = cv2.split(sub_image)
output_img = image.copy()

for c in self.channels:
for t in self.transforms:
image[:, :, c] = t(image=image[:, :, c])["image"]
for idx, channel in zip(self.channels, transformed_channels):
output_img[:, :, idx] = channel

data["image"] = image
data["image"] = np.ascontiguousarray(output_img)

return data

Expand Down
2 changes: 1 addition & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def targets(self) -> Dict[str, Callable[..., Any]]:


class NoOp(DualTransform):
"""Identical transform (does nothing).
"""Identity transform (does nothing).
Targets:
image, mask, bboxes, keypoints, global_label
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,5 @@ def mp_pool():

SQUARE_IMAGES = [SQUARE_UINT8_IMAGE, SQUARE_FLOAT_IMAGE]
RECTANGULAR_IMAGES = [RECTANGULAR_UINT8_IMAGE, RECTANGULAR_FLOAT_IMAGE]

SQUARE_MULTI_UINT8_IMAGE = np.random.randint(low=0, high=256, size=(100, 100, 7), dtype=np.uint8)
19 changes: 1 addition & 18 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
KeypointParams,
OneOf,
OneOrOther,
PerChannel,
ReplayCompose,
Sequential,
SomeOf,
Expand Down Expand Up @@ -123,7 +122,7 @@ def test_to_tuple(input, kwargs, expected):


@pytest.mark.parametrize("image", IMAGES)
def test_image_only_transform(request, image):
def test_image_only_transform(image):
mask = image.copy()
height, width = image.shape[:2]
with mock.patch.object(ImageOnlyTransform, "apply") as mocked_apply:
Expand Down Expand Up @@ -193,22 +192,6 @@ def test_check_bboxes_with_end_greater_that_start():
assert str(exc_info.value) == message


def test_per_channel_mono():
transforms = [Blur(), Rotate()]
augmentation = PerChannel(transforms, p=1)
image = np.ones((8, 8))
data = augmentation(image=image)
assert data


def test_per_channel_multi():
transforms = [Blur(), Rotate()]
augmentation = PerChannel(transforms, p=1)
image = np.ones((8, 8, 5))
data = augmentation(image=image)
assert data


def test_deterministic_oneof():
aug = ReplayCompose([OneOf([HorizontalFlip(), Blur()])], p=1)
for _ in range(10):
Expand Down
126 changes: 125 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import albumentations.augmentations.functional as F
import albumentations.augmentations.geometric.functional as FGeometric
from albumentations.augmentations.blur.functional import gaussian_blur
from tests.conftest import IMAGES
from tests.conftest import IMAGES, SQUARE_MULTI_UINT8_IMAGE, SQUARE_UINT8_IMAGE

from .utils import get_dual_transforms, get_image_only_transforms, get_transforms, set_seed

Expand Down Expand Up @@ -1454,3 +1454,127 @@ def test_coarse_dropout_functionality(params, expected):
def test_coarse_dropout_invalid_input(params):
with pytest.raises(Exception):
aug = A.CoarseDropout(**params, p=1)


@pytest.mark.parametrize(
["augmentation_cls", "params"],
get_transforms(
custom_arguments={
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
A.Resize: {"height": 10, "width": 10},
A.TemplateTransform: {
"templates": np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8),
},
A.XYMasking: {
"num_masks_x": (1, 3),
"num_masks_y": (1, 3),
"mask_x_length": 10,
"mask_y_length": 10,
"mask_fill_value": 1,
"fill_value": 0,
},
A.Superpixels: {"p_replace": (1, 1),
"n_segments": (10, 10),
"max_size": 10
},
},
except_augmentations={
A.RandomCropNearBBox,
A.RandomSizedBBoxSafeCrop,
A.BBoxSafeRandomCrop,
A.CropNonEmptyMaskIfExists,
A.FDA,
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.MaskDropout,
A.MixUp,
A.NoOp,
A.Lambda,
A.ToRGB,
A.RandomRotate90,
A.FancyPCA
},
),
)
def test_change_image(augmentation_cls, params):
"""Checks whether transform performs changes to the image."""
aug = A.Compose([augmentation_cls(p=1, **params)])
image = SQUARE_UINT8_IMAGE
assert not np.array_equal(aug(image=image)["image"], image)

@pytest.mark.parametrize(
["augmentation_cls", "params"],
get_transforms(
custom_arguments={
A.XYMasking: {
"num_masks_x": (1, 3),
"num_masks_y": (1, 3),
"mask_x_length": 10,
"mask_y_length": 10,
"mask_fill_value": 1,
"fill_value": 0,
},
A.Superpixels: {"p_replace": (1, 1),
"n_segments": (10, 10),
"max_size": 10
},
A.FancyPCA: {"alpha":1}
},
except_augmentations={
A.Crop,
A.CenterCrop,
A.CropNonEmptyMaskIfExists,
A.RandomCrop,
A.RandomResizedCrop,
A.RandomSizedCrop,
A.CropAndPad,
A.Resize,
A.TemplateTransform,
A.RandomCropNearBBox,
A.RandomSizedBBoxSafeCrop,
A.BBoxSafeRandomCrop,
A.CropNonEmptyMaskIfExists,
A.FDA,
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.MaskDropout,
A.MixUp,
A.NoOp,
A.Lambda,
A.ToRGB,
A.ChannelDropout,
A.LongestMaxSize,
A.PadIfNeeded,
A.RandomCropFromBorders,
A.SmallestMaxSize,
A.RandomScale,
A.ChannelShuffle,
A.ChromaticAberration,
A.RandomRotate90,
A.FancyPCA
},
),
)
def test_selective_channel(augmentation_cls, params):
set_seed(0)

image = SQUARE_MULTI_UINT8_IMAGE
channels = [3, 2, 4]

aug = A.Compose(
[A.SelectiveChannelTransform(transforms=[augmentation_cls(**params, always_apply=True, p=1)], channels=channels, always_apply=True, p=1)],
)

transformed_image = aug(image=image)["image"]

for channel in range(image.shape[-1]):
if channel in channels:
assert not np.array_equal(image[..., channel], transformed_image[..., channel])
else:
assert np.array_equal(image[..., channel], transformed_image[..., channel])

0 comments on commit d47389c

Please sign in to comment.