[Serialization]: remove explicit weights_only default from safe_load to allow user to bypass if needed #1279
Conversation
Allow torch>=2.6's built-in default (weights_only=True) to take effect naturally, so users can override via TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 when they trust a checkpoint but hit pickle.UnpicklingError. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/utils/serialization.py (1)
56-65:⚠️ Potential issue | 🟠 MajorStrengthen
safe_loadwith explicitweights_only=Truefor defensive clarity.The codebase enforces
torch>=2.8globally (pyproject.tomlline 41), which meanstorch.load()defaults toweights_only=Truewhen unspecified. However, per SECURITY.md best practices, prefer explicit safe defaults over implicit reliance on PyTorch's version-dependent behavior. Additionally, the current implementation allows callers to override or bypass this safety via environment variables without an explicit guard in the function itself.🔧 Recommended hardening
def safe_load(f: str | os.PathLike | BinaryIO | bytes, **kwargs) -> Any: """Load a checkpoint securely using ``weights_only=True`` by default. NOTE: We dont set default ``weights_only`` (interpret as True for torch>=2.6) so you can override it with ``export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` if you see ``pickle.UnpicklingError`` and trust the checkpoint. """ if isinstance(f, (bytes, bytearray)): f = BytesIO(f) + # Preserve secure default unless user explicitly overrides or opts out via env var. + if "weights_only" not in kwargs and os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD") != "1": + kwargs["weights_only"] = True + return torch.load(f, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/serialization.py` around lines 56 - 65, The safe_load function should explicitly set weights_only=True by default to avoid relying on torch's version defaults; update safe_load (the function using BytesIO and calling torch.load) to call kwargs.setdefault('weights_only', True) (so callers can still explicitly override by passing weights_only) before invoking torch.load(f, **kwargs), preserving the existing bytes/BytesIO handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@modelopt/torch/utils/serialization.py`:
- Around line 56-65: The safe_load function should explicitly set
weights_only=True by default to avoid relying on torch's version defaults;
update safe_load (the function using BytesIO and calling torch.load) to call
kwargs.setdefault('weights_only', True) (so callers can still explicitly
override by passing weights_only) before invoking torch.load(f, **kwargs),
preserving the existing bytes/BytesIO handling.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: e38c95c9-3f5b-46d5-bbfa-d28144ac16bf
📒 Files selected for processing (2)
modelopt/torch/utils/serialization.pytests/unit/torch/utils/test_serialization.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1279 +/- ##
==========================================
+ Coverage 75.58% 76.56% +0.97%
==========================================
Files 459 459
Lines 48613 48612 -1
==========================================
+ Hits 36745 37219 +474
+ Misses 11868 11393 -475
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
kwargs.setdefault("weights_only", True)call fromsafe_load, deferring to torch's built-in default (which isTruefor torch>=2.6)TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1env var when they trust a checkpoint but hitpickle.UnpicklingErrorTest plan
python -m pytest tests/unit/torch/utils/test_serialization.py -v🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
Tests