Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/unit/onnx/quantization/test_quantize_zint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.

import os
import tempfile as _tempfile
from collections.abc import Sequence

import numpy as np
import onnx
import onnx_graphsurgeon as gs
import pytest
from _test_utils.onnx.lib_test_models import find_init

import modelopt.onnx.quantization as moq
Expand Down Expand Up @@ -225,3 +227,57 @@ def is_quant_scale_with_right_shape(model, quant_axis, block_size):
)

# Ensure above tests pass.


@pytest.mark.parametrize("calibration_method", ["awq_lite", "awq_clip"])
@pytest.mark.parametrize("use_external_data_format", [True, False])
def test_awq_no_temp_file_leak(tmp_path, monkeypatch, calibration_method, use_external_data_format):
"""Test that tmp*.onnx and tmp*.onnx_data are written to the
system temp directory must be removed even when quantization fails mid-run.

Simulates the real-world failure window (OOM, bad EP, driver error) by injecting
a RuntimeError at ORT session creation — which happens after the augmented ONNX
has already been written to disk but before the original cleanup code was reached.

Thread-safe: tracks the exact paths created by mkstemp during this test rather
than glob-snapshotting the temp directory, so parallel test runs cannot interfere.
"""
onnx_path = _matmul_model(
w=np.random.rand(288, 16).astype(np.float32),
in_shape=(96, 288),
out_shape=(96, 16),
tmp_path=tmp_path,
)

# Intercept mkstemp to record the exact augmented-model temp path(s) created.
created_paths = []
real_mkstemp = _tempfile.mkstemp

def _tracking_mkstemp(*args, **kwargs):
fd, path = real_mkstemp(*args, **kwargs)
created_paths.append(path)
return fd, path

monkeypatch.setattr("modelopt.onnx.quantization.int4.tempfile.mkstemp", _tracking_mkstemp)

def _raise_session_error(*args, **kwargs):
raise RuntimeError("injected ORT session failure")

monkeypatch.setattr(
"modelopt.onnx.quantization.int4.create_inference_session",
_raise_session_error,
)

with pytest.raises(RuntimeError, match="injected ORT session failure"):
quantize_int4(
onnx_path,
calibration_method=calibration_method,
use_external_data_format=use_external_data_format,
block_size=8,
)

assert created_paths, "Expected mkstemp to be called but it was not"
for augmented_path in created_paths:
assert not os.path.exists(augmented_path), f"Leaked: {augmented_path}"
if use_external_data_format:
assert not os.path.exists(augmented_path + "_data"), f"Leaked: {augmented_path}_data"
Loading