Skip to content

Commit

Permalink
Normalize using cv2 (#563)
Browse files Browse the repository at this point in the history
* Use OpenCV for normalization.

* OpenCV normalize by condition

* Rename scale to scale_image and clahe to clahe_mat to avoid Re-defined variable from outer scope

* OpenCV faster only for rgb images

* Update test_normalize_np_cv_equal

* Contiguous input for normalize.

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Mikhail Druzhinin <dipet@gmail.com>
  • Loading branch information
3 people committed Jul 16, 2022
1 parent ed7626f commit 2eac060
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
28 changes: 24 additions & 4 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,27 @@ def rot90(img, factor):
return np.ascontiguousarray(img)


def normalize_cv2(img, mean, denominator):
if mean.shape and len(mean) != 4 and mean.shape != img.shape:
mean = np.array(mean.tolist() + [0] * (4 - len(mean)), dtype=np.float64)
if not denominator.shape:
denominator = np.array([denominator.tolist()] * 4, dtype=np.float64)
elif len(denominator) != 4 and denominator.shape != img.shape:
denominator = np.array(denominator.tolist() + [1] * (4 - len(denominator)), dtype=np.float64)

img = np.ascontiguousarray(img.astype("float32"))
cv2.subtract(img, mean.astype(np.float64), img)
cv2.multiply(img, denominator.astype(np.float64), img)
return img


def normalize_numpy(img, mean, denominator):
img = img.astype(np.float32)
img -= mean
img *= denominator
return img


def normalize(img, mean, std, max_pixel_value=255.0):
mean = np.array(mean, dtype=np.float32)
mean *= max_pixel_value
Expand All @@ -253,10 +274,9 @@ def normalize(img, mean, std, max_pixel_value=255.0):

denominator = np.reciprocal(std, dtype=np.float32)

img = img.astype(np.float32)
img -= mean
img *= denominator
return img
if img.ndim == 3 and img.shape[-1] == 3:
return normalize_cv2(img, mean, denominator)
return normalize_numpy(img, mean, denominator)


def _maybe_process_in_chunks(process_fn, **kwargs):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,20 @@ def test_cv_dtype_from_np():
assert F.get_opencv_dtype_from_numpy(np.dtype("float32")) == cv2.CV_32F
assert F.get_opencv_dtype_from_numpy(np.dtype("float64")) == cv2.CV_64F
assert F.get_opencv_dtype_from_numpy(np.dtype("int32")) == cv2.CV_32S


@pytest.mark.parametrize(
["image", "mean", "std"],
[
[np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
[np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8), 0.5, 0.5],
[np.random.randint(0, 256, [100, 100], dtype=np.uint8), 0.5, 0.5],
],
)
def test_normalize_np_cv_equal(image, mean, std):
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)

res1 = F.normalize_cv2(image, mean, std)
res2 = F.normalize_numpy(image, mean, std)
assert np.allclose(res1, res2)

0 comments on commit 2eac060

Please sign in to comment.