[Feature] Support share MTP weights.#1672
Conversation
…pdated MoE and MTPBlock classes to handle shared weights and adjusted layer initialization accordingly. Added share_weights parameter to MTPConfig for better control over layer behavior.
…computed as necessary.
|
@claude review |
| super().__init__() | ||
| if not mtp_layers: | ||
| raise ValueError("mtp_layers cannot be empty") | ||
|
|
There was a problem hiding this comment.
Claude: Nit: self.num_layers is now semantically ambiguous. It stores the physical layer count (len(mtp_layers)), but the forward loop iterates self.mtp_config.num_layers times (the logical step count). When share_weights=True, these differ (1 vs. N).
If self.num_layers is not used externally, consider removing it or renaming it to self.num_physical_layers to avoid confusion. Alternatively, you could update it to always reflect the logical count:
self.num_layers = mtp_config.num_layers|
Claude: ## Summary Adds IssuesWarning
Nit
VerdictCOMMENT — the logic is sound overall. The main thing to address is the missing |
Refactor MTP configuration to support weight sharing across layers. Updated MoE and MTPBlock classes to handle shared weights and adjusted layer initialization accordingly. Added share_weights parameter to MTPConfig for better control over layer behavior.