From 8ed72509d67f4c1304930f39a2c6dad649de377b Mon Sep 17 00:00:00 2001 From: vipandya Date: Sun, 3 May 2026 15:09:11 +0000 Subject: [PATCH] add unit test for checking any leak of temporary augmented onnx files during onnx int4 awq quantization Signed-off-by: vipandya --- .../onnx/quantization/test_quantize_zint4.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/unit/onnx/quantization/test_quantize_zint4.py b/tests/unit/onnx/quantization/test_quantize_zint4.py index f60533a829..7ba883b0ce 100644 --- a/tests/unit/onnx/quantization/test_quantize_zint4.py +++ b/tests/unit/onnx/quantization/test_quantize_zint4.py @@ -14,11 +14,13 @@ # limitations under the License. import os +import tempfile as _tempfile from collections.abc import Sequence import numpy as np import onnx import onnx_graphsurgeon as gs +import pytest from _test_utils.onnx.lib_test_models import find_init import modelopt.onnx.quantization as moq @@ -225,3 +227,57 @@ def is_quant_scale_with_right_shape(model, quant_axis, block_size): ) # Ensure above tests pass. + + +@pytest.mark.parametrize("calibration_method", ["awq_lite", "awq_clip"]) +@pytest.mark.parametrize("use_external_data_format", [True, False]) +def test_awq_no_temp_file_leak(tmp_path, monkeypatch, calibration_method, use_external_data_format): + """Test that tmp*.onnx and tmp*.onnx_data are written to the + system temp directory must be removed even when quantization fails mid-run. + + Simulates the real-world failure window (OOM, bad EP, driver error) by injecting + a RuntimeError at ORT session creation — which happens after the augmented ONNX + has already been written to disk but before the original cleanup code was reached. + + Thread-safe: tracks the exact paths created by mkstemp during this test rather + than glob-snapshotting the temp directory, so parallel test runs cannot interfere. + """ + onnx_path = _matmul_model( + w=np.random.rand(288, 16).astype(np.float32), + in_shape=(96, 288), + out_shape=(96, 16), + tmp_path=tmp_path, + ) + + # Intercept mkstemp to record the exact augmented-model temp path(s) created. + created_paths = [] + real_mkstemp = _tempfile.mkstemp + + def _tracking_mkstemp(*args, **kwargs): + fd, path = real_mkstemp(*args, **kwargs) + created_paths.append(path) + return fd, path + + monkeypatch.setattr("modelopt.onnx.quantization.int4.tempfile.mkstemp", _tracking_mkstemp) + + def _raise_session_error(*args, **kwargs): + raise RuntimeError("injected ORT session failure") + + monkeypatch.setattr( + "modelopt.onnx.quantization.int4.create_inference_session", + _raise_session_error, + ) + + with pytest.raises(RuntimeError, match="injected ORT session failure"): + quantize_int4( + onnx_path, + calibration_method=calibration_method, + use_external_data_format=use_external_data_format, + block_size=8, + ) + + assert created_paths, "Expected mkstemp to be called but it was not" + for augmented_path in created_paths: + assert not os.path.exists(augmented_path), f"Leaked: {augmented_path}" + if use_external_data_format: + assert not os.path.exists(augmented_path + "_data"), f"Leaked: {augmented_path}_data"