fix: Feature: Add validation for loaded modelopt state files (#1041)#1074
fix: Feature: Add validation for loaded modelopt state files (#1041)#1074
Conversation
📝 WalkthroughWalkthroughAdded explicit runtime validation for loaded modelopt state objects via Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
modelopt/torch/opt/conversion.py
Outdated
| f"The file may not be a valid modelopt state file." | ||
| ) | ||
|
|
||
| # Validate that the dictionary has the expected keys |
There was a problem hiding this comment.
Can we move this validation logic inside the ModelOptStateManager class?
There was a problem hiding this comment.
Let me see if I can make it address the comment and update the PR.
Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
|
/ok to test 87e6e8f |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 355-358: The long ValueError message in conversion.py (within the
function that validates modelopt_state_dict entries where the raise ValueError
is emitted) violates line-length rules; shorten or split the message into
multiple concatenated strings or build it in a temporary variable so no single
source line exceeds the limit (for example, create a message variable using
multiple shorter f-strings or implicit string concatenation and then raise
ValueError(message)). Ensure you still include the dynamic parts (i,
type(entry).__name__, and the conditional length expression) so the raised error
text remains informative.
- Around line 333-345: The loader currently checks for the presence of
"modelopt_version" but not its type; update the validation after the
required-keys check to verify that modelopt_state["modelopt_version"] is a
string (e.g., isinstance(version, str)) and raise a ValueError with a clear
message if it is not; reference the symbol modelopt_state["modelopt_version"]
and ensure this check occurs before any call to load_state_dict() or any code
that calls version.split(...) so downstream failures are avoided.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8ff4605f-26fa-4c04-babf-9279b69ffc7d
📒 Files selected for processing (2)
modelopt/torch/opt/conversion.pytests/unit/torch/opt/test_modelopt_state_validation.py
Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
|
/ok to test 7289799 |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/opt/conversion.py (1)
607-614: Security comment present; consider clarifying the trust assumption in documentation.The inline comment at line 607 justifies
weights_only=Falseby stating the file is "ModelOpt-generated state_dict, not untrusted user input." While this comment satisfies the coding guideline requirement, the function signature accepts any file path, meaning callers could potentially pass untrusted files.The post-load validation (line 613) provides defense-in-depth by failing fast on malformed data, but note that arbitrary code execution via pickle can occur during
torch.load()itself, before validation runs.Consider adding a note in the docstring warning callers that this function assumes trusted input, or alternatively exposing
weights_onlyas a caller-configurable parameter defaulting to a safer value.Optional: Document trust assumption in docstring
def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]: """Load the modelopt state from a file. + .. warning:: + This function uses ``weights_only=False`` by default for ``torch.load()``. + Only use with trusted, ModelOpt-generated state files. + Args: modelopt_state_path: Target file location. **kwargs: additional args for ``torch.load()``.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/opt/conversion.py` around lines 607 - 614, The comment justifying kwargs.setdefault("weights_only", False) is insufficient because torch.load(modelopt_state_path, **kwargs) can execute code before validation; update the surrounding function's docstring to explicitly state the function assumes the supplied modelopt_state_path is trusted and may execute code during torch.load, and/or change the function signature to accept a caller-configurable weights_only parameter (defaulting to True for safety) and use that parameter instead of hardcoding kwargs.setdefault("weights_only", False); ensure references to torch.load and ModeloptStateManager.validate_modelopt_state remain intact so callers and reviewers can find the load/validate sequence.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 607-614: The comment justifying kwargs.setdefault("weights_only",
False) is insufficient because torch.load(modelopt_state_path, **kwargs) can
execute code before validation; update the surrounding function's docstring to
explicitly state the function assumes the supplied modelopt_state_path is
trusted and may execute code during torch.load, and/or change the function
signature to accept a caller-configurable weights_only parameter (defaulting to
True for safety) and use that parameter instead of hardcoding
kwargs.setdefault("weights_only", False); ensure references to torch.load and
ModeloptStateManager.validate_modelopt_state remain intact so callers and
reviewers can find the load/validate sequence.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cc3fc581-39af-4e91-a131-6d6b9ae213f5
📒 Files selected for processing (2)
modelopt/torch/opt/conversion.pytests/unit/torch/opt/test_modelopt_state_validation.py
Fixes #1041
Summary
Add validation to the
load_modelopt_state()function to ensure loaded files contain valid modelopt state objects with expected schema. Currently there is a TODO comment indicating this validation is needed, and missing validation can lead to unclear downstream errors when invalid files are loaded.Root Cause
The
load_modelopt_state()function inmodelopt/torch/opt/conversion.pyperformstorch.load()without any validation of the loaded object's structure. This allows invalid state files to be loaded, causing cryptic errors downstream instead of clear validation errors at load time.Agent Fix Summary
Successfully added validation for loaded modelopt state files in
modelopt/torch/opt/conversion.py. Theload_modelopt_state()function now validates:All validation errors provide clear, informative messages. Validation was tested with:
Files Changed
modelopt/torch/opt/conversion.pyReproduction
To reproduce the validation on a Slurm cluster, save these files in nmm-sandbox and run:
services/triage/test_validation_specific.yamlservices/triage/test_validation_specific.shservices/triage/test_load_modelopt_state_validation.yamlservices/triage/test_load_modelopt_state_validation.shAuto-generated by pensieve
/magic-triageagentic fix — please review before merging.Summary by CodeRabbit
Bug Fixes
Tests