[1/2Refactor] speculative decoding: use mto config subsystem#1328
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughTyped single-file speculative-decoding recipes and loader dotlist overrides are added; loader dispatches on Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Main as Speculative Main Script
participant Loader as load_recipe()
participant Parser as YAML Parser
participant Validator as Pydantic Validator
User->>Main: run with recipe_path + overrides
Main->>Loader: load_recipe(path, overrides)
Loader->>Parser: parse YAML file
Parser-->>Loader: parsed dict
alt overrides provided
Loader->>Loader: apply dotlist overrides (OmegaConf)
end
Loader->>Validator: select recipe class by metadata.recipe_type
Validator->>Validator: validate model / data / training sections
Validator-->>Loader: validated recipe instance
Loader-->>Main: return typed recipe
Main->>Main: construct HfTrainingArguments from recipe.training
Main->>Main: route conversion/execution by recipe type (eagle/dflash/medusa)
sequenceDiagram
participant Script as Main Script
participant Recipe as Recipe Config
participant HfArgs as HfTrainingArguments
participant Trainer as HF Trainer
Script->>Recipe: load_recipe(path)
Recipe-->>Script: typed recipe
Script->>HfArgs: HfTrainingArguments.from_dict(recipe.training.model_dump())
HfArgs->>HfArgs: infer dp_shard_size from WORLD_SIZE (if needed)
HfArgs->>HfArgs: validate speculative fields (training_seq_len, estimate_ar, ...)
HfArgs-->>Script: validated hf args
rect rgba(100,150,200,0.5)
Script->>Trainer: perform recipe-specific conversion & training
end
Trainer-->>Script: complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1328 +/- ##
==========================================
- Coverage 76.91% 76.89% -0.03%
==========================================
Files 473 474 +1
Lines 51439 51506 +67
==========================================
+ Hits 39566 39605 +39
- Misses 11873 11901 +28
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
153-186:⚠️ Potential issue | 🟠 MajorMedusa recipe skips
data_modulecreation, causing an undefined variable error.When
recipeisModelOptMedusaRecipe, the code at line 154 callsmtsp.convert()but then falls through to line 203-211 wheredata_moduleis only created forModelOptEagleRecipeorModelOptDFlashRecipe. This leavesdata_moduleundefined, causing aNameErrorat line 226 when passed toEagleTrainerWithAccLog.Proposed fix: Add Medusa to data_module creation or handle separately
print_rank_0("Loading dataset...") is_dflash = isinstance(recipe, ModelOptDFlashRecipe) - if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)): + if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)): data_module = make_speculative_data_module( tokenizer, recipe.data, train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, ) + else: + raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 153 - 186, The Medusa branch (when recipe is ModelOptMedusaRecipe) calls mtsp.convert but never creates the data_module used later by EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe branch to either construct the same data_module as done for ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing data-module factory logic) or explicitly set data_module to an appropriate value (or None) and handle that case before calling EagleTrainerWithAccLog so data_module is always defined; refer to the branches around ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert, recipe.data, and EagleTrainerWithAccLog to place the fix.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)
61-76: Field synchronization betweenHfTrainingArgumentsandSpecTrainingArgsshould be explicit.The docstring correctly notes these field sets "MUST stay in sync" with
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments. Consider adding a test or assertion to catch drift between these two definitions automatically.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 61 - 76, Add an automated check to ensure HfTrainingArguments and the plugin TrainingArguments stay in sync: write a small unit test or an import-time assertion that imports HfTrainingArguments (examples.speculative_decoding.main.HfTrainingArguments) and modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts their dataclass/field names (and optionally defaults/types), and fails if any field names or default values differ; place the check in tests (preferred) or at module import to catch drift early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 96-130: The _fill_parallelism validator currently sets world_size
using int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields
0 on CPU-only machines and leads to dp_shard_size becoming 0; change that logic
to treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 153-186: The Medusa branch (when recipe is ModelOptMedusaRecipe)
calls mtsp.convert but never creates the data_module used later by
EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe
branch to either construct the same data_module as done for
ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing
data-module factory logic) or explicitly set data_module to an appropriate value
(or None) and handle that case before calling EagleTrainerWithAccLog so
data_module is always defined; refer to the branches around
ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert,
recipe.data, and EagleTrainerWithAccLog to place the fix.
---
Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 61-76: Add an automated check to ensure HfTrainingArguments and
the plugin TrainingArguments stay in sync: write a small unit test or an
import-time assertion that imports HfTrainingArguments
(examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 838ec6d9-91a7-46a6-a197-701d6268d76c
📒 Files selected for processing (7)
examples/speculative_decoding/main.pymodelopt/recipe/config.pymodelopt/recipe/loader.pymodelopt/torch/speculative/plugins/hf_training_args.pymodelopt_recipes/general/speculative_decoding/dflash.yamlmodelopt_recipes/general/speculative_decoding/eagle3.yamltests/unit/recipe/test_loader.py
|
As discussed in the meeting, let's strip off the config override part into a separate change. |
69f3c40 to
3bf010b
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 97-114: The nested config fields (model, data, training) currently
instantiate SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update
ModeloptField to accept a default_factory alternative (or remove the assertion
forcing default so callers can use pydantic.Field) so you can switch those
fields to lazy defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.
In `@modelopt/recipe/loader.py`:
- Around line 109-118: In _apply_dotlist, avoid eagerly resolving OmegaConf
interpolations by changing the OmegaConf.to_container call so it does not
resolve interpolations; specifically, update the call in function _apply_dotlist
to use resolve=False instead of resolve=True when converting the merged
OmegaConf to a plain dict so user-supplied overrides (e.g., ${...} patterns) are
preserved and not expanded before Pydantic validation.
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 90-106: The _fill_parallelism model_validator currently divides by
cp_size and can accept non-positive cp_size/dp_shard_size from user configs;
update _fill_parallelism (on TrainingArguments) to first validate that cp_size
is an int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also
verify computed world_size is >=1), and raise a clear ValueError (so Pydantic
surfaces it) if these checks fail; only after these guards compute world_size
and set self.dp_shard_size = world_size // self.cp_size when dp_shard_size is
None to avoid ZeroDivisionError and invalid runtime state.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 50f5aa9c-3c49-42a0-be51-0ab116444297
📒 Files selected for processing (7)
examples/speculative_decoding/main.pymodelopt/recipe/config.pymodelopt/recipe/loader.pymodelopt/torch/speculative/plugins/hf_training_args.pymodelopt_recipes/general/speculative_decoding/dflash.yamlmodelopt_recipes/general/speculative_decoding/eagle3.yamltests/unit/recipe/test_loader.py
✅ Files skipped from review due to trivial changes (1)
- modelopt_recipes/general/speculative_decoding/eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt_recipes/general/speculative_decoding/dflash.yaml
- tests/unit/recipe/test_loader.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/recipe/config.py (1)
98-115:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftUse
default_factoryfor nested config defaults to avoid import-time instantiation.Line 99, Line 105, Line 111 (and similarly Line 124, Line 153, Line 171) instantiate nested config objects at import time via
default=...(). This is the same issue already raised earlier and is still unresolved. Prefer lazy construction withdefault_factory(which may require extendingModeloptFieldto pass throughdefault_factory).Suggested direction
-from pydantic import field_validator, model_validator +from pydantic import Field, field_validator, model_validator - model: SpecModelArgs = ModeloptField(default=SpecModelArgs(), ...) + model: SpecModelArgs = Field(default_factory=SpecModelArgs, ...) - data: SpecDataArgs = ModeloptField(default=SpecDataArgs(), ...) + data: SpecDataArgs = Field(default_factory=SpecDataArgs, ...) - training: SpecTrainingArgs = ModeloptField(default=SpecTrainingArgs(), ...) + training: SpecTrainingArgs = Field(default_factory=SpecTrainingArgs, ...)If
ModeloptFieldmust be retained, updateModeloptFieldto acceptdefault_factoryas an alternative todefault.#!/bin/bash # Verify eager constructor defaults in recipe config and whether ModeloptField supports default_factory. python - <<'PY' import ast, pathlib p = pathlib.Path("modelopt/recipe/config.py") tree = ast.parse(p.read_text()) for n in ast.walk(tree): if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == "ModeloptField": for kw in n.keywords: if kw.arg == "default" and isinstance(kw.value, ast.Call): fn = kw.value.func name = getattr(fn, "id", getattr(fn, "attr", type(fn).__name__)) print(f"Line {n.lineno}: eager default via {name}()") PY rg -n -C2 "def ModeloptField|default_factory|PydanticUndefined|assert .*default" modelopt/torch/opt/config.pyAlso applies to: 123-124, 152-153, 170-171
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/recipe/config.py` around lines 98 - 115, The ModeloptField usages for the nested config attributes model, data, and training (and other spots noted) currently instantiate objects at import time via default=SpecModelArgs(), default=SpecDataArgs(), default=SpecTrainingArgs(), which causes eager construction; change these to use lazy construction by supporting and passing default_factory=SpecModelArgs, default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of calling the constructors, and if ModeloptField does not yet accept default_factory update the ModeloptField implementation to accept a default_factory kwarg, store/forward it to the underlying field construction logic (mirroring pydantic/dataclasses semantics), and ensure validate_default behavior still works with the factory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 27-32: The module currently performs hard imports of
speculative/HF classes DFlashConfig, EagleConfig, MedusaConfig and
SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time; change this to
lazy/plugin loading by moving those imports behind typing.TYPE_CHECKING for type
annotations and using the project's import_plugin() (or an equivalent lazy
import helper) where the classes are actually needed at runtime (e.g., in
factory functions or recipe registration). Keep type-only references via from
typing import TYPE_CHECKING and string annotations or if TYPE_CHECKING: import
the speculative classes, and replace direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.
---
Duplicate comments:
In `@modelopt/recipe/config.py`:
- Around line 98-115: The ModeloptField usages for the nested config attributes
model, data, and training (and other spots noted) currently instantiate objects
at import time via default=SpecModelArgs(), default=SpecDataArgs(),
default=SpecTrainingArgs(), which causes eager construction; change these to use
lazy construction by supporting and passing default_factory=SpecModelArgs,
default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of
calling the constructors, and if ModeloptField does not yet accept
default_factory update the ModeloptField implementation to accept a
default_factory kwarg, store/forward it to the underlying field construction
logic (mirroring pydantic/dataclasses semantics), and ensure validate_default
behavior still works with the factory.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 143fdb4d-cd06-457c-b8b0-cbace1bf21cc
📒 Files selected for processing (1)
modelopt/recipe/config.py
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)
186-209:⚠️ Potential issue | 🔴 Critical | ⚡ Quick win
data_moduleis undefined for Medusa recipes and will crash trainer construction.At Line 186,
data_moduleis created only for Eagle/DFlash. ForModelOptMedusaRecipe, Line 208 expands**data_modulebefore assignment (UnboundLocalError).Suggested fix
- if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)): + if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)): data_module = make_speculative_data_module( tokenizer, recipe.data, train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, ) + else: + raise ValueError(f"Unsupported speculative recipe type for dataset setup: {type(recipe).__name__}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 186 - 209, The variable data_module is only set inside the isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later always expanded into EagleTrainerWithAccLog via **data_module, causing an UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is always defined before trainer construction—either initialize data_module = {} prior to the conditional or add an else branch that creates/returns the appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent) so that **data_module is safe for all recipe types (references: recipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe, make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 196-201: The code references eagle_cfg when deciding to append
LoRAWarmupCallback but eagle_cfg may be undefined in the checkpoint-resume path;
initialize or load eagle_cfg before that conditional so the .get calls are safe.
Concretely, ensure eagle_cfg is defined (e.g., set eagle_cfg = {} or populate it
from the checkpoint metadata in the resume branch) prior to the callback gating
logic that checks isinstance(recipe, ModelOptEagleRecipe) and
eagle_cfg.get(...), so
LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]) only executes when
eagle_cfg exists and contains the key.
- Around line 109-113: The code mutates training_args.parallelism_config when
training_args.cp_size > 1 but HfTrainingArguments lacks that field; either add a
parallelism_config field and an initializer/validator to HfTrainingArguments
(similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 257-264: load_draft_vocab_cache currently assigns torch.load
directly to model.eagle_module.d2t which can cause device/dtype/shape
mismatches; instead, load the saved tensor using map_location to the
model.eagle_module.d2t device, validate that the loaded tensor's dtype and numel
match model.eagle_module.d2t (and that model.eagle_config.draft_vocab_size <
model.eagle_config.vocab_size condition holds), then copy the data into the
existing buffer in-place (e.g., via copy_()) so the registered buffer keeps its
device and shape; update load_draft_vocab_cache and references to
model.eagle_module.d2t accordingly.
---
Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 186-209: The variable data_module is only set inside the
isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later
always expanded into EagleTrainerWithAccLog via **data_module, causing an
UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is
always defined before trainer construction—either initialize data_module = {}
prior to the conditional or add an else branch that creates/returns the
appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent)
so that **data_module is safe for all recipe types (references: recipe,
ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe,
make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 16de0a2a-dbe7-453e-a197-0049f864b524
📒 Files selected for processing (2)
examples/speculative_decoding/main.pymodelopt/torch/speculative/plugins/hf_eagle.py
0915687 to
8719802
Compare
|
/claude review |
Claude review summaryFindings: CRITICAL: 1, IMPORTANT: 1, SUGGESTION: 1 Most impactful
Risk assessmentThe recipe-subsystem refactor is well-scoped and the new test coverage is solid (dotlist parsing, type validation, missing-section errors). The CRITICAL finding above is a single-line regression that breaks a real user path ( The schema duplication between |
7ea8ad4 to
55b270c
Compare
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
55b270c to
4b44e4c
Compare
What does this PR do?
Type of change: new feature
Port the speculative-decoding example to ModelOpt's recipe/config subsystem:
model/data/training/<algo>now load from a single YAML with Pydantic validation and OmegaConf dotlist overrides. Adds built-ineagle3/dflashrecipes, drops the redundanttraining.modefield (inferred from recipe class), and shrinksmain.pyby ~145 lines (−208 / +63).JIRA: OMNIML-3859
Usage
python main.py --config general/speculative_decoding/eagle3 \ model.model_name_or_path=meta-llama/Llama-3.2-1B \ data.data_path=train.jsonl \ training.output_dir=ckpts/testTesting
pytest tests/unit/recipe/test_loader.py— new coverage for Eagle / DFlash YAML loading, dotlist overrides, and field-level validation.eagle3anddflashrecipes end-to-end.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).main.pyCLI switched to--config <recipe>(+ dotlist overrides); the old argparse flags are removed.CONTRIBUTING.md: N/A — no new deps (pydantic,omegaconfalready in core).tests/unit/recipe/test_loader.py.Additional Information
Follow-up to the
modelopt.recipesubsystem introduced for PTQ; this PR extends the same declarative-YAML pattern to speculative decoding (Eagle3 / DFlash / Medusa).Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation