diff --git a/modelopt/torch/utils/serialization.py b/modelopt/torch/utils/serialization.py index da16f7514c..dc880b86b8 100644 --- a/modelopt/torch/utils/serialization.py +++ b/modelopt/torch/utils/serialization.py @@ -54,9 +54,11 @@ def safe_save(obj: Any, f: str | os.PathLike | BinaryIO, **kwargs) -> None: def safe_load(f: str | os.PathLike | BinaryIO | bytes, **kwargs) -> Any: - """Load a checkpoint securely using weights_only=True by default.""" - kwargs.setdefault("weights_only", True) + """Load a checkpoint securely using ``weights_only=True`` by default. + NOTE: We dont set default ``weights_only`` (interpret as True for torch>=2.6) so you can override it with + ``export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` if you see ``pickle.UnpicklingError`` and trust the checkpoint. + """ if isinstance(f, (bytes, bytearray)): f = BytesIO(f) diff --git a/tests/unit/torch/utils/test_serialization.py b/tests/unit/torch/utils/test_serialization.py index 32851d3a09..cb3233739c 100644 --- a/tests/unit/torch/utils/test_serialization.py +++ b/tests/unit/torch/utils/test_serialization.py @@ -16,7 +16,9 @@ """Tests for Modelopt's serialization utilities.""" from io import BytesIO +from pickle import UnpicklingError +import pytest import torch from modelopt.torch.opt.config import ModeloptBaseConfig @@ -70,3 +72,25 @@ def test_safe_load_with_path(tmp_path): loaded_state = safe_load(file_path) assert loaded_state["data"] == 42 + + +class _UnsafeObj: + """Not registered in torch safe globals — unpickling fails with weights_only=True.""" + + def __init__(self, v): + self.v = v + + +def test_safe_load_env_var_bypasses_weights_only(tmp_path, monkeypatch): + """Verify TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 allows safe_load to load objects unsafe for weights_only.""" + file_path = tmp_path / "unsafe.pt" + torch.save({"obj": _UnsafeObj(42)}, file_path) + + # Always fails when weights_only is not set (default=True) + with pytest.raises(UnpicklingError): + safe_load(file_path) + + # With the env var, safe_load (no explicit weights_only) defers to torch's default=False + monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + loaded = safe_load(file_path) + assert loaded["obj"].v == 42