Skip to content

[1/2Refactor] speculative decoding: use mto config subsystem#1328

Merged
h-guo18 merged 2 commits into
mainfrom
haoguo/spec-mto-config
May 17, 2026
Merged

[1/2Refactor] speculative decoding: use mto config subsystem#1328
h-guo18 merged 2 commits into
mainfrom
haoguo/spec-mto-config

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 23, 2026

What does this PR do?

Type of change: new feature

Port the speculative-decoding example to ModelOpt's recipe/config subsystem: model / data / training / <algo> now load from a single YAML with Pydantic validation and OmegaConf dotlist overrides. Adds built-in eagle3 / dflash recipes, drops the redundant training.mode field (inferred from recipe class), and shrinks main.py by ~145 lines (−208 / +63).

JIRA: OMNIML-3859

Usage

python main.py --config general/speculative_decoding/eagle3 \
    model.model_name_or_path=meta-llama/Llama-3.2-1B \
    data.data_path=train.jsonl \
    training.output_dir=ckpts/test

Testing

  • pytest tests/unit/recipe/test_loader.py — new coverage for Eagle / DFlash YAML loading, dotlist overrides, and field-level validation.
  • Smoke-trained both built-in eagle3 and dflash recipes end-to-end.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ — main.py CLI switched to --config <recipe> (+ dotlist overrides); the old argparse flags are removed.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A — no new deps (pydantic, omegaconf already in core).
  • Did you write any new necessary tests?: ✅ — tests/unit/recipe/test_loader.py.
  • Did you update Changelog?: ❌ — to be added.

Additional Information

Follow-up to the modelopt.recipe subsystem introduced for PTQ; this PR extends the same declarative-YAML pattern to speculative decoding (Eagle3 / DFlash / Medusa).

Summary by CodeRabbit

  • New Features

    • Added typed speculative-decoding recipe support for EAGLE, DFlash, and Medusa; CLI dotlist overrides supported for single-file recipes.
    • Trainer/config schema extended with speculative-training fields and draft-vocab cache loading for Eagle.
  • Bug Fixes

    • Offline training no longer mutates model configs; loader enforces required algorithm sections and prints recipe/config only on the primary process.
    • Reduced noisy per-rank logging by restricting status output to the primary process.
  • Tests

    • Expanded tests for recipe loading, dotlist overrides, validation strictness, and error cases.
  • Documentation

    • Recipe YAMLs updated with metadata and usage notes.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Typed single-file speculative-decoding recipes and loader dotlist overrides are added; loader dispatches on metadata.recipe_type to create validated EAGLE/DFlash/Medusa recipe objects with typed model, data, and HF training sections. HF training args schema, eagle draft-vocab loader, and the example script were updated to consume the typed recipe.

Changes

Cohort / File(s) Summary
Recipe Configuration
modelopt/recipe/config.py
Adds SPECULATIVE_EAGLE, SPECULATIVE_DFLASH, SPECULATIVE_MEDUSA and new recipe models: ModelOptSpeculativeRecipeBase, ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe with typed model, data, training sections and algorithm-specific fields; offline flags set at recipe level and some cross-field validation added.
Recipe Loading
modelopt/recipe/loader.py
load_recipe(..., overrides=None) added; supports CLI-style dotlist overrides for single-file recipes (applied via OmegaConf), rejects overrides for directory-format recipes, and instantiates concrete speculative recipe classes based on metadata.recipe_type, erroring when required top-level algorithm sections are missing.
HF Training Args Schema
modelopt/torch/speculative/plugins/hf_training_args.py
New Pydantic-only schema module providing ModelArguments, DataArguments, TrainingArguments (permits HF Trainer extras) and speculative fields (training_seq_len, estimate_ar, ar_validate_steps, answer_only_loss, cp_size, dp_shard_size, parallelism_config) with runtime validation to infer dp_shard_size and optionally populate parallelism_config.
Speculative Torch Config Changes
modelopt/torch/speculative/config.py
Removes context-based auto-derivation validators for dflash_offline/eagle_offline, defers mask-token required checks when tokenizer absent, removes rope-vs-training-seq warning; field docs updated to indicate offline flags are recipe-derived.
HF EAGLE Plugin
modelopt/torch/speculative/plugins/hf_eagle.py
Replaced unqualified prints with print_rank_0 and added HFEagleModel.load_draft_vocab_cache(model, d2t_path: str) to validate and conditionally load a draft-vocab cache tensor into model.eagle_module.d2t.
Example Script
examples/speculative_decoding/main.py
Refactors script to load typed recipe via load_recipe() (with overrides), build HfTrainingArguments from recipe.training.model_dump(), route conversion by concrete recipe type (Medusa/Eagle/DFlash), remove prior mode-based dispatch and auto-parallelism mutation, and print recipe/config only on master.
Recipe YAMLs
modelopt_recipes/general/speculative_decoding/dflash.yaml, modelopt_recipes/general/speculative_decoding/eagle3.yaml
Converted to full modelopt recipes: add metadata with recipe_type and description, update header comments, and remove explicit training.mode fields.
Tests
tests/unit/recipe/test_loader.py
Expanded tests: _apply_dotlist unit tests (parsing, scientific notation, nested creation, immutability, error cases), load_recipe(..., overrides=...) behavior, positive tests for loading typed EAGLE/DFlash recipes, and negative tests for missing algorithm sections and invalid typed fields.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Main as Speculative Main Script
    participant Loader as load_recipe()
    participant Parser as YAML Parser
    participant Validator as Pydantic Validator

    User->>Main: run with recipe_path + overrides
    Main->>Loader: load_recipe(path, overrides)
    Loader->>Parser: parse YAML file
    Parser-->>Loader: parsed dict
    alt overrides provided
        Loader->>Loader: apply dotlist overrides (OmegaConf)
    end
    Loader->>Validator: select recipe class by metadata.recipe_type
    Validator->>Validator: validate model / data / training sections
    Validator-->>Loader: validated recipe instance
    Loader-->>Main: return typed recipe
    Main->>Main: construct HfTrainingArguments from recipe.training
    Main->>Main: route conversion/execution by recipe type (eagle/dflash/medusa)
Loading
sequenceDiagram
    participant Script as Main Script
    participant Recipe as Recipe Config
    participant HfArgs as HfTrainingArguments
    participant Trainer as HF Trainer

    Script->>Recipe: load_recipe(path)
    Recipe-->>Script: typed recipe
    Script->>HfArgs: HfTrainingArguments.from_dict(recipe.training.model_dump())
    HfArgs->>HfArgs: infer dp_shard_size from WORLD_SIZE (if needed)
    HfArgs->>HfArgs: validate speculative fields (training_seq_len, estimate_ar, ...)
    HfArgs-->>Script: validated hf args
    rect rgba(100,150,200,0.5)
        Script->>Trainer: perform recipe-specific conversion & training
    end
    Trainer-->>Script: complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Changes introduce only torch.load calls with weights_only=True, no unsafe pickle, numpy.load with allow_pickle, hardcoded trust_remote_code, eval/exec, nosec bypasses, or non-permissive dependencies.
Title check ✅ Passed The title accurately reflects the main objective of refactoring speculative decoding to use the ModelOpt config subsystem, which is the primary focus of the PR across all changed files.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/spec-mto-config

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-17 00:29 UTC

@h-guo18 h-guo18 changed the title mto config subsystem speculative decoding: use mto config subsystem Apr 23, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 23, 2026

Codecov Report

❌ Patch coverage is 90.00000% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.89%. Comparing base (81c4fb2) to head (522a8c1).

Files with missing lines Patch % Lines
modelopt/torch/speculative/plugins/hf_eagle.py 43.75% 9 Missing ⚠️
modelopt/recipe/loader.py 93.75% 2 Missing ⚠️
...lopt/torch/speculative/plugins/hf_training_args.py 96.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1328      +/-   ##
==========================================
- Coverage   76.91%   76.89%   -0.03%     
==========================================
  Files         473      474       +1     
  Lines       51439    51506      +67     
==========================================
+ Hits        39566    39605      +39     
- Misses      11873    11901      +28     
Flag Coverage Δ
examples 41.73% <81.66%> (+1.02%) ⬆️
gpu 59.74% <60.00%> (-0.60%) ⬇️
regression 15.22% <76.66%> (+0.31%) ⬆️
unit 52.64% <88.33%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread examples/speculative_decoding/main.py
Comment thread examples/speculative_decoding/main.py
@h-guo18 h-guo18 self-assigned this Apr 23, 2026
@h-guo18 h-guo18 marked this pull request as ready for review April 23, 2026 22:57
@h-guo18 h-guo18 requested review from a team as code owners April 23, 2026 22:57
@h-guo18 h-guo18 requested review from sychen52 and yeyu-nvidia April 23, 2026 22:57
@h-guo18 h-guo18 changed the title speculative decoding: use mto config subsystem [Refactor] speculative decoding: use mto config subsystem Apr 23, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

153-186: ⚠️ Potential issue | 🟠 Major

Medusa recipe skips data_module creation, causing an undefined variable error.

When recipe is ModelOptMedusaRecipe, the code at line 154 calls mtsp.convert() but then falls through to line 203-211 where data_module is only created for ModelOptEagleRecipe or ModelOptDFlashRecipe. This leaves data_module undefined, causing a NameError at line 226 when passed to EagleTrainerWithAccLog.

Proposed fix: Add Medusa to data_module creation or handle separately
     print_rank_0("Loading dataset...")
     is_dflash = isinstance(recipe, ModelOptDFlashRecipe)
-    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)):
+    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)):
         data_module = make_speculative_data_module(
             tokenizer,
             recipe.data,
             train_len=training_args.training_seq_len,
             answer_only_loss=training_args.answer_only_loss,
             shift_labels=not is_dflash,
         )
+    else:
+        raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 153 - 186, The Medusa
branch (when recipe is ModelOptMedusaRecipe) calls mtsp.convert but never
creates the data_module used later by EagleTrainerWithAccLog, causing a
NameError; update the ModelOptMedusaRecipe branch to either construct the same
data_module as done for ModelOptEagleRecipe/ModelOptDFlashRecipe (using
recipe.data and the existing data-module factory logic) or explicitly set
data_module to an appropriate value (or None) and handle that case before
calling EagleTrainerWithAccLog so data_module is always defined; refer to the
branches around ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe,
mtsp.convert, recipe.data, and EagleTrainerWithAccLog to place the fix.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)

61-76: Field synchronization between HfTrainingArguments and SpecTrainingArgs should be explicit.

The docstring correctly notes these field sets "MUST stay in sync" with modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments. Consider adding a test or assertion to catch drift between these two definitions automatically.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 61 - 76, Add an automated
check to ensure HfTrainingArguments and the plugin TrainingArguments stay in
sync: write a small unit test or an import-time assertion that imports
HfTrainingArguments (examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 96-130: The _fill_parallelism validator currently sets world_size
using int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields
0 on CPU-only machines and leads to dp_shard_size becoming 0; change that logic
to treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 153-186: The Medusa branch (when recipe is ModelOptMedusaRecipe)
calls mtsp.convert but never creates the data_module used later by
EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe
branch to either construct the same data_module as done for
ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing
data-module factory logic) or explicitly set data_module to an appropriate value
(or None) and handle that case before calling EagleTrainerWithAccLog so
data_module is always defined; refer to the branches around
ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert,
recipe.data, and EagleTrainerWithAccLog to place the fix.

---

Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 61-76: Add an automated check to ensure HfTrainingArguments and
the plugin TrainingArguments stay in sync: write a small unit test or an
import-time assertion that imports HfTrainingArguments
(examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 838ec6d9-91a7-46a6-a197-701d6268d76c

📥 Commits

Reviewing files that changed from the base of the PR and between c796611 and 69f3c40.

📒 Files selected for processing (7)
  • examples/speculative_decoding/main.py
  • modelopt/recipe/config.py
  • modelopt/recipe/loader.py
  • modelopt/torch/speculative/plugins/hf_training_args.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
  • tests/unit/recipe/test_loader.py

Comment thread modelopt/torch/speculative/plugins/hf_training_args.py Outdated
@shengliangxu
Copy link
Copy Markdown
Collaborator

As discussed in the meeting, let's strip off the config override part into a separate change.

Comment thread modelopt/torch/speculative/plugins/hf_training_args.py Outdated
@h-guo18 h-guo18 force-pushed the haoguo/spec-mto-config branch from 69f3c40 to 3bf010b Compare April 30, 2026 00:46
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 97-114: The nested config fields (model, data, training) currently
instantiate SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update
ModeloptField to accept a default_factory alternative (or remove the assertion
forcing default so callers can use pydantic.Field) so you can switch those
fields to lazy defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.

In `@modelopt/recipe/loader.py`:
- Around line 109-118: In _apply_dotlist, avoid eagerly resolving OmegaConf
interpolations by changing the OmegaConf.to_container call so it does not
resolve interpolations; specifically, update the call in function _apply_dotlist
to use resolve=False instead of resolve=True when converting the merged
OmegaConf to a plain dict so user-supplied overrides (e.g., ${...} patterns) are
preserved and not expanded before Pydantic validation.

In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 90-106: The _fill_parallelism model_validator currently divides by
cp_size and can accept non-positive cp_size/dp_shard_size from user configs;
update _fill_parallelism (on TrainingArguments) to first validate that cp_size
is an int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also
verify computed world_size is >=1), and raise a clear ValueError (so Pydantic
surfaces it) if these checks fail; only after these guards compute world_size
and set self.dp_shard_size = world_size // self.cp_size when dp_shard_size is
None to avoid ZeroDivisionError and invalid runtime state.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 50f5aa9c-3c49-42a0-be51-0ab116444297

📥 Commits

Reviewing files that changed from the base of the PR and between 69f3c40 and 3bf010b.

📒 Files selected for processing (7)
  • examples/speculative_decoding/main.py
  • modelopt/recipe/config.py
  • modelopt/recipe/loader.py
  • modelopt/torch/speculative/plugins/hf_training_args.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
  • tests/unit/recipe/test_loader.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • tests/unit/recipe/test_loader.py

Comment thread examples/speculative_decoding/main.py
Comment thread modelopt/recipe/config.py
Comment thread modelopt/recipe/loader.py Outdated
Comment thread modelopt/torch/speculative/plugins/hf_training_args.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/recipe/config.py (1)

98-115: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Use default_factory for nested config defaults to avoid import-time instantiation.

Line 99, Line 105, Line 111 (and similarly Line 124, Line 153, Line 171) instantiate nested config objects at import time via default=...(). This is the same issue already raised earlier and is still unresolved. Prefer lazy construction with default_factory (which may require extending ModeloptField to pass through default_factory).

Suggested direction
-from pydantic import field_validator, model_validator
+from pydantic import Field, field_validator, model_validator

-    model: SpecModelArgs = ModeloptField(default=SpecModelArgs(), ...)
+    model: SpecModelArgs = Field(default_factory=SpecModelArgs, ...)

-    data: SpecDataArgs = ModeloptField(default=SpecDataArgs(), ...)
+    data: SpecDataArgs = Field(default_factory=SpecDataArgs, ...)

-    training: SpecTrainingArgs = ModeloptField(default=SpecTrainingArgs(), ...)
+    training: SpecTrainingArgs = Field(default_factory=SpecTrainingArgs, ...)

If ModeloptField must be retained, update ModeloptField to accept default_factory as an alternative to default.

#!/bin/bash
# Verify eager constructor defaults in recipe config and whether ModeloptField supports default_factory.
python - <<'PY'
import ast, pathlib
p = pathlib.Path("modelopt/recipe/config.py")
tree = ast.parse(p.read_text())
for n in ast.walk(tree):
    if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == "ModeloptField":
        for kw in n.keywords:
            if kw.arg == "default" and isinstance(kw.value, ast.Call):
                fn = kw.value.func
                name = getattr(fn, "id", getattr(fn, "attr", type(fn).__name__))
                print(f"Line {n.lineno}: eager default via {name}()")
PY

rg -n -C2 "def ModeloptField|default_factory|PydanticUndefined|assert .*default" modelopt/torch/opt/config.py

Also applies to: 123-124, 152-153, 170-171

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 98 - 115, The ModeloptField usages
for the nested config attributes model, data, and training (and other spots
noted) currently instantiate objects at import time via default=SpecModelArgs(),
default=SpecDataArgs(), default=SpecTrainingArgs(), which causes eager
construction; change these to use lazy construction by supporting and passing
default_factory=SpecModelArgs, default_factory=SpecDataArgs,
default_factory=SpecTrainingArgs instead of calling the constructors, and if
ModeloptField does not yet accept default_factory update the ModeloptField
implementation to accept a default_factory kwarg, store/forward it to the
underlying field construction logic (mirroring pydantic/dataclasses semantics),
and ensure validate_default behavior still works with the factory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 27-32: The module currently performs hard imports of
speculative/HF classes DFlashConfig, EagleConfig, MedusaConfig and
SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time; change this to
lazy/plugin loading by moving those imports behind typing.TYPE_CHECKING for type
annotations and using the project's import_plugin() (or an equivalent lazy
import helper) where the classes are actually needed at runtime (e.g., in
factory functions or recipe registration). Keep type-only references via from
typing import TYPE_CHECKING and string annotations or if TYPE_CHECKING: import
the speculative classes, and replace direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.

---

Duplicate comments:
In `@modelopt/recipe/config.py`:
- Around line 98-115: The ModeloptField usages for the nested config attributes
model, data, and training (and other spots noted) currently instantiate objects
at import time via default=SpecModelArgs(), default=SpecDataArgs(),
default=SpecTrainingArgs(), which causes eager construction; change these to use
lazy construction by supporting and passing default_factory=SpecModelArgs,
default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of
calling the constructors, and if ModeloptField does not yet accept
default_factory update the ModeloptField implementation to accept a
default_factory kwarg, store/forward it to the underlying field construction
logic (mirroring pydantic/dataclasses semantics), and ensure validate_default
behavior still works with the factory.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 143fdb4d-cd06-457c-b8b0-cbace1bf21cc

📥 Commits

Reviewing files that changed from the base of the PR and between 23094e4 and d3ffd94.

📒 Files selected for processing (1)
  • modelopt/recipe/config.py

Comment thread modelopt/recipe/config.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)

186-209: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

data_module is undefined for Medusa recipes and will crash trainer construction.

At Line 186, data_module is created only for Eagle/DFlash. For ModelOptMedusaRecipe, Line 208 expands **data_module before assignment (UnboundLocalError).

Suggested fix
-    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)):
+    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)):
         data_module = make_speculative_data_module(
             tokenizer,
             recipe.data,
             train_len=training_args.training_seq_len,
             answer_only_loss=training_args.answer_only_loss,
             shift_labels=not is_dflash,
         )
+    else:
+        raise ValueError(f"Unsupported speculative recipe type for dataset setup: {type(recipe).__name__}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 186 - 209, The variable
data_module is only set inside the isinstance(recipe, (ModelOptEagleRecipe,
ModelOptDFlashRecipe)) branch but later always expanded into
EagleTrainerWithAccLog via **data_module, causing an UnboundLocalError for
ModelOptMedusaRecipe; fix by ensuring data_module is always defined before
trainer construction—either initialize data_module = {} prior to the conditional
or add an else branch that creates/returns the appropriate Medusa data module
(e.g., a make_medusa_data_module or equivalent) so that **data_module is safe
for all recipe types (references: recipe, ModelOptEagleRecipe,
ModelOptDFlashRecipe, ModelOptMedusaRecipe, make_speculative_data_module,
EagleTrainerWithAccLog, **data_module).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 196-201: The code references eagle_cfg when deciding to append
LoRAWarmupCallback but eagle_cfg may be undefined in the checkpoint-resume path;
initialize or load eagle_cfg before that conditional so the .get calls are safe.
Concretely, ensure eagle_cfg is defined (e.g., set eagle_cfg = {} or populate it
from the checkpoint metadata in the resume branch) prior to the callback gating
logic that checks isinstance(recipe, ModelOptEagleRecipe) and
eagle_cfg.get(...), so
LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]) only executes when
eagle_cfg exists and contains the key.
- Around line 109-113: The code mutates training_args.parallelism_config when
training_args.cp_size > 1 but HfTrainingArguments lacks that field; either add a
parallelism_config field and an initializer/validator to HfTrainingArguments
(similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.

In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 257-264: load_draft_vocab_cache currently assigns torch.load
directly to model.eagle_module.d2t which can cause device/dtype/shape
mismatches; instead, load the saved tensor using map_location to the
model.eagle_module.d2t device, validate that the loaded tensor's dtype and numel
match model.eagle_module.d2t (and that model.eagle_config.draft_vocab_size <
model.eagle_config.vocab_size condition holds), then copy the data into the
existing buffer in-place (e.g., via copy_()) so the registered buffer keeps its
device and shape; update load_draft_vocab_cache and references to
model.eagle_module.d2t accordingly.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 186-209: The variable data_module is only set inside the
isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later
always expanded into EagleTrainerWithAccLog via **data_module, causing an
UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is
always defined before trainer construction—either initialize data_module = {}
prior to the conditional or add an else branch that creates/returns the
appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent)
so that **data_module is safe for all recipe types (references: recipe,
ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe,
make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 16de0a2a-dbe7-453e-a197-0049f864b524

📥 Commits

Reviewing files that changed from the base of the PR and between d3ffd94 and a2afa7e.

📒 Files selected for processing (2)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/plugins/hf_eagle.py

Comment thread examples/speculative_decoding/main.py
Comment thread examples/speculative_decoding/main.py Outdated
Comment thread modelopt/torch/speculative/plugins/hf_eagle.py Outdated
@h-guo18 h-guo18 force-pushed the haoguo/spec-mto-config branch from 0915687 to 8719802 Compare May 7, 2026 22:24
@h-guo18 h-guo18 changed the title [Refactor] speculative decoding: use mto config subsystem [1/2Refactor] speculative decoding: use mto config subsystem May 7, 2026
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review

@claude
Copy link
Copy Markdown

claude Bot commented May 15, 2026

Claude review summary

Findings: CRITICAL: 1, IMPORTANT: 1, SUGGESTION: 1

Most impactful

  • eagle_cfg regression on resume-from-checkpoint (examples/speculative_decoding/main.py:236-239): eagle_cfg is only assigned inside the not-checkpoint branch (line 190) but is unconditionally read after the branch. EAGLE runs that resume from a checkpoint will hit NameError. Fix is to read recipe.eagle.eagle_base_lora{,_warmup_steps} directly. Adding a checkpoint-resume test for EAGLE would prevent this class of regression.
  • recipe_type field/metadata can diverge (modelopt/recipe/config.py:105-206): subclasses now declare recipe_type as a Pydantic field while the parent class still derives it as a property from metadata["recipe_type"]. With the default metadata carrying RecipeType.PTQ, instantiating any speculative subclass without an explicit metadata override produces inconsistent views. The loader avoids the problem only because it pre-dispatches via _peek_recipe_type; downstream/test usage will hit it.

Risk assessment

The recipe-subsystem refactor is well-scoped and the new test coverage is solid (dotlist parsing, type validation, missing-section errors). The CRITICAL finding above is a single-line regression that breaks a real user path (resume_from_checkpoint) and should block merge until fixed and covered by a test. The IMPORTANT finding is a latent inconsistency that will bite anyone constructing recipes in code; worth fixing in this PR or as an immediate follow-up.

The schema duplication between HfTrainingArguments (example) and SpecTrainingArgs (recipe) is a smaller maintenance concern flagged as a SUGGESTION.

Comment thread examples/speculative_decoding/main.py Outdated
Comment thread modelopt/recipe/config.py
Comment thread examples/speculative_decoding/main.py Outdated
@h-guo18 h-guo18 force-pushed the haoguo/spec-mto-config branch from 7ea8ad4 to 55b270c Compare May 16, 2026 04:05
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/spec-mto-config branch from 55b270c to 4b44e4c Compare May 16, 2026 04:06
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 merged commit 7038dec into main May 17, 2026
48 checks passed
@h-guo18 h-guo18 deleted the haoguo/spec-mto-config branch May 17, 2026 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants