Skip to content

fix: read rope config from rope_parameters across all models#1400

Merged
akoumpa merged 3 commits intomainfrom
hemil/fix-qwen3-rope
Feb 27, 2026
Merged

fix: read rope config from rope_parameters across all models#1400
akoumpa merged 3 commits intomainfrom
hemil/fix-qwen3-rope

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

@hemildesai hemildesai commented Feb 27, 2026

Summary

  • Adds a shared get_rope_config(config) helper in components/models/common/utils.py that extracts rope_theta, rope_parameters, and partial_rotary_factor directly from config.rope_parameters
  • Updates 10 models to use the helper and pass YaRN scaling fields (factor, beta_slow, beta_fast, original_max_position_embeddings) from rope_parameters to RotaryEmbedding instead of hardcoding defaults
  • Adds unit tests for get_rope_config and updates mock configs in existing tests to include rope_parameters

Models updated

qwen3_moe, qwen3_next, gpt_oss, glm4_moe, minimax_m2, step3p5, glm4_moe_lite, deepseek_v3, deepseek_v32, mistral3

What was wrong

  1. YaRN scaling silently ignoredRotaryEmbedding was instantiated with scaling_factor=1.0, ntk_alpha=1.0, etc., so configured rope scaling never activated.
  2. Duplicated logic — every model had its own rope config extraction, making it easy for some models to miss fixes applied to others.

Closes #1398

Test plan

  • uv run pytest tests/unit_tests/models/ -q — 1100 passed, 7 pre-existing failures (TE/flash_attn not installed), 0 regressions
  • Functional test with Qwen/Qwen3-30B-A3B + rope_scaling to confirm YaRN activates end-to-end

🤖 Generated with Claude Code

hemildesai and others added 2 commits February 26, 2026 17:05
Extract rope_theta, rope_scaling, and partial_rotary_factor from
config.rope_parameters (the newer HuggingFace format) via a shared
get_rope_config helper. This fixes Qwen3MoE crashing with KeyError
when rope_parameters exists but lacks rope_theta, and ensures YaRN
scaling parameters are propagated to RotaryEmbedding instead of being
silently hardcoded to defaults.

Models updated: qwen3_moe, qwen3_next, gpt_oss, glm4_moe, minimax_m2,
step3p5, glm4_moe_lite, deepseek_v3, deepseek_v32, mistral3.

Closes #1398

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
In transformers v5 all rope configuration lives in
config.rope_parameters; config.rope_theta no longer exists as a
top-level attribute. Remove all fallback paths and read rope_theta,
partial_rotary_factor, and scaling fields directly from
rope_parameters. Update mock configs in glm4_moe_lite, minimax_m2,
and step3p5 tests to include rope_parameters.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 277c9c7

@akoumpa akoumpa added the r0.3.0 Add for cherry-pick into release branch r0.3.0 label Feb 27, 2026
@akoumpa akoumpa merged commit bb22625 into main Feb 27, 2026
51 checks passed
@akoumpa akoumpa deleted the hemil/fix-qwen3-rope branch February 27, 2026 12:40
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* fix: read rope config from rope_parameters across all models

Extract rope_theta, rope_scaling, and partial_rotary_factor from
config.rope_parameters (the newer HuggingFace format) via a shared
get_rope_config helper. This fixes Qwen3MoE crashing with KeyError
when rope_parameters exists but lacks rope_theta, and ensures YaRN
scaling parameters are propagated to RotaryEmbedding instead of being
silently hardcoded to defaults.

Models updated: qwen3_moe, qwen3_next, gpt_oss, glm4_moe, minimax_m2,
step3p5, glm4_moe_lite, deepseek_v3, deepseek_v32, mistral3.

Closes #1398

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* refactor: simplify get_rope_config to read only from rope_parameters

In transformers v5 all rope configuration lives in
config.rope_parameters; config.rope_theta no longer exists as a
top-level attribute. Remove all fallback paths and read rope_theta,
partial_rotary_factor, and scaling fields directly from
rope_parameters. Update mock configs in glm4_moe_lite, minimax_m2,
and step3p5 tests to include rope_parameters.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* revert: restore mistral3/model.py to main version

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

r0.3.0 Add for cherry-pick into release branch r0.3.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Qwen3MoE] rope_scaling is silently ignored and crashes with KeyError

2 participants