Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix abs diff computation in check_batch test utility #4957

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_dali_extra_path():
assert_array_equal = None
assert_allclose = None
cp = None
absdiff_checked = False


def import_numpy():
Expand Down Expand Up @@ -128,6 +129,55 @@ def get_gpu_num():
return len(out_list)


def _get_absdiff(left, right):

def make_unsigned(dtype):
if not np.issubdtype(dtype, np.signedinteger):
return dtype
return {
np.dtype(np.int8): np.uint8,
np.dtype(np.int16): np.uint16,
np.dtype(np.int32): np.uint32,
np.dtype(np.int64): np.uint64,
}[dtype]

# np.abs of diff doesn't handle overflow for unsigned types
absdiff = np.maximum(left, right) - np.minimum(left, right)
# max - min can overflow for signed types, wrap them up
absdiff = absdiff.astype(make_unsigned(absdiff.dtype))
return absdiff


def _check_absdiff():
"""
In principle, overflow on signed int is UB (that we relied on so far anyway).
The following one-time check aims to verify the overflow wraps as expected.
"""
for i in range(-128, 127):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, overflow on signed int is UB (that we rely on anyway without this PR). The following one-time check aims to verify it works as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should explain it in a comment in the code rather than in the PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for j in range(-128, 127):
left = np.array([i, i], dtype=np.int8)
right = np.array([j, j], dtype=np.int8)
diff = _get_absdiff(left, right)
expected_diff = np.array([abs(i - j), abs(i - j)], dtype=np.uint8)
assert np.array_equal(diff, expected_diff), f"{diff} {expected_diff} {i} {j}"
for i in range(0, 255):
for j in range(0, 255):
left = np.array([i, i], dtype=np.uint8)
right = np.array([j, j], dtype=np.uint8)
diff = _get_absdiff(left, right)
expected_diff = np.array([abs(i - j), abs(i - j)], dtype=np.uint8)
assert np.array_equal(diff, expected_diff), f"{diff} {expected_diff} {i} {j}"


def get_absdiff(left, right):
# Make sanity checks, in particular, if wrapping signed integers works as expected
global absdiff_checked
if not absdiff_checked:
absdiff_checked = True
_check_absdiff()
return _get_absdiff(left, right)


# If the `max_allowed_error` is not None, it's checked instead of comparing mean error with `eps`.
def check_batch(batch1, batch2, batch_size=None,
eps=1e-07, max_allowed_error=None,
Expand Down Expand Up @@ -203,10 +253,7 @@ def _verify_batch_size(batch):
left = left.astype(int)
if right.dtype == bool:
right = right.astype(int)
# abs doesn't handle overflow for uint8, so get minimal value of a-b and b-a
diff1 = np.abs(left - right)
diff2 = np.abs(right - left)
absdiff = np.minimum(diff2, diff1)
absdiff = get_absdiff(left, right)
err = np.mean(absdiff)
max_err = np.max(absdiff)
min_err = np.min(absdiff)
Expand Down