Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions cuda_bindings/cuda/bindings/_internal/strdecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
"""Decode C strings returned by CUDA libraries with actionable failure context."""

# Cap sized for the #2118 mojibake without flooding exception text.
_PREVIEW_MAX_BYTES = 64


def _bounded_hex_preview(data: bytes, max_bytes: int = _PREVIEW_MAX_BYTES) -> str:
# Bytes after the first NUL are not part of the returned C string. The
# marker is explicit so truncation cannot be misread as an empty value.
nul = data.find(b"\x00")
nul_stopped = nul != -1
visible_end = len(data) if not nul_stopped else nul
snippet_end = min(visible_end, max_bytes)
snippet = data[:snippet_end]
body = snippet.hex(" ") if snippet else ""
parts = []
if snippet_end < visible_end:
parts.append(f"+{visible_end - snippet_end} more")
if nul_stopped:
parts.append(f"stopped at NUL@{nul}")
suffix = f" ...({'; '.join(parts)})" if parts else ""
return f"<{visible_end} bytes; hex='{body}'{suffix}>"


def decode_c_str(data: bytes, api_name: str) -> str:
"""Decode ``data`` as UTF-8, or raise ``UnicodeDecodeError`` with ``api_name`` and a bounded hex preview in ``reason``.

Internal API. ``api_name`` is trusted caller input and embedded verbatim.
"""
try:
return data.decode("utf-8")
except UnicodeDecodeError as e:
# Same exception type, not a subclass, so existing handlers still catch.
preview = _bounded_hex_preview(data)
reason = f"{e.reason} (returned by {api_name}; bytes={preview})"
raise UnicodeDecodeError(e.encoding, e.object, e.start, e.end, reason) from e
80 changes: 80 additions & 0 deletions cuda_bindings/tests/test_strdecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import pytest

from cuda.bindings._internal.strdecode import _bounded_hex_preview, decode_c_str

WSL_MOJIBAKE_PREFIX = b"\xf8\x9a\x80\x80\xaf"


def test_valid_utf8_passthrough():
assert decode_c_str(b"hello world", "fakeApi") == "hello world"


def test_invalid_bytes_raise_unicode_decode_error():
with pytest.raises(UnicodeDecodeError):
decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName")


def test_failure_reason_includes_api_name():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName")
assert "nvmlSystemGetProcessName" in excinfo.value.reason


def test_failure_reason_includes_hex_preview():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName")
assert "f8 9a 80 80 af" in excinfo.value.reason


def test_failure_chains_original_error():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(b"\xf8", "fakeApi")
assert isinstance(excinfo.value.__cause__, UnicodeDecodeError)


def test_failure_preserves_codec_and_position():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(b"\xf8\x9a", "fakeApi")
assert excinfo.value.encoding == "utf-8"
assert excinfo.value.start == 0
assert excinfo.value.end == 1


def test_preview_stops_at_first_nul():
preview = _bounded_hex_preview(b"\xf8\xf8\x00trailing junk")
assert "f8 f8" in preview
assert "trailing" not in preview
assert "<2 bytes;" in preview
assert "stopped at NUL@2" in preview


def test_preview_caps_long_buffers():
preview = _bounded_hex_preview(b"\xf8" * 200, max_bytes=8)
assert "f8 f8 f8 f8 f8 f8 f8 f8" in preview
assert "+192 more" in preview
assert "stopped at NUL" not in preview


def test_preview_combines_truncation_and_nul_markers():
preview = _bounded_hex_preview(b"\xf8" * 20 + b"\x00rest", max_bytes=8)
assert "+12 more" in preview
assert "stopped at NUL@20" in preview


def test_failure_preview_stops_at_embedded_nul_even_with_bad_bytes_before():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(b"\xf8\x9a\x00ignored_after_nul", "fakeApi")
reason = excinfo.value.reason
assert "f8 9a" in reason
assert "ignored_after_nul" not in reason


def test_failure_message_stays_bounded_for_long_garbage():
with pytest.raises(UnicodeDecodeError) as excinfo:
decode_c_str(b"\xf8" * 1024, "fakeApi")
reason = excinfo.value.reason
assert "+960 more" in reason
assert len(reason) < 500
Loading