Skip to content

Commit

Permalink
Fix abs diff computation in check_batch test utility (#4957)
Browse files Browse the repository at this point in the history
* Fix check_batch abs diff computation
* Add one-time sanity check
---------

Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Aug 21, 2023
1 parent 8336af0 commit caf257a
Showing 1 changed file with 51 additions and 4 deletions.
55 changes: 51 additions & 4 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_dali_extra_path():
assert_array_equal = None
assert_allclose = None
cp = None
absdiff_checked = False


def import_numpy():
Expand Down Expand Up @@ -128,6 +129,55 @@ def get_gpu_num():
return len(out_list)


def _get_absdiff(left, right):

def make_unsigned(dtype):
if not np.issubdtype(dtype, np.signedinteger):
return dtype
return {
np.dtype(np.int8): np.uint8,
np.dtype(np.int16): np.uint16,
np.dtype(np.int32): np.uint32,
np.dtype(np.int64): np.uint64,
}[dtype]

# np.abs of diff doesn't handle overflow for unsigned types
absdiff = np.maximum(left, right) - np.minimum(left, right)
# max - min can overflow for signed types, wrap them up
absdiff = absdiff.astype(make_unsigned(absdiff.dtype))
return absdiff


def _check_absdiff():
"""
In principle, overflow on signed int is UB (that we relied on so far anyway).
The following one-time check aims to verify the overflow wraps as expected.
"""
for i in range(-128, 127):
for j in range(-128, 127):
left = np.array([i, i], dtype=np.int8)
right = np.array([j, j], dtype=np.int8)
diff = _get_absdiff(left, right)
expected_diff = np.array([abs(i - j), abs(i - j)], dtype=np.uint8)
assert np.array_equal(diff, expected_diff), f"{diff} {expected_diff} {i} {j}"
for i in range(0, 255):
for j in range(0, 255):
left = np.array([i, i], dtype=np.uint8)
right = np.array([j, j], dtype=np.uint8)
diff = _get_absdiff(left, right)
expected_diff = np.array([abs(i - j), abs(i - j)], dtype=np.uint8)
assert np.array_equal(diff, expected_diff), f"{diff} {expected_diff} {i} {j}"


def get_absdiff(left, right):
# Make sanity checks, in particular, if wrapping signed integers works as expected
global absdiff_checked
if not absdiff_checked:
absdiff_checked = True
_check_absdiff()
return _get_absdiff(left, right)


# If the `max_allowed_error` is not None, it's checked instead of comparing mean error with `eps`.
def check_batch(batch1, batch2, batch_size=None,
eps=1e-07, max_allowed_error=None,
Expand Down Expand Up @@ -203,10 +253,7 @@ def _verify_batch_size(batch):
left = left.astype(int)
if right.dtype == bool:
right = right.astype(int)
# abs doesn't handle overflow for uint8, so get minimal value of a-b and b-a
diff1 = np.abs(left - right)
diff2 = np.abs(right - left)
absdiff = np.minimum(diff2, diff1)
absdiff = get_absdiff(left, right)
err = np.mean(absdiff)
max_err = np.max(absdiff)
min_err = np.min(absdiff)
Expand Down

0 comments on commit caf257a

Please sign in to comment.