Skip to content

feat: Add context parallel support for Qwen3.5 MoE#1560

Merged
akoumpa merged 18 commits intomainfrom
zpqiu/qwen35-cp
Mar 19, 2026
Merged

feat: Add context parallel support for Qwen3.5 MoE#1560
akoumpa merged 18 commits intomainfrom
zpqiu/qwen35-cp

Conversation

@zpqiu
Copy link
Copy Markdown
Contributor

@zpqiu zpqiu commented Mar 17, 2026

What does this PR do ?

Add CP support for Qwen3.5 MoE.

Changelog

  • Add specific line by line info of high level changes in this PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

Convergence Experiment

  • Pink: CP1
  • Orange: CP2
截屏2026-03-13 12 49 40
# Qwen3.5 MoE 35B LLM SFT — CP=2 experiment
# Usage:
#   torchrun --nproc-per-node=8 examples/llm_finetune/finetune.py \
#     --config examples/llm_finetune/qwen/qwen3_5_moe_35b_cp2.yaml

step_scheduler:
    global_batch_size: 128
    local_batch_size: 4
    ckpt_every_steps: 500
    val_every_steps: 10
    num_epochs: 1
    max_steps: 200

dist_env:
    backend: nccl
    timeout_minutes: 60

rng:
    _target_: nemo_automodel.components.training.rng.StatefulRNG
    seed: 1234
    ranked: true

model:
    _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
    pretrained_model_name_or_path: Qwen/Qwen3.5-35B-A3B-Base
    backend:
        _target_: nemo_automodel.components.models.common.BackendConfig
        attn: te
        linear: te
        rms_norm: torch_fp32
        experts: te
        dispatcher: no
        rope_fusion: false
        fake_balanced_gate: false
        enable_hf_state_dict_adapter: true

checkpoint:
    enabled: false

distributed:
    strategy: fsdp2
    tp_size: 1
    cp_size: 2
    pp_size: 1
    ep_size: 8

    sequence_parallel: false
    activation_checkpointing: true

loss_fn:
    _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

dataset:
    _target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset
    path_or_dataset_id: allenai/tulu-3-sft-mixture
    split: train
    shuffle_seed: 1234
    seq_length: 2048
    padding: max_length
    truncation: true

packed_sequence:
    packed_sequence_size: 0

dataloader:
    _target_: torchdata.stateful_dataloader.StatefulDataLoader
    collate_fn: nemo_automodel.components.datasets.utils.default_collater
    shuffle: true

validation_dataset:
    _target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset
    path_or_dataset_id: allenai/tulu-3-sft-mixture
    split: "train[:1024]"
    shuffle_seed: 1234
    seq_length: 2048
    padding: max_length
    truncation: true

validation_dataloader:
    _target_: torchdata.stateful_dataloader.StatefulDataLoader
    collate_fn: nemo_automodel.components.datasets.utils.default_collater
    batch_size: 4

optimizer:
    _target_: torch.optim.Adam
    betas: [0.9, 0.999]
    eps: 1e-7
    lr: 1.0e-5
    weight_decay: 0
    foreach: false

zpqiu and others added 11 commits March 17, 2026 01:28
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
- Unit tests for CPAwareGatedDeltaNet: extract_local_positions, undo/redo
  load balancing, conv1d boundary exchange, AllGatherConcatFn, forward
  fast path delegation, and init checks
- Unit tests for parallelizer: apply_cp linear_attention branch coverage
  (cp_mesh attachment, warning paths, mixed full+linear attention)
- Functional test: 2-GPU torchrun comparing CP=1 vs CP=2 forward outputs
  and gradients for GatedDeltaNet linear attention

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
CPAwareGatedDeltaNet inherits from HF GatedDeltaNet which stores
hidden_size as a direct attribute, not under a config object.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
- Replace P2P-based conv1d boundary exchange with FLA's causal_conv1d
  that natively supports CP context
- Refactor functional test baseline to use _forward_with_cp for
  consistent comparison with CP=2
- Simplify unit tests: consolidate edge cases into parametrized test,
  remove trivial TestInit and redundant TestAllGatherConcat

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
- Fix ruff format issues in cp_linear_attn.py and parallelizer.py
- Pin flash-linear-attention to git commit 7cbe461b (includes fla.ops.cp)
- Update uv.lock accordingly

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
The override-dependencies block was injected by update_pyproject_pytorch.sh
for container environments and should not be checked in.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Replace hard assert with warning + continue for blocks whose
attn_module is not DotProductAttention, consistent with the
linear_attention branch handling.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Mar 17, 2026

/ok to test 4807fb2

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 17, 2026

/ok to test 4807fb2

@zpqiu, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Mar 17, 2026

/ok to test 8978884

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 18, 2026

/ok to test 8e70593

@akoumpa akoumpa linked an issue Mar 18, 2026 that may be closed by this pull request
Comment thread tests/functional_tests/context_parallel/L2_CP_Qwen3_5MoE_LinearAttn_Test.sh Outdated
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 19, 2026

/ok to test b950cc9

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 19, 2026

Thanks a lot @zpqiu for the contribution, very much appreciated. Please feel free to LMK if there's anything I can help with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add context parallel support for Qwen3.5

3 participants