Skip to content

fix: Feature: Add validation for loaded modelopt state files (#1041)#1074

Open
ChenhanYu wants to merge 5 commits intomainfrom
pensieve/fix-issue-1041
Open

fix: Feature: Add validation for loaded modelopt state files (#1041)#1074
ChenhanYu wants to merge 5 commits intomainfrom
pensieve/fix-issue-1041

Conversation

@ChenhanYu
Copy link
Collaborator

@ChenhanYu ChenhanYu commented Mar 19, 2026

Fixes #1041

Summary

Add validation to the load_modelopt_state() function to ensure loaded files contain valid modelopt state objects with expected schema. Currently there is a TODO comment indicating this validation is needed, and missing validation can lead to unclear downstream errors when invalid files are loaded.

Root Cause

The load_modelopt_state() function in modelopt/torch/opt/conversion.py performs torch.load() without any validation of the loaded object's structure. This allows invalid state files to be loaded, causing cryptic errors downstream instead of clear validation errors at load time.

Agent Fix Summary

Successfully added validation for loaded modelopt state files in modelopt/torch/opt/conversion.py. The load_modelopt_state() function now validates:

  1. Loaded object is a dictionary
  2. Required keys 'modelopt_state_dict' and 'modelopt_version' are present
  3. modelopt_state_dict is a list of 2-element tuples
  4. Each tuple contains a string mode name and a dictionary mode state

All validation errors provide clear, informative messages. Validation was tested with:

  • 6 specific validation test cases (all passed)
  • 30 existing chaining tests (all passed)
  • No backward compatibility issues detected

Files Changed

  • modelopt/torch/opt/conversion.py

Reproduction

To reproduce the validation on a Slurm cluster, save these files in nmm-sandbox and run:

uv run slurm.py --yaml services/triage/test_validation_specific.yaml --yes
uv run slurm.py --yaml services/triage/test_load_modelopt_state_validation.yaml --yes
services/triage/test_validation_specific.yaml
job_name: test_validation_specific
pipeline:
  task_0:
    script: services/triage/test_validation_specific.sh
    slurm_config:
      _factory_: "computelab_slurm_factory"
      nodes: 1
services/triage/test_validation_specific.sh
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source ${SCRIPT_DIR}/../service_utils.sh
trap 'error_handler $0 $LINENO' ERR
trap 'exit_handler' EXIT

cd modules/Model-Optimizer

# Create a Python test script
cat > /tmp/test_validation.py << 'EOF'
import torch
import tempfile
import os
import sys
from pathlib import Path

# Add the modelopt module to the path
sys.path.insert(0, '/nemo_run/code/modules/Model-Optimizer')

import modelopt.torch.opt as mto

# Test 1: Invalid file - not a dictionary
print("Test 1: Testing with non-dictionary file...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save([1, 2, 3], temp_file)

try:
    mto.load_modelopt_state(temp_file)
    print("FAILED: Should have raised TypeError for non-dictionary file")
    sys.exit(1)
except TypeError as e:
    if "Expected loaded modelopt state to be a dictionary" in str(e):
        print(f"PASSED: Correctly raised TypeError: {e}")
    else:
        print(f"FAILED: Wrong error message: {e}")
        sys.exit(1)
finally:
    os.unlink(temp_file)

# Test 2: Invalid file - missing keys
print("\nTest 2: Testing with missing required keys...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save({"some_key": "value"}, temp_file)

try:
    mto.load_modelopt_state(temp_file)
    print("FAILED: Should have raised ValueError for missing keys")
    sys.exit(1)
except ValueError as e:
    if "missing required keys" in str(e):
        print(f"PASSED: Correctly raised ValueError: {e}")
    else:
        print(f"FAILED: Wrong error message: {e}")
        sys.exit(1)
finally:
    os.unlink(temp_file)

# Test 3: Invalid file - modelopt_state_dict is not a list
print("\nTest 3: Testing with invalid modelopt_state_dict type...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save({
        "modelopt_state_dict": "not_a_list",
        "modelopt_version": "0.1.0"
    }, temp_file)

try:
    mto.load_modelopt_state(temp_file)
    print("FAILED: Should have raised TypeError for non-list state_dict")
    sys.exit(1)
except TypeError as e:
    if "Expected 'modelopt_state_dict' to be a list" in str(e):
        print(f"PASSED: Correctly raised TypeError: {e}")
    else:
        print(f"FAILED: Wrong error message: {e}")
        sys.exit(1)
finally:
    os.unlink(temp_file)

# Test 4: Invalid file - state_dict entry is not a tuple
print("\nTest 4: Testing with invalid state_dict entry type...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save({
        "modelopt_state_dict": [["not", "a", "tuple"]],
        "modelopt_version": "0.1.0"
    }, temp_file)

try:
    mto.load_modelopt_state(temp_file)
    print("FAILED: Should have raised ValueError for non-tuple entry")
    sys.exit(1)
except ValueError as e:
    if "Expected each entry in 'modelopt_state_dict' to be a tuple of length 2" in str(e):
        print(f"PASSED: Correctly raised ValueError: {e}")
    else:
        print(f"FAILED: Wrong error message: {e}")
        sys.exit(1)
finally:
    os.unlink(temp_file)

# Test 5: Valid modelopt state file
print("\nTest 5: Testing with valid modelopt state file...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save({
        "modelopt_state_dict": [],
        "modelopt_version": "0.1.0"
    }, temp_file)

try:
    result = mto.load_modelopt_state(temp_file)
    if isinstance(result, dict) and "modelopt_state_dict" in result and "modelopt_version" in result:
        print(f"PASSED: Successfully loaded valid modelopt state file")
    else:
        print("FAILED: Loaded result doesn't have expected structure")
        sys.exit(1)
finally:
    os.unlink(temp_file)

# Test 6: Valid modelopt state file with actual mode data
print("\nTest 6: Testing with valid modelopt state file containing mode data...")
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
    temp_file = f.name
    torch.save({
        "modelopt_state_dict": [
            ("quantize", {"config": {}, "metadata": {}})
        ],
        "modelopt_version": "0.1.0"
    }, temp_file)

try:
    result = mto.load_modelopt_state(temp_file)
    if isinstance(result, dict) and "modelopt_state_dict" in result and len(result["modelopt_state_dict"]) == 1:
        print(f"PASSED: Successfully loaded valid modelopt state file with mode data")
    else:
        print("FAILED: Loaded result doesn't have expected structure")
        sys.exit(1)
finally:
    os.unlink(temp_file)

print("\n=== All validation tests passed! ===")
EOF

python /tmp/test_validation.py
report_result "PASS: validation tests"
services/triage/test_load_modelopt_state_validation.yaml
job_name: test_load_modelopt_state_validation
pipeline:
  task_0:
    script: services/triage/test_load_modelopt_state_validation.sh
    slurm_config:
      _factory_: "computelab_slurm_factory"
      nodes: 1
services/triage/test_load_modelopt_state_validation.sh
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source ${SCRIPT_DIR}/../service_utils.sh
trap 'error_handler $0 $LINENO' ERR
trap 'exit_handler' EXIT

cd modules/Model-Optimizer

# Run the test
python -m pytest tests/unit/torch/opt/test_chaining.py -v
report_result "PASS: test_chaining.py"

Auto-generated by pensieve /magic-triage agentic fix — please review before merging.

Summary by CodeRabbit

  • Bug Fixes

    • Improved validation when loading model state files. Invalid or corrupted states are now detected and reported with clear, specific error messages for missing fields, wrong types, or malformed entries to prevent silent failures during load.
  • Tests

    • Added unit tests covering valid and various invalid state file formats and ensuring errors are raised and reported consistently.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 19, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 19, 2026

📝 Walkthrough

Walkthrough

Added explicit runtime validation for loaded modelopt state objects via ModeloptStateManager.validate_modelopt_state() and invoked it from load_modelopt_state() immediately after torch.load, with TypeError/ValueError documented as possible raises.

Changes

Cohort / File(s) Summary
State validation logic
modelopt/torch/opt/conversion.py
Added ModeloptStateManager.validate_modelopt_state(modelopt_state: Any) -> None that enforces: top-level dict with required keys modelopt_state_dict (a list of 2-tuples) and modelopt_version (str); checks tuple shapes and element types; updated load_modelopt_state() to call this validator and document TypeError/ValueError.
Validation tests
tests/unit/torch/opt/test_modelopt_state_validation.py
New unit tests covering successful validation and multiple failure modes (non-dict top-level, missing keys, wrong types for modelopt_version/modelopt_state_dict, invalid tuple shapes, wrong tuple element types). Also tests load_modelopt_state() using temporary .pt files saved/loaded with torch.save/torch.load.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title mentions 'Feature' and 'fix', is somewhat redundant, but accurately describes the main change: adding validation for loaded modelopt state files.
Linked Issues check ✅ Passed The PR fully addresses all coding requirements from issue #1041: validates loaded object is a dict, verifies required keys (modelopt_state_dict and modelopt_version), validates modelopt_state_dict structure, checks mode_name/mode_state types, and raises clear TypeError/ValueError messages.
Out of Scope Changes check ✅ Passed All changes are scoped to the validation feature: conversion.py adds the validation method and updates load_modelopt_state(), and test file comprehensively tests the new validation logic with no extraneous changes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR properly addresses security requirements: torch.load(..., weights_only=False) includes inline comment justifying safety via internally-generated files; no other security anti-patterns present; robust schema validation added.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch pensieve/fix-issue-1041
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

f"The file may not be a valid modelopt state file."
)

# Validate that the dictionary has the expected keys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this validation logic inside the ModelOptStateManager class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if I can make it address the comment and update the PR.

Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
@ChenhanYu
Copy link
Collaborator Author

/ok to test 87e6e8f

@ChenhanYu ChenhanYu marked this pull request as ready for review March 20, 2026 00:11
@ChenhanYu ChenhanYu requested a review from a team as a code owner March 20, 2026 00:11
@ChenhanYu ChenhanYu requested a review from realAsma March 20, 2026 00:11
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 355-358: The long ValueError message in conversion.py (within the
function that validates modelopt_state_dict entries where the raise ValueError
is emitted) violates line-length rules; shorten or split the message into
multiple concatenated strings or build it in a temporary variable so no single
source line exceeds the limit (for example, create a message variable using
multiple shorter f-strings or implicit string concatenation and then raise
ValueError(message)). Ensure you still include the dynamic parts (i,
type(entry).__name__, and the conditional length expression) so the raised error
text remains informative.
- Around line 333-345: The loader currently checks for the presence of
"modelopt_version" but not its type; update the validation after the
required-keys check to verify that modelopt_state["modelopt_version"] is a
string (e.g., isinstance(version, str)) and raise a ValueError with a clear
message if it is not; reference the symbol modelopt_state["modelopt_version"]
and ensure this check occurs before any call to load_state_dict() or any code
that calls version.split(...) so downstream failures are avoided.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8ff4605f-26fa-4c04-babf-9279b69ffc7d

📥 Commits

Reviewing files that changed from the base of the PR and between 839fa3d and 87e6e8f.

📒 Files selected for processing (2)
  • modelopt/torch/opt/conversion.py
  • tests/unit/torch/opt/test_modelopt_state_validation.py

Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
Signed-off-by: Pensieve Bot <pensieve-bot@nvidia.com>
@ChenhanYu
Copy link
Collaborator Author

/ok to test 7289799

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/opt/conversion.py (1)

607-614: Security comment present; consider clarifying the trust assumption in documentation.

The inline comment at line 607 justifies weights_only=False by stating the file is "ModelOpt-generated state_dict, not untrusted user input." While this comment satisfies the coding guideline requirement, the function signature accepts any file path, meaning callers could potentially pass untrusted files.

The post-load validation (line 613) provides defense-in-depth by failing fast on malformed data, but note that arbitrary code execution via pickle can occur during torch.load() itself, before validation runs.

Consider adding a note in the docstring warning callers that this function assumes trusted input, or alternatively exposing weights_only as a caller-configurable parameter defaulting to a safer value.

Optional: Document trust assumption in docstring
 def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]:
     """Load the modelopt state from a file.
 
+    .. warning::
+        This function uses ``weights_only=False`` by default for ``torch.load()``.
+        Only use with trusted, ModelOpt-generated state files.
+
     Args:
         modelopt_state_path: Target file location.
         **kwargs: additional args for ``torch.load()``.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/opt/conversion.py` around lines 607 - 614, The comment
justifying kwargs.setdefault("weights_only", False) is insufficient because
torch.load(modelopt_state_path, **kwargs) can execute code before validation;
update the surrounding function's docstring to explicitly state the function
assumes the supplied modelopt_state_path is trusted and may execute code during
torch.load, and/or change the function signature to accept a caller-configurable
weights_only parameter (defaulting to True for safety) and use that parameter
instead of hardcoding kwargs.setdefault("weights_only", False); ensure
references to torch.load and ModeloptStateManager.validate_modelopt_state remain
intact so callers and reviewers can find the load/validate sequence.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/opt/conversion.py`:
- Around line 607-614: The comment justifying kwargs.setdefault("weights_only",
False) is insufficient because torch.load(modelopt_state_path, **kwargs) can
execute code before validation; update the surrounding function's docstring to
explicitly state the function assumes the supplied modelopt_state_path is
trusted and may execute code during torch.load, and/or change the function
signature to accept a caller-configurable weights_only parameter (defaulting to
True for safety) and use that parameter instead of hardcoding
kwargs.setdefault("weights_only", False); ensure references to torch.load and
ModeloptStateManager.validate_modelopt_state remain intact so callers and
reviewers can find the load/validate sequence.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cc3fc581-39af-4e91-a131-6d6b9ae213f5

📥 Commits

Reviewing files that changed from the base of the PR and between 87e6e8f and 7289799.

📒 Files selected for processing (2)
  • modelopt/torch/opt/conversion.py
  • tests/unit/torch/opt/test_modelopt_state_validation.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: Add validation for loaded modelopt state files

2 participants