Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/quantization/algorithm/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_model(self):
)

# Load data
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids

# base
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_value(self):
)

# Load data
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="train")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
sample_input = tokenizer(dataset[0]["text"], return_tensors="pt").input_ids

# base
Expand Down
72 changes: 66 additions & 6 deletions test/quantization/recipes/optional_dependency_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,69 @@
make the tests heavier than necessary.
"""

import importlib
import sys
import types

_STUB_MARKER = "__tico_optional_dependency_stub__"


def install_optional_dependency_stubs() -> None:
"""Install lightweight stubs for optional recipe dependencies."""
_install_datasets_stub()
_install_lm_eval_stub()


def _has_attrs(module: types.ModuleType, names: tuple[str, ...]) -> bool:
"""Return True if a module has all required attributes."""
return all(hasattr(module, name) for name in names)


def _is_our_stub(module: types.ModuleType | None) -> bool:
"""Return True if a module was installed by this stub helper."""
return bool(module is not None and getattr(module, _STUB_MARKER, False))


def _try_import_optional_module(module_name: str) -> types.ModuleType | None:
"""
Import an optional module if it is available.

Missing optional modules are tolerated. Import failures caused by missing
transitive dependencies are re-raised so broken real installations are not
silently hidden by test stubs.
"""
try:
return importlib.import_module(module_name)
except ModuleNotFoundError as exc:
top_level_name = module_name.partition(".")[0]
if exc.name in {module_name, top_level_name}:
return None
raise


def _install_datasets_stub() -> None:
"""Install a minimal datasets module when the real package is unavailable."""
if "datasets" in sys.modules and all(
hasattr(sys.modules["datasets"], name)
for name in ("Dataset", "IterableDataset", "load_dataset")
required_attrs = ("Dataset", "IterableDataset", "load_dataset")
existing_module = sys.modules.get("datasets")

if (
existing_module is not None
and not _is_our_stub(existing_module)
and _has_attrs(existing_module, required_attrs)
):
return

real_module = _try_import_optional_module("datasets")
if real_module is not None and _has_attrs(real_module, required_attrs):
return

module = sys.modules.get("datasets")
if module is None:
if module is None or not _is_our_stub(module):
module = types.ModuleType("datasets")
sys.modules["datasets"] = module

setattr(module, _STUB_MARKER, True)

class Dataset:
"""Minimal datasets.Dataset stub for import-time compatibility."""

Expand Down Expand Up @@ -74,15 +114,31 @@ def load_dataset(*args, **kwargs):

def _install_lm_eval_stub() -> None:
"""Install minimal lm_eval modules when the real package is unavailable."""
if "lm_eval" in sys.modules and hasattr(sys.modules["lm_eval"], "evaluator"):
existing_module = sys.modules.get("lm_eval")

if (
existing_module is not None
and not _is_our_stub(existing_module)
and hasattr(existing_module, "evaluator")
):
return

real_module = _try_import_optional_module("lm_eval")
if real_module is not None and not _is_our_stub(real_module):
real_evaluator_module = _try_import_optional_module("lm_eval.evaluator")
if real_evaluator_module is not None:
setattr(real_module, "evaluator", real_evaluator_module)
return

lm_eval_module = sys.modules.get("lm_eval")
if lm_eval_module is None:
if lm_eval_module is None or not _is_our_stub(lm_eval_module):
lm_eval_module = types.ModuleType("lm_eval")
sys.modules["lm_eval"] = lm_eval_module

setattr(lm_eval_module, _STUB_MARKER, True)

evaluator_module = types.ModuleType("lm_eval.evaluator")
setattr(evaluator_module, _STUB_MARKER, True)

def simple_evaluate(*args, **kwargs):
"""Fail clearly if a test accidentally runs real lm-eval."""
Expand All @@ -96,6 +152,7 @@ def simple_evaluate(*args, **kwargs):
setattr(lm_eval_module, "evaluator", evaluator_module)

utils_module = types.ModuleType("lm_eval.utils")
setattr(utils_module, _STUB_MARKER, True)

def make_table(results):
"""Return a stable string representation for patched evaluation results."""
Expand All @@ -106,7 +163,10 @@ def make_table(results):
setattr(lm_eval_module, "utils", utils_module)

models_module = types.ModuleType("lm_eval.models")
setattr(models_module, _STUB_MARKER, True)

huggingface_module = types.ModuleType("lm_eval.models.huggingface")
setattr(huggingface_module, _STUB_MARKER, True)

class HFLM:
"""Minimal HFLM stub for import-time compatibility."""
Expand Down
Loading