Skip to content

[Feature] Support share MTP weights.#1672

Merged
HAOCHENYE merged 3 commits intoInternLM:mainfrom
RangiLyu:lcq/share-mtp-head
Apr 14, 2026
Merged

[Feature] Support share MTP weights.#1672
HAOCHENYE merged 3 commits intoInternLM:mainfrom
RangiLyu:lcq/share-mtp-head

Conversation

@RangiLyu
Copy link
Copy Markdown
Collaborator

Refactor MTP configuration to support weight sharing across layers. Updated MoE and MTPBlock classes to handle shared weights and adjusted layer initialization accordingly. Added share_weights parameter to MTPConfig for better control over layer behavior.

…pdated MoE and MTPBlock classes to handle shared weights and adjusted layer initialization accordingly. Added share_weights parameter to MTPConfig for better control over layer behavior.
@RangiLyu
Copy link
Copy Markdown
Collaborator Author

@claude review

Comment thread xtuner/v1/module/mtp/mtp_block.py
super().__init__()
if not mtp_layers:
raise ValueError("mtp_layers cannot be empty")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: self.num_layers is now semantically ambiguous. It stores the physical layer count (len(mtp_layers)), but the forward loop iterates self.mtp_config.num_layers times (the logical step count). When share_weights=True, these differ (1 vs. N).

If self.num_layers is not used externally, consider removing it or renaming it to self.num_physical_layers to avoid confusion. Alternatively, you could update it to always reflect the logical count:

self.num_layers = mtp_config.num_layers

Comment thread xtuner/v1/model/moe/moe.py Outdated
Comment thread xtuner/v1/model/moe/moe.py
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 13, 2026

Claude: ## Summary

Adds share_weights parameter to MTPConfig so a single physical MTP layer can be reused across all prediction depths, reducing memory footprint.

Issues

Warning

  • mtp_block.py:64-65 — Missing validation for share_weights=False: the forward loop now indexes self.layers[step] using mtp_config.num_layers, but nothing enforces len(mtp_layers) == mtp_config.num_layers in the constructor. Could cause IndexError if they diverge.
  • moe.py:1237-1238_should_recompute docstring example doesn't cover share_weights=True behavior.

Nit

  • mtp_block.py:67self.num_layers is now ambiguous (physical count vs. logical step count).
  • moe.py:1018-1020 — Prefer assert or explicit is not None check over # type: ignore for strict mypy compliance.

Verdict

COMMENT — the logic is sound overall. The main thing to address is the missing len(mtp_layers) == mtp_config.num_layers validation in MTPBlock.__init__ for the non-sharing case.

@HAOCHENYE HAOCHENYE merged commit e35e80c into InternLM:main Apr 14, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants