Skip to content

fix: preserve inlined MTP layers for GLM5#1532

Merged
Fridah-nv merged 5 commits into
mainfrom
fridah/glm5.1-mtp
May 22, 2026
Merged

fix: preserve inlined MTP layers for GLM5#1532
Fridah-nv merged 5 commits into
mainfrom
fridah/glm5.1-mtp

Conversation

@Fridah-nv
Copy link
Copy Markdown
Contributor

@Fridah-nv Fridah-nv commented May 22, 2026

What does this PR do?

Type of change: Bug fix

Extends load_mtp_weights to detect inlined MTP layers — keys model.layers.{i}.* for i in [num_hidden, num_hidden + num_nextn_predict_layers) — in addition to the existing mtp.* separate-file convention.

Bug. load_mtp_weights() only matched the substring "mtp" in safetensors keys. GLM-5.1 (GlmMoeDsaForCausalLM) stores MTP at model.layers.78.* with no mtp substring, so detection returned ([], {}), _mtp_layer_prefixes was never set, and MTP tensors were silently dropped from the exported safetensors (had to be re-added manually).

Detection.

  1. Detect via config.num_nextn_predict_layers (the model's own declaration of how many MTP layers exist).
  2. Compute the inlined layer indices: model.layers.{i} for i in range(num_hidden, num_hidden + num_nextn).
  3. Load matching tensors from the on-disk shards via safe_open (walks model.safetensors.index.json if present,else falls back to the single shard).
  4. Split the loaded tensors by whether model.state_dict() has a slot for them:
    • keys present in model.state_dict()model.load_state_dict(..., strict=False) (DeepSeek-V3 case: HF instantiates the extra layers).
    • keys absent from model.state_dict() → returned as not_in_state_dict so the exporter routes them through extra_state_dict (GLM-5.1
      case: GlmMoeDsaModel in transformers ≥5.7 only builds num_hidden decoders, leaving MTP keys orphaned at from_pretrained time).
      The returned prefixes flow into the existing plumbing — _mtp_layer_prefixesquant_cfg disable +quantization_config.exclude_modules — unchanged.

Usage

# Add a code snippet demonstrating how to use this

Testing

Verified end-to-end on a mini GLM-5.1 fixture (4 hidden layers + 1 inlined MTP at model.layers.4, 7 synthesized MTP tensors mirroring the full GLM-5.1 layout)
To be verified with full model

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Improvements
    • Quantization utilities now stream safetensors and unify loading of multi-token-prediction (MTP) weights from both inline and separate/sharded conventions, reporting detected MTP prefixes and counts of loaded vs orphaned tensors.
  • Tests
    • Added unit tests and a test import helper covering MTP discovery, loading behaviors (inlined vs standalone/indexed shards), orphan reporting, and non‑MTP checkpoints.

Review Change Stack

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv requested a review from a team as a code owner May 22, 2026 00:42
@Fridah-nv Fridah-nv requested review from Edwardf0t1 and meenchen May 22, 2026 00:42
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds streaming safetensors support and helpers to detect and load MTP tensors stored either inline in model shards or in separate shard files, splits loaded tensors into model state-dict entries and orphaned tensors, and includes unit tests and a test utility module.

Changes

MTP Weight Loading Enhancement

Layer / File(s) Summary
Imports, inlined-prefix computation, and streaming loader
examples/llm_ptq/example_utils.py
Replaced safetensors loading with safetensors.safe_open (CPU streaming); added get_inlined_mtp_prefixes(config) and _load_tensors_matching() to scan indexed shards and standalone *.safetensors files using a predicate to collect matching tensors.
State-dict application helper
examples/llm_ptq/example_utils.py
Added _apply_to_model_state_dict() to partition streamed tensors into those present in model.state_dict() (loaded with strict=False) and orphaned tensors returned for later export merging.
Refactored load_mtp_weights orchestrator
examples/llm_ptq/example_utils.py
Rewrote load_mtp_weights() to unify inlined and separate-file MTP handling via the predicate-based loader, compute exclusion prefixes from config and loaded keys, apply matching tensors to the model, and return (sorted_prefixes, orphan_tensors) while updating detection logging with counts.
Test utilities and unit tests
tests/_test_utils/examples/llm_ptq_example_utils.py, tests/examples/llm_ptq/test_example_utils.py
Added a test helper module that re-exports example_utils for tests and unit tests verifying MTP detection/loading across inlined-orphaned, inlined-in-state, standalone mtp.safetensors, indexed shard layouts, and non-MTP checkpoints (including regression where num_nextn_predict_layers=None).

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: preserve inlined MTP layers for GLM5' accurately describes the main change—extending load_mtp_weights to detect and preserve MTP tensors inlined in checkpoints (model.layers.) rather than only standalone mtp. files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Uses safetensors.safe_open for secure deserialization, no torch.load/pickle, no hardcoded trust_remote_code, no eval/exec, no # nosec comments, no new unsafe dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/glm5.1-mtp

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Line 415: The line using json.load(open(index_file)) leaves the file handle
open; change it to use a context manager so the file is closed automatically
(e.g., open index_file with a with statement and pass the file object to
json.load), mirroring the fix used in _load_inlined_mtp_tensors; update the code
that assigns to the variable index to read via the with-block and remove the raw
open(...) call.
- Line 333: Replace the direct open(index_file) call with a context manager to
ensure the file handle is closed: use a with open(index_file, "r") as f: and
call json.load(f) to populate weight_map (the expression that currently assigns
to weight_map should read from the file handle `f`), updating the code around
the weight_map assignment in example_utils.py accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 8d53a7fd-806c-49a8-bdf1-476fbd1570e1

📥 Commits

Reviewing files that changed from the base of the PR and between b02e888 and 6e7d53c.

📒 Files selected for processing (1)
  • examples/llm_ptq/example_utils.py

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.03%. Comparing base (b02e888) to head (03687a3).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1532      +/-   ##
==========================================
+ Coverage   76.63%   77.03%   +0.39%     
==========================================
  Files         476      476              
  Lines       51813    52208     +395     
==========================================
+ Hits        39707    40217     +510     
+ Misses      12106    11991     -115     
Flag Coverage Δ
examples 41.63% <ø> (+2.59%) ⬆️
unit 52.72% <ø> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 382-384: The early return when index_file (model_dir /
"model.safetensors.index.json") does not exist drops standalone MTP weights;
instead, change the logic in the block that computes index_file so that when the
index is missing you fall back to the legacy layout: scan model_dir for
standalone safetensors files (e.g., "mtp.safetensors" and "model.safetensors")
and build the returned (set, dict) from those files rather than returning
(set(), {}); update the code that references index_file to only parse the index
if it exists, otherwise construct the weight map from discovered safetensors so
MTP weights are preserved.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 89d68edb-f100-4032-9d77-cb23031ac458

📥 Commits

Reviewing files that changed from the base of the PR and between 6e7d53c and cb2d3c4.

📒 Files selected for processing (3)
  • examples/llm_ptq/example_utils.py
  • tests/_test_utils/examples/llm_ptq_example_utils.py
  • tests/examples/llm_ptq/test_example_utils.py

Comment thread examples/llm_ptq/example_utils.py Outdated
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Summary: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 2

The fix correctly extends load_mtp_weights to detect inlined-MTP layouts (DeepSeek-V3, GLM-5.1 GlmMoeDsa, GLM-4.7) by computing the layer-index window from config.num_nextn_predict_layers + config.num_hidden_layers and streaming matching tensors from on-disk shards via safe_open. Logic verified end-to-end:

  • get_inlined_mtp_prefixes returns model.layers.{N..N+K-1} prefixes; the prefix + "." startswith check correctly avoids false matches against neighboring layer indices (e.g. model.layers.78. won't match model.layers.781.x).
  • The split into in-state-dict vs orphan tensors handles both DeepSeek-V3 (HF instantiates the extra layers) and GLM-5.1 (orphaned, routed via extra_state_dict).
  • The legacy mtp-substring path is preserved unchanged in _scan_separate_file_mtp.
  • Downstream call sites (hf_ptq.py:1137 quant exclusion and unified_export_hf.py:802 exclude_modules) work transparently with the new model.layers.N-style prefixes.
  • Return signature is unchanged; ordering is now deterministic (sorted) — minor improvement.
  • Tests cover the pure prefix-derivation contract.

Two non-blocking SUGGESTIONs left as inline comments (a docstring reference to a non-existent MTP_DETECTION.md, and minor defensive-asymmetry in get_inlined_mtp_prefixes). CodeRabbit already flagged json.load(open(...)) file-handle leaks at lines 337 and 386 — not duplicating.

Risk: low. Scoped to an example utility, with a fallback path that preserves prior behavior for non-MTP configs (num_nextn_predict_layers absent or 0).

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Bug fix is reasonable and well-structured: the existing "mtp" in key substring detection genuinely misses GLM-5.1's inlined model.layers.{i} convention, and splitting into get_inlined_mtp_prefixes + _load_inlined_mtp_tensors + _scan_separate_file_mtp + _apply_to_model_state_dict is cleaner than the prior monolithic function. The single-file model.safetensors fallback in _load_inlined_mtp_tensors is also a nice secondary improvement (legacy path only handled the sharded case).

Flagging for human sign-off because:

  • Test coverage is thin. The only new test exercises the trivial 4-line get_inlined_mtp_prefixes (config → list of strings). The actual on-disk loader (_load_inlined_mtp_tensors), the orphan-vs-state-dict split (_apply_to_model_state_dict), and the integrated load_mtp_weights flow are untested. The PR body says the author verified on "a mini GLM-5.1 fixture (4 hidden layers + 1 inlined MTP at model.layers.4, 7 synthesized MTP tensors)" — that fixture would make a solid unit test for the loader path; it's odd not to include it. The bug being fixed (silent tensor drop) is exactly the kind of regression a fixture-based test would catch.
  • End-to-end verification is still pending. Author explicitly says "To be verified with full model" — i.e. no full GLM-5.1 export run yet.
  • Stale docstring reference. load_mtp_weights now points readers to examples/llm_ptq/MTP_DETECTION.md ("See Also") but that file does not exist in the PR or the repo. Either add the doc or drop the reference before merge.
  • Minor: int(getattr(config, "num_nextn_predict_layers", 0)) will raise TypeError if the attribute is present but None (some HF configs do this); getattr(..., 0) or 0 would be safer.

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread tests/examples/llm_ptq/test_example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Per Slack request: posted simplification ideas. The inlined- and separate-file MTP paths duplicate the "walk index → stream matching tensors" logic with different I/O APIs and different fallback behavior, and the prefix-extraction concern is mixed into the separate-file scanner. Unifying both around a single _load_tensors_matching(model_dir, predicate) helper (using safe_open consistently, with single-file fallback) plus a pure _keys_to_prefixes(keys) extractor would shrink the diff meaningfully and close the no-index regression on the legacy path that CodeRabbit flagged. Also noting unresolved prior comments: stale MTP_DETECTION.md "See Also" pointer, int(None) TypeError if num_nextn_predict_layers is present-but-None, and the loader/orphan-split paths still have no fixture-based test (the PR body's mini-GLM-5.1 fixture would be a natural unit test).

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread tests/examples/llm_ptq/test_example_utils.py Outdated
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/examples/llm_ptq/test_example_utils.py (1)

112-133: ⚡ Quick win

Add one test for the sharded-index (model.safetensors.index.json) path.

Line 112 currently validates the no-index fallback/standalone-file flow well, but the PR also adds index-walk behavior and this branch is still unexercised. A focused test with two shards + model.safetensors.index.json would lock down the regression surface.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/examples/llm_ptq/test_example_utils.py` around lines 112 - 133, Add a
new test mirroring test_load_mtp_weights_separate_standalone_file but exercising
the sharded-index path: use _write_safetensors to create two shard files (e.g.,
model.safetensors.0000 and model.safetensors.0001) containing distinct keys
(e.g., "mtp.fc.weight" in one shard and "mtp.layers.0.q_proj.weight" in the
other), write a corresponding model.safetensors.index.json that maps those
tensor names to the appropriate shard filenames, instantiate _FakeModel (as in
the existing test) and call example_utils.load_mtp_weights(model,
str(tmp_path)), then assert the returned prefixes and orphans include the
expected "mtp" prefixes and the two orphan keys; reference
example_utils.load_mtp_weights, _write_safetensors, _FakeModel and the index
file name model.safetensors.index.json when locating code to modify.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/examples/llm_ptq/test_example_utils.py`:
- Around line 112-133: Add a new test mirroring
test_load_mtp_weights_separate_standalone_file but exercising the sharded-index
path: use _write_safetensors to create two shard files (e.g.,
model.safetensors.0000 and model.safetensors.0001) containing distinct keys
(e.g., "mtp.fc.weight" in one shard and "mtp.layers.0.q_proj.weight" in the
other), write a corresponding model.safetensors.index.json that maps those
tensor names to the appropriate shard filenames, instantiate _FakeModel (as in
the existing test) and call example_utils.load_mtp_weights(model,
str(tmp_path)), then assert the returned prefixes and orphans include the
expected "mtp" prefixes and the two orphan keys; reference
example_utils.load_mtp_weights, _write_safetensors, _FakeModel and the index
file name model.safetensors.index.json when locating code to modify.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: f28d789e-29ef-4ffc-ba7a-a8a8d95f164b

📥 Commits

Reviewing files that changed from the base of the PR and between cb2d3c4 and a397d31.

📒 Files selected for processing (2)
  • examples/llm_ptq/example_utils.py
  • tests/examples/llm_ptq/test_example_utils.py

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/_test_utils/examples/llm_ptq_example_utils.py (1)

30-30: ⚡ Quick win

Move __all__ near the top-level declarations.

Line 30 defines the module’s public surface, but this repo convention expects __all__ at the top of the module for discoverability and consistency.

Proposed adjustment
 import sys
 
 from _test_utils.examples.run_command import MODELOPT_ROOT
+__all__ = ["example_utils"]
 
 _LLM_PTQ_DIR = MODELOPT_ROOT / "examples" / "llm_ptq"
 if str(_LLM_PTQ_DIR) not in sys.path:
     sys.path.insert(0, str(_LLM_PTQ_DIR))
 
 import example_utils
-
-__all__ = ["example_utils"]

As per coding guidelines: "Define the public API with __all__ at the top of each module and re-export via from .module import * in __init__.py files".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/_test_utils/examples/llm_ptq_example_utils.py` at line 30, Move the
module-level __all__ declaration from line 30 up to the top of the module
immediately after the import block and before any class/function/variable
declarations so the public API is declared at top-level; keep the symbol list
unchanged (["example_utils"]) and ensure the module still exports that name via
__all__.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/_test_utils/examples/llm_ptq_example_utils.py`:
- Line 30: Move the module-level __all__ declaration from line 30 up to the top
of the module immediately after the import block and before any
class/function/variable declarations so the public API is declared at top-level;
keep the symbol list unchanged (["example_utils"]) and ensure the module still
exports that name via __all__.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 577a7e47-01bd-44c4-961a-81f564ce59d6

📥 Commits

Reviewing files that changed from the base of the PR and between a397d31 and 53291fa.

📒 Files selected for processing (3)
  • examples/llm_ptq/example_utils.py
  • tests/_test_utils/examples/llm_ptq_example_utils.py
  • tests/examples/llm_ptq/test_example_utils.py

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review pass — the previous round of simplification work has all landed and is in good shape:

  • 💬 Author replied "Done" on the unified scanner — _load_tensors_matching + _keys_to_prefixes cleanly separate I/O from prefix derivation, and the safe_open migration kills the legacy load_file whole-shard read.
  • 💬 Author replied "No, just removed" on MTP_DETECTION.md — confirmed gone from the docstring.
  • 💬 int(None) fix (or 0) applied; file handles wrapped in with; no-index fallback restored via model_dir.glob("*.safetensors").
  • 💬 Test coverage substantially expanded — all four conventions (inlined-orphaned, inlined-in-state-dict, separate-standalone, separate-indexed) plus the num_nextn_predict_layers=None regression now have fixture tests. This addresses the prior "test coverage is thin" concern.

Flagging for human sign-off only because of the additional simplifications the operator asked about — none are correctness issues, all are nits:

  1. predicate(key, shard_name) second arg looks dead. The only consumer of shard_name is "mtp" in shard_name in load_mtp_weights.predicate. For every real MTP convention enumerated in the docstring, the "mtp" in key branch already matches (Qwen3-Next, GLM-4.7) or the inlined-prefix branch matches (GLM-5.1, DeepSeek-V3). It's hard to construct a checkpoint where MTP weights would only be detected via shard filename. Consider dropping the shard_name parameter from the predicate signature — _load_tensors_matching becomes Callable[[str], bool] and the predicate body collapses to two clauses. If you want to keep the defensive shard-name fallback, document the case it covers; right now it reads as belt-and-suspenders.

  2. Redundant if inlined_tuple guard in predicate. key.startswith(()) already returns False, so if inlined_tuple and key.startswith(inlined_tuple) simplifies to key.startswith(inlined_tuple).

  3. Awkward conditional + operator-precedence reading hazard on the prefixes line. prefixes = inlined_prefixes | _keys_to_prefixes(separate_keys) if tensors else set() parses as (A | B) if tensors else set() (correct), but the |-vs-ternary precedence is non-obvious. Cleaner as an early return:

    if not tensors:
        return [], {}
    prefixes = inlined_prefixes | _keys_to_prefixes(separate_keys)

    This also drops the now-unreachable if prefixes: guard around the print.

  4. Minor: _keys_to_prefixes is only safe for separate-file keys (it would emit "model" as a prefix if fed an inlined key like "model.layers.78.eh_proj"). The current call site filters via separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)] so it's fine, but a one-line docstring note ("caller must filter out inlined keys") would prevent future misuse.

  5. Docstring style: get_inlined_mtp_prefixes and _keys_to_prefixes open with lowercase verbs ("turn an HF config…", "invert separate-file MTP keys…") — convention is capitalized imperative.

None of these block correctness; all are tighten-up suggestions per the operator's "additional simplification" ask.

Comment thread examples/llm_ptq/example_utils.py
Comment thread examples/llm_ptq/example_utils.py

def _load_tensors_matching(
model_dir: Path, predicate: Callable[[str, str | None], bool]
) -> dict[str, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor: _keys_to_prefixes is only correct for separate-file keys — if fed an inlined key like "model.layers.78.eh_proj.weight" it would emit "model" as a top-level prefix, which is not what the exporter wants. The current caller filters via separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)] so this is safe today; worth a one-line note in the docstring ("caller must filter inlined keys first") to prevent future misuse.

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: previous round's findings are addressed.

  • 💬 File-handle leaks at lines 333/415 — both now use with open(...) as f: context managers.
  • 💬 No-index fallback (CodeRabbit major) — _load_tensors_matching now globs *.safetensors independently of the index, restoring the standalone-shard path.
  • 💬 Author replied "No, just removed" — MTP_DETECTION.md "See Also" reference is gone.
  • 💬 int(None) regression — fixed via or 0, plus an explicit num_nextn_predict_layers=None test.
  • 💬 Author replied "Done" on unifying the two scanners — _load_tensors_matching + _keys_to_prefixes cleanly split I/O from prefix derivation; safe_open replaces load_file.
  • 💬 Test coverage substantially expanded — 5 fixture-based tests across all four conventions (inlined-orphaned, inlined-in-state-dict, separate-standalone, separate-indexed) plus the None regression. The "thin coverage" concern is resolved.
  • The prior round's tighten-up nits (dead shard_name predicate arg, redundant if inlined_tuple guard, ternary precedence on the prefixes line, caller-must-filter docstring note for _keys_to_prefixes) are all reflected in the current code. Only remaining items are docstring-capitalization nits, which are non-blocking.

End-to-end full-model verification is still pending per the PR body, but that's an operational gate the author has flagged explicitly and not something a unit-test review can substitute for.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
examples/llm_ptq/example_utils.py (2)

433-437: ⚡ Quick win

Use rank-aware logging for this new status message.

A raw print() here will fire on every rank during distributed runs. Please route it through print_rank_0/warn_rank_0 instead.

As per coding guidelines, "Use print_rank_0 or warn_rank_0 when possible to avoid noisy logs and guard shared side effects against race conditions between ranks in distributed processing".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 433 - 437, Replace the raw
print(...) that reports detected MTP tensors with a rank-aware logger: call
print_rank_0 (or warn_rank_0 if this should be a warning) instead, passing the
same formatted string that uses tensors, prefixes, and not_in_state_dict to
build the message; locate the print call in example_utils.py (the f-string
referencing len(tensors), sorted(prefixes), and len(not_in_state_dict)) and
simply swap print(...) for print_rank_0(...) so only rank 0 emits the message
during distributed runs.

421-429: ⚡ Quick win

Tighten separate-file detection to the documented mtp.* namespace.

_keys_to_prefixes() and the support matrix both assume separate-file tensors are keyed under top-level mtp.*. Matching any "mtp" substring can still pull in unrelated keys and derive overly broad exclusions from them.

Suggested change
     def predicate(key: str) -> bool:
-        return key.startswith(inlined_tuple) or "mtp" in key
+        return key.startswith(inlined_tuple) or key.startswith("mtp.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 421 - 429, The current
predicate and separate_keys logic treat any key containing "mtp" as a
separate-file tensor, which is too broad; change the predicate in the
_load_tensors_matching call to test top-level mtp namespace (e.g., use
key.startswith(inlined_tuple) or key.startswith("mtp.") ) and update
separate_keys to exclude keys that start with the inlined_tuple or start with
"mtp." (so separate_keys = [k for k in tensors if not
k.startswith(inlined_tuple) and not k.startswith("mtp.")]); keep prefixes =
inlined_prefixes | _keys_to_prefixes(separate_keys).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 433-437: Replace the raw print(...) that reports detected MTP
tensors with a rank-aware logger: call print_rank_0 (or warn_rank_0 if this
should be a warning) instead, passing the same formatted string that uses
tensors, prefixes, and not_in_state_dict to build the message; locate the print
call in example_utils.py (the f-string referencing len(tensors),
sorted(prefixes), and len(not_in_state_dict)) and simply swap print(...) for
print_rank_0(...) so only rank 0 emits the message during distributed runs.
- Around line 421-429: The current predicate and separate_keys logic treat any
key containing "mtp" as a separate-file tensor, which is too broad; change the
predicate in the _load_tensors_matching call to test top-level mtp namespace
(e.g., use key.startswith(inlined_tuple) or key.startswith("mtp.") ) and update
separate_keys to exclude keys that start with the inlined_tuple or start with
"mtp." (so separate_keys = [k for k in tensors if not
k.startswith(inlined_tuple) and not k.startswith("mtp.")]); keep prefixes =
inlined_prefixes | _keys_to_prefixes(separate_keys).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3d5f6dd1-e38a-4de5-8786-843faa15e577

📥 Commits

Reviewing files that changed from the base of the PR and between 53291fa and 03687a3.

📒 Files selected for processing (1)
  • examples/llm_ptq/example_utils.py

Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Fridah-nv Fridah-nv enabled auto-merge (squash) May 22, 2026 22:43
@Fridah-nv Fridah-nv merged commit 16a0130 into main May 22, 2026
41 checks passed
@Fridah-nv Fridah-nv deleted the fridah/glm5.1-mtp branch May 22, 2026 22:47
@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-22 22:47 UTC

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants