diff --git a/test/quantization/algorithm/test_gptq.py b/test/quantization/algorithm/test_gptq.py index 9747a297..0a5c6327 100644 --- a/test/quantization/algorithm/test_gptq.py +++ b/test/quantization/algorithm/test_gptq.py @@ -391,7 +391,7 @@ def test_model(self): ) # Load data - dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train") + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids # base diff --git a/test/quantization/algorithm/test_smooth_quant.py b/test/quantization/algorithm/test_smooth_quant.py index bdfab5b5..6aa3ac01 100644 --- a/test/quantization/algorithm/test_smooth_quant.py +++ b/test/quantization/algorithm/test_smooth_quant.py @@ -49,7 +49,7 @@ def test_value(self): ) # Load data - dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train") + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids # base diff --git a/test/quantization/recipes/optional_dependency_stubs.py b/test/quantization/recipes/optional_dependency_stubs.py index 2f7d6d41..b3061862 100644 --- a/test/quantization/recipes/optional_dependency_stubs.py +++ b/test/quantization/recipes/optional_dependency_stubs.py @@ -21,9 +21,12 @@ make the tests heavier than necessary. """ +import importlib import sys import types +_STUB_MARKER = "__tico_optional_dependency_stub__" + def install_optional_dependency_stubs() -> None: """Install lightweight stubs for optional recipe dependencies.""" @@ -31,19 +34,56 @@ def install_optional_dependency_stubs() -> None: _install_lm_eval_stub() +def _has_attrs(module: types.ModuleType, names: tuple[str, ...]) -> bool: + """Return True if a module has all required attributes.""" + return all(hasattr(module, name) for name in names) + + +def _is_our_stub(module: types.ModuleType | None) -> bool: + """Return True if a module was installed by this stub helper.""" + return bool(module is not None and getattr(module, _STUB_MARKER, False)) + + +def _try_import_optional_module(module_name: str) -> types.ModuleType | None: + """ + Import an optional module if it is available. + + Missing optional modules are tolerated. Import failures caused by missing + transitive dependencies are re-raised so broken real installations are not + silently hidden by test stubs. + """ + try: + return importlib.import_module(module_name) + except ModuleNotFoundError as exc: + top_level_name = module_name.partition(".")[0] + if exc.name in {module_name, top_level_name}: + return None + raise + + def _install_datasets_stub() -> None: """Install a minimal datasets module when the real package is unavailable.""" - if "datasets" in sys.modules and all( - hasattr(sys.modules["datasets"], name) - for name in ("Dataset", "IterableDataset", "load_dataset") + required_attrs = ("Dataset", "IterableDataset", "load_dataset") + existing_module = sys.modules.get("datasets") + + if ( + existing_module is not None + and not _is_our_stub(existing_module) + and _has_attrs(existing_module, required_attrs) ): return + real_module = _try_import_optional_module("datasets") + if real_module is not None and _has_attrs(real_module, required_attrs): + return + module = sys.modules.get("datasets") - if module is None: + if module is None or not _is_our_stub(module): module = types.ModuleType("datasets") sys.modules["datasets"] = module + setattr(module, _STUB_MARKER, True) + class Dataset: """Minimal datasets.Dataset stub for import-time compatibility.""" @@ -74,15 +114,31 @@ def load_dataset(*args, **kwargs): def _install_lm_eval_stub() -> None: """Install minimal lm_eval modules when the real package is unavailable.""" - if "lm_eval" in sys.modules and hasattr(sys.modules["lm_eval"], "evaluator"): + existing_module = sys.modules.get("lm_eval") + + if ( + existing_module is not None + and not _is_our_stub(existing_module) + and hasattr(existing_module, "evaluator") + ): return + real_module = _try_import_optional_module("lm_eval") + if real_module is not None and not _is_our_stub(real_module): + real_evaluator_module = _try_import_optional_module("lm_eval.evaluator") + if real_evaluator_module is not None: + setattr(real_module, "evaluator", real_evaluator_module) + return + lm_eval_module = sys.modules.get("lm_eval") - if lm_eval_module is None: + if lm_eval_module is None or not _is_our_stub(lm_eval_module): lm_eval_module = types.ModuleType("lm_eval") sys.modules["lm_eval"] = lm_eval_module + setattr(lm_eval_module, _STUB_MARKER, True) + evaluator_module = types.ModuleType("lm_eval.evaluator") + setattr(evaluator_module, _STUB_MARKER, True) def simple_evaluate(*args, **kwargs): """Fail clearly if a test accidentally runs real lm-eval.""" @@ -96,6 +152,7 @@ def simple_evaluate(*args, **kwargs): setattr(lm_eval_module, "evaluator", evaluator_module) utils_module = types.ModuleType("lm_eval.utils") + setattr(utils_module, _STUB_MARKER, True) def make_table(results): """Return a stable string representation for patched evaluation results.""" @@ -106,7 +163,10 @@ def make_table(results): setattr(lm_eval_module, "utils", utils_module) models_module = types.ModuleType("lm_eval.models") + setattr(models_module, _STUB_MARKER, True) + huggingface_module = types.ModuleType("lm_eval.models.huggingface") + setattr(huggingface_module, _STUB_MARKER, True) class HFLM: """Minimal HFLM stub for import-time compatibility."""