Skip to content

Add Quantization Aware Distillation (QAD) to Megatron-Bridge example#1600

Merged
kevalmorabia97 merged 9 commits into
mainfrom
kmorabia/megatron-bridge-qad
Jun 5, 2026
Merged

Add Quantization Aware Distillation (QAD) to Megatron-Bridge example#1600
kevalmorabia97 merged 9 commits into
mainfrom
kmorabia/megatron-bridge-qad

Conversation

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 commented Jun 2, 2026

What does this PR do?

Type of change: new example

Note: This is part 2 of 4 (builds on #1589):

Extends examples/megatron_bridge/distill.py to initialize the student from a Megatron checkpoint (a quantized checkpoint from quantize.py, or a pruned one) via --student_megatron_path, enabling Quantization Aware Distillation (QAD):

  • --student_hf_path still builds the student architecture; --student_megatron_path supplies the (optionally quantized) weights.
  • For a quantized checkpoint, the ModelOpt quantize mode + base weights are restored onto the plain student before the knowledge-distillation conversion (restore_sharded_modelopt_state is a no-op once a model is already converted), so the distilled checkpoint stays exportable as a quantized model with export.py.

Upstream dependency / workaround: DistillationProvider.provide() has no seam to transform the student before the KD conversion, so this patches provide() at the class level (via an id()-keyed registry, because the provider proxies instance-attribute assignment to its teacher once the teacher is set). A companion Megatron-Bridge PR adds a first-class DistillationProvider.student_pre_conversion_hook; from nemo:26.06 onwards the workaround should be removed and replaced with that hook (a removal note in distill.py documents exactly how).

Usage

# 1) PTQ -> quantized Megatron checkpoint (part 1)
torchrun --nproc_per_node 2 quantize.py \
    --hf_model_name_or_path Qwen/Qwen3-8B --quant_cfg fp8 --tp_size 2 \
    --export_megatron_path /tmp/Qwen3-8B-FP8-megatron

# 2) QAD: distill the quantized student from the unquantized teacher
torchrun --nproc_per_node 8 distill.py \
    --teacher_hf_path Qwen/Qwen3-8B \
    --student_hf_path Qwen/Qwen3-8B \
    --student_megatron_path /tmp/Qwen3-8B-FP8-megatron \
    --data_paths 1.0 tokenized/data_text_document \
    --train_iters 1000 --output_dir /output/qwen3_8b_qad

# 3) export the distilled quantized checkpoint (part 1)
torchrun --nproc_per_node 1 export.py \
    --hf_model_name_or_path Qwen/Qwen3-8B \
    --megatron_path /output/qwen3_8b_qad/checkpoints \
    --export_unified_hf_path /tmp/qwen3_8b_qad_fp8_hf

Testing

tests/examples/megatron_bridge/test_qad.py (validated on a 2-GPU NeMo 26.04 container): quantize a tiny Qwen3 at TP=2 → QAD distill from the quantized student → export.py to a unified HF checkpoint, asserting hf_quant_config.json is written (proves the quantize mode survived QAD). Includes a commented-out vLLM deployment check, validated locally (full flow passes; vLLM loads the export as quantization=modelopt). Existing normal/Puzzletron distillation tests still pass.

Before your PR is "Ready for review"

  • Is this change backward compatible?: N/A (new example feature; default behavior unchanged when --student_megatron_path is not set)
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A (no new dependencies)
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ✅

Additional Information

Depends on a companion Megatron-Bridge PR adding DistillationProvider.student_pre_conversion_hook (the upstream replacement for the class-level provide() workaround). The Nemotron-3 tutorial NVFP4 + QAD experiments ship in part 3.

Summary by CodeRabbit

  • New Features

    • Quantization Aware Distillation (QAD) workflow to recover accuracy of quantized Megatron students and distill from quantized checkpoints.
    • CLI option to initialize a distillation student from a Megatron checkpoint and a structure-only load path for bridging.
  • Documentation

    • Expanded runnable quantize → QAD → export guidance and best-practice tips.
  • Tests

    • End-to-end test validating quantize → QAD → export artifacts.
  • Chores / UX

    • Clearer rank-aware messages, improved tokenizer padding handling, and more consistent export behavior (fixed export dtype).

@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner June 2, 2026 12:46
@kevalmorabia97 kevalmorabia97 requested review from yueshen2016 and removed request for a team June 2, 2026 12:46
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Extends Megatron-Bridge distillation to support Quantization Aware Distillation (QAD) by allowing distill.py to initialize the student model from a quantized Megatron checkpoint. Adds CLI argument, monkeypatch-based restoration workflow, documentation, quantize/export adjustments, and an end-to-end test.

Changes

QAD Feature Implementation

Layer / File(s) Summary
QAD Documentation
CHANGELOG.rst, examples/megatron_bridge/README.md
CHANGELOG updated to reference QAD capability via distill.py extension. README adds TIP recommending QAD for post-quantization accuracy recovery, extends distillation instructions with --student_megatron_path parameter, and introduces "Quantization Aware Distillation (QAD)" subsection with teacher/student roles and example command.
Megatron Student Loading Mechanism
examples/megatron_bridge/distill.py
Adds imports and comments, creates a provider-id registry, and monkeypatches DistillationProvider.provide to build the student structure, restore ModelOpt state/weights from a Megatron checkpoint, then perform ModelOpt KD conversion.
Distill Script CLI and Main Wiring
examples/megatron_bridge/distill.py
Adds --student_megatron_path CLI argument. Modifies _build_model_provider() to accept load_weights. Main() conditionally skips HuggingFace student weight loading when Megatron checkpoint provided, detects ModelOpt state presence, disables gradient_accumulation_fusion when necessary, and registers distill_provider with checkpoint metadata for runtime restoration.
Export wiring to mbridge plugin
examples/megatron_bridge/export.py
Switches export flow to load_mbridge_model_from_hf/load_modelopt_megatron_checkpoint and export_mcore_gpt_to_hf with a fixed torch.bfloat16 dtype; updates example launcher and provider overrides (tensor/expert forced, --pp_size support).
Quantize Script Updates
examples/megatron_bridge/quantize.py
Adjusts example usage and calibration defaults/help (supported datasets), disables gradient_accumulation_fusion in provider overrides, adds GC/CUDA cache clearing before generation, reworks checkpoint save logging, and consolidates generation sanity-print formatting.
Megatron Plugin Warning Helpers
modelopt/torch/quantization/plugins/megatron.py
Replaces logging/warnings with rank-aware warn_rank_0 at multiple warning sites while preserving fallback behaviors.
Dataset Utilities Rank-Aware Logging
modelopt/torch/utils/dataset_utils.py
Replaces prints/warnings with print_rank_0/warn_rank_0 across dataset loading, calibration, HF JSONL download, tokenizer warnings, memory probes, and split processing messages.
Args Printing Order
modelopt/torch/utils/logging.py
print_args now prints argparse.Namespace fields in sorted key order.
mbridge Tokenizer & checkpoint helpers
modelopt/torch/utils/plugins/mbridge.py
Adds load_weights option to load_mbridge_model_from_hf, ensures tokenizer pad_token and padding_side='left', exports load_modelopt_megatron_checkpoint to restore ModelOpt state then load Megatron weights, and updates __all__.
End-to-End QAD Test & calib tweak
tests/examples/megatron_bridge/test_qad.py, tests/examples/megatron_bridge/test_quantize_export.py
Adds test_qad() end-to-end PTQ→QAD→export test and updates quantize test calib batch-size from 1 to 4.

Sequence Diagram

sequenceDiagram
  participant Test as test_qad
  participant Quantize as quantize.py
  participant MegatronCKPT as Megatron checkpoint (dir)
  participant Distill as distill.py
  participant Loader as load_modelopt_megatron_checkpoint
  participant Export as export.py
  participant HF as HuggingFace unified checkpoint

  Test->>Quantize: run quantize.py -> produce MegatronCKPT
  Test->>Distill: run distill.py --student_megatron_path MegatronCKPT --teacher_hf_path ...
  Distill->>Loader: restore ModelOpt state & weights into student
  Distill->>MegatronCKPT: write distilled Megatron checkpoints (modelopt_state preserved)
  Test->>Export: run export.py on distilled MegatronCKPT
  Export->>HF: export_mcore_gpt_to_hf -> produce HF checkpoint artifacts
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1589: Modifies Megatron-Bridge PTQ/export scripts and is directly related to quantize/export flows touched here.

Suggested reviewers

  • ChenhanYu
  • jenchen13
  • yueshen2016
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature addition: Quantization Aware Distillation (QAD) support for the Megatron-Bridge example, which is the core purpose of this PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All security practices from SECURITY.md adhered: no unsafe torch/numpy load, no hardcoded trust_remote_code=True (all configurable, default False), no eval/exec, no # nosec, no new dependencies.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmorabia/megatron-bridge-qad

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch 2 times, most recently from 9ba6385 to 7787f98 Compare June 2, 2026 13:00
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

❌ Patch coverage is 79.48718% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.89%. Comparing base (433b549) to head (6ee7c83).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/dataset_utils.py 76.47% 4 Missing ⚠️
modelopt/torch/quantization/plugins/megatron.py 50.00% 3 Missing ⚠️
modelopt/torch/utils/plugins/mbridge.py 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1600      +/-   ##
==========================================
+ Coverage   76.31%   76.89%   +0.58%     
==========================================
  Files         488      489       +1     
  Lines       54386    54415      +29     
==========================================
+ Hits        41503    41844     +341     
+ Misses      12883    12571     -312     
Flag Coverage Δ
examples 42.76% <46.15%> (+0.76%) ⬆️
gpu 58.40% <41.02%> (-1.48%) ⬇️
regression 14.88% <2.56%> (-0.24%) ⬇️
unit 54.00% <12.82%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 7787f98 to 4e22a59 Compare June 2, 2026 17:28
@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Reviewed Part 2 of the QAD-on-Megatron-Bridge series (extends distill.py with --student_megatron_path to load a quantized Megatron checkpoint as the student before the KD conversion, plus an end-to-end QAD test).

Findings: 0 CRITICAL · 0 IMPORTANT · 3 SUGGESTION (all inline)

The class-level provide() monkey-patch is well-documented as a 26.04 workaround with a clear removal path once student_pre_conversion_hook lands in 26.06; the test exercises the full quantize → QAD-distill → unified-HF-export flow and confirms modelopt_state survives. Suggestions are about local clarity, not behavior:

  • _restore_megatron_student: the strict=False rationale referencing "in-memory teacher weights" doesn't match the only call site (teacher isn't built yet at that point).
  • student_is_quantized / quantized: bool actually mean "checkpoint has any ModelOpt mode state"; safe today (prune_minitron strips its state, only quantize.py emits any), but the QAD-specific log message and gradient_accumulation_fusion = False would fire incorrectly if any other mode starts shipping state.
  • id(self)-keyed registry silently falls back to vanilla distillation if the framework ever wraps/copies the provider before provide() is called — consider asserting the lookup hit when --student_megatron_path was set so the failure is loud rather than producing an uninitialized-student run.

Comment thread examples/megatron_bridge/distill.py Outdated
Comment thread examples/megatron_bridge/distill.py Outdated
Comment thread examples/megatron_bridge/distill.py Outdated
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch 2 times, most recently from 70b1610 to 46056fa Compare June 2, 2026 18:35
Base automatically changed from kmorabia/megatron-bridge-quantize-export to main June 2, 2026 19:36
@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner June 2, 2026 19:36
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 46056fa to 2ed6c0c Compare June 2, 2026 19:49
@kevalmorabia97 kevalmorabia97 removed the request for review from a team June 2, 2026 19:50
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/megatron_bridge/distill.py`:
- Line 104: The parameter name has_modelopt_state in the function
_restore_megatron_student shadows the imported function has_modelopt_state;
rename the parameter (e.g., to modelopt_present or has_modelopt_flag) in
_restore_megatron_student, update all references inside that function to the new
parameter name, and update all call sites of _restore_megatron_student to pass
the renamed parameter variable so the imported has_modelopt_state function
remains callable.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7cf45685-2015-4140-bfc0-7c93ab8c17e4

📥 Commits

Reviewing files that changed from the base of the PR and between f21977a and 2ed6c0c.

📒 Files selected for processing (4)
  • CHANGELOG.rst
  • examples/megatron_bridge/README.md
  • examples/megatron_bridge/distill.py
  • tests/examples/megatron_bridge/test_qad.py

Comment thread examples/megatron_bridge/distill.py Outdated
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 2ed6c0c to f0d1988 Compare June 2, 2026 20:09
Comment thread examples/megatron_bridge/distill.py Outdated

# _load_model_weights_from_checkpoint is the (private) helper bridge.load_megatron_model uses to load
# a (quantized) Megatron checkpoint into an already-built model; reused here to initialize the student.
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't you just use bridge.load_megatron_model?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of some checks in it which doesnt cause load_modelopt_state to be invoked at the righ time in case of quantized ckpt, we need a temporary workaround. @AAnoosheh is looking into the correct fix in megatron bridge side. Alternatively as we discussed in Nemo-ModelOpt meeting, ideally we will have mbridge natively support quantized ckpt then we wont need this workaround

Comment thread examples/megatron_bridge/distill.py Outdated
if restore_modelopt_state:
load_modelopt_state([student_model], str(ckpt_root))
print_rank_0(f"Loading student weights from Megatron checkpoint {ckpt_dir}")
# strict=False because the bridge loader strips Transformer-Engine extra-state from the loaded
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the mbridge loader strips extra-state from the sharded checkpoint, how are the amax values restored?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its restored from load_modelopt_state call above

@kevalmorabia97 kevalmorabia97 requested a review from jenchen13 June 4, 2026 08:47
@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner June 4, 2026 11:06
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 9d94730 to 87c33b5 Compare June 4, 2026 19:59
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners June 4, 2026 19:59
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch 2 times, most recently from 3147d66 to 82982cb Compare June 4, 2026 20:18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name this file train.py and support both QAT and QAD similar to HF llm_qat folder?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is we dont have a good story for using QAT and instead recommend QAD. But if needed, I can change it in a follow-up PR to also support QAD

@realAsma
Copy link
Copy Markdown
Contributor

realAsma commented Jun 4, 2026

Can we add a link to this from https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_qat as performant QAT/QAD backend?

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 72013f2 to 67bc3a5 Compare June 5, 2026 06:57
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners June 5, 2026 06:57
@kevalmorabia97 kevalmorabia97 requested a review from realAsma June 5, 2026 06:57
Comment thread examples/megatron_bridge/README.md Outdated
import vllm

DEFAULT_PROMPTS = [
"Hello!",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be addressed in later PR but: these prompts are rather short, would be good to have some more long generation/agentic task

Comment thread examples/megatron_bridge/quantize.py Outdated
)
parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size")
parser.add_argument("--seq_length", type=int, default=4096, help="Calibration sequence length")
parser.add_argument("--seq_length", type=int, default=512, help="Calibration sequence length")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea we need longer seq length than 512, like 2048 or 4096. and also suggest a longer seq length for higher quality calibration

Comment thread examples/megatron_bridge/README.md Outdated
--quant_cfg fp8 \
--tp_size 2 \
--calib_batch_size 16 \
--seq_length 512 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are examples in nemotron-post-training really under 2k seq length? nemotron is always post trained with long seq length like 256k

kevalmorabia97 and others added 8 commits June 5, 2026 11:51
Extend examples/megatron_bridge/distill.py with --student_megatron_path to
initialize the student from a Megatron checkpoint (a quantized checkpoint from
quantize.py, or a pruned one) instead of HuggingFace weights; --student_hf_path
still builds the architecture.

For a quantized checkpoint, the ModelOpt quantize mode + base weights are
restored onto the plain student before the knowledge-distillation conversion
(restore_sharded_modelopt_state is a no-op once a model is already converted),
so the distilled checkpoint stays exportable as a quantized model with export.py.

Until nemo:26.06 (which adds DistillationProvider.student_pre_conversion_hook
upstream), this is done by patching DistillationProvider.provide at the class
level via an id()-keyed registry, since the provider proxies instance attribute
assignment to its teacher once the teacher is set. A removal note documents the
upstream-hook replacement.

Add tests/examples/megatron_bridge/test_qad.py covering quantize -> QAD ->
export, asserting hf_quant_config.json is written so the distilled checkpoint
stays exportable as a quantized model.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/megatron-bridge-qad branch from 67bc3a5 to 2b5ad6a Compare June 5, 2026 18:52
Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kevalmorabia97 kevalmorabia97 enabled auto-merge (squash) June 5, 2026 19:17
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 merged commit 54ce4e0 into main Jun 5, 2026
52 checks passed
@kevalmorabia97 kevalmorabia97 deleted the kmorabia/megatron-bridge-qad branch June 5, 2026 21:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants