Skip to content

feat(model): add Ling 2.0 / BailingMoeV2 (mini, flash, 1T) (#2242)#2255

Merged
HuiyingLi merged 9 commits into
NVIDIA-NeMo:mainfrom
Hayden727:hayden727/feat/ling-v2
May 19, 2026
Merged

feat(model): add Ling 2.0 / BailingMoeV2 (mini, flash, 1T) (#2242)#2255
HuiyingLi merged 9 commits into
NVIDIA-NeMo:mainfrom
Hayden727:hayden727/feat/ling-v2

Conversation

@Hayden727
Copy link
Copy Markdown
Contributor

What does this PR do?

Adds support for the inclusionAI/Ling-2.0 MoE family (BailingMoeV2ForCausalLM):
Ling-mini-2.0 (16B-A1.4B), Ling-flash-2.0 (100B-A6B), and Ling-1T (1T-A50B).
All three variants share the same architecture and are handled by a single implementation.

Closes #2242.

Changelog

  • Add nemo_automodel/components/models/ling_v2/ with BailingMoeV2Config, BailingMoeV2ForCausalLM, attention layer (GQA + per-head QK-RMSNorm + half RoPE), and state-dict adapter (HF↔native).
  • Register BailingMoeV2ForCausalLM arch and bailing_moe custom config in nemo_automodel/_transformers/registry.py.
  • Reuse the framework's existing sigmoid + grouped-topk + per-expert-bias Gate (DeepSeek-V3-style) — no new MoE routing code.
  • Reuse gpt_oss/rope_utils.py's partial_rotary_factor support — no new RoPE code.
  • Derive partial_rotary_factor from rotary_dim when present (Ling-1T uses rotary_dim=64; mini/flash use partial_rotary_factor=0.5).
  • Add LoRA SFT example recipe: examples/llm_finetune/ling/ling_mini_2_0_squad.yaml.
  • Add model-coverage page: docs/model-coverage/llm/inclusionai/ling-2.md.
  • Add 16 unit tests including HF↔native state-dict round-trip and per-variant GPU smoke tests for mini / flash / 1T shapes.
  • Add manual parity script (tests/unit_tests/models/ling_v2/parity_ling_v2.py) implementing the 3 levels from the parity-testing skill, plus a single-GPU real-checkpoint smoke (_real_forward_smoke.py).

Validation

Run inside nvcr.io/nvidia/nemo-automodel:26.04 on a single H100 80GB:

  • 16 unit tests pass (CPU + GPU; covers all three variant shapes incl. Ling-1T's first_k_dense_replace=4 and the rotary_dimpartial_rotary_factor derivation).
  • Level 1 state-dict round-trip on the real 16B inclusionAI/Ling-mini-2.0 checkpoint: 14813 HF tensors → 299 grouped native → 14813 HF, zero diff.
  • Level 2 half-RoPE component parity vs a GPT-NeoX-style reference: max_diff=0, cos=1.0.
  • Real-checkpoint forward on H100: 16.26B params loaded, load_state_dict reports missing=0 unexpected=0, GPU memory 32.5 GB (matches bf16), logits finite, top-1 argmax shows 22 distinct tokens over 32 positions (non-degenerate).
  • Adjacent tests/unit_tests/_transformers/ suite (366 tests): no regression.
  • Recipe YAML accepted by tools/lint_example_yamls.py (the CI validator that runs in validate-recipe-configs.yml).
  • Level 3 end-to-end HF logits parity is deferred: HF's bundled modeling_bailing_moe_v2.py imports is_torch_fx_available which was removed in transformers>=5. To be revisited once upstream Ling patches that or via a transformers<5 shim.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test a0e7913

@HuiyingLi
Copy link
Copy Markdown
Contributor

Thank you @Hayden727 !
Could you please share the training log/curve for the yaml examples/llm_finetune/ling/ling_mini_2_0_squad.yaml ?
Appreciate it!

@Hayden727
Copy link
Copy Markdown
Contributor Author

Hi @HuiyingLi — here is the training curve for examples/llm_finetune/ling/ling_mini_2_0_squad.yaml.

Setup

  • Container: nvcr.io/nvidia/nemo-automodel:26.04
  • Hardware: 2 × H100 80GB (CUDA_VISIBLE_DEVICES=0,1, NCCL_NVLS_ENABLE=0)
  • Run command: automodel examples/llm_finetune/ling/ling_mini_2_0_squad.yaml --nproc-per-node 2 --distributed.ep_size=2 --checkpoint.enabled=false
  • All 500 steps + 5 validation passes finished in ~9 minutes of wall-clock training.

Curve

Ling-mini-2.0 LoRA SFT on SQuAD — 500 steps

Numbers

Metric Value
Initial train loss 12.21
Final train loss (step 500) 0.1004
Final val loss (step 500) 0.1091
LR schedule linear warmup 50 steps → cosine decay
Peak LR 1.0e-04
Min LR 1.0e-06
Per-rank peak GPU memory ~21 GiB
Throughput (global) 5,000 – 8,000 tokens/sec

Convergence is clean: loss drops from ~12 to <0.5 by step 40, plateaus around 0.1, and val loss tracks train loss with no overfitting through 500 steps.

Full raw stdout log (213 KB, has every step's metrics) is at:
https://github.com/Hayden727/Automodel/blob/ling-v2-assets/mini_500.log

YAML changes I made before this run

While preparing this run I caught two issues in the example YAML — both pushed to the PR branch in commit a0e7913… (will follow up with a separate commit):

  1. validation_dataset.num_samples_limit: 128 → renamed to limit_dataset_samples: 128 to match make_squad_dataset's signature.
  2. distributed.ep_size: 14 (the default since the framework's HF state-dict adapter for grouped MoE experts requires an ep mesh dimension; users with a different GPU count override via --distributed.ep_size=N).
  3. Switched recipe default to attn: sdpa (works without TE installed) and added a lr_scheduler block with the standard cosine + warmup schedule.

Note on Ling-flash-2.0 (100 B) and Ling-1T

Same architecture, same code path — both arch-validated via per-variant tiny-config GPU smoke tests already in this PR. I was not able to run real-checkpoint training on this 8×H100 80GB box for either:

  • Ling-flash-2.0: needs ~200 GB checkpoint download + 4-8 GPUs free simultaneously; the host's other tenants kept grabbing GPUs mid-run. Feasible on a dedicated 8-GPU node.
  • Ling-1T: total bf16 weights ≈ 2 TB > 640 GB total GPU memory on this box — the model physically cannot be held even sharded. A real run would need ≥32 H100s.

Happy to run them on dedicated hardware if NVIDIA has a slot, or wait for follow-up PRs that focus on the larger variants.

Hayden727 added a commit to Hayden727/Automodel that referenced this pull request May 17, 2026
While running the recipe end-to-end on H100 to produce the training curve
requested in the PR review, three issues surfaced in the example YAML:

- validation_dataset.num_samples_limit -> limit_dataset_samples
  (matches make_squad_dataset's actual signature; the prior name was a typo).
- distributed.ep_size: 1 -> 4 (with a CLI override note).
  The framework's HF state-dict adapter for grouped MoE experts requires an
  `ep` mesh dimension; without ep_size > 1 the to_hf path asserts.
- backend.attn: te -> sdpa as the default so the recipe runs in containers
  without TransformerEngine installed; TE remains an option via override.

Also add a lr_scheduler block with the standard cosine + warmup schedule,
tidy header comments to reflect tested 4-GPU layout, and gitignore local
training_logs/ so plots don't leak into commits.

Validation curve for this recipe is in the PR comment:
NVIDIA-NeMo#2255 (comment)

Signed-off-by: Hayden727 <hayden.cc@zohomail.com>
@Hayden727 Hayden727 force-pushed the hayden727/feat/ling-v2 branch from 9617bbd to 67ca563 Compare May 18, 2026 01:35
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 67ca563

Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tech pubs review of .md files. No copyedits needed. LGTM.

@HuiyingLi
Copy link
Copy Markdown
Contributor

We observe DeepEP timeout on Ling-1T SFT and LoRA runs

@Hayden727 Root cause for the Ling-1T DeepEP failure:

Ling-1T has hidden_size=8192, which crosses a DeepEP intranode TMA shared-memory limit. The failing DeepEP path checks:

half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp

For bf16 hidden size 8192, half_hidden_bytes is already 8192 bytes, so the extra 8-byte metadata pushes it to 8200 bytes. That exceeds the 8192-byte per-warp TMA budget. This is why smaller MoE models with hidden sizes like 4096 or 7168 work, but Ling-1T does not with dispatcher: deepep.

The practical workaround is to use dispatcher: hybridep.

One additional tuning was needed for the full SFT topology. With ep_size=64, pp_size=4, and NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=8, each EP group spans 8 nodes. HybridEP’s default multi-node backward combine config requests slightly more dynamic shared memory than H100 allows:

  • H100 opt-in shared memory limit: 232448 bytes
  • default HybridEP backward combine request: 233984 bytes

Setting:

NUM_OF_STAGES_G2S_COMBINE_API=4

reduces the HybridEP backward combine shared-memory request to about 200832 bytes, which fits under the H100 limit.

So the recommended Ling-1T config is:

  • use dispatcher: hybridep
  • keep rope_fusion: false
  • set NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=8
  • for multi-node EP64 SFT, also set NUM_OF_STAGES_G2S_COMBINE_API=4

@HuiyingLi
Copy link
Copy Markdown
Contributor

@Hayden727
1T 32nodes SFT
image

1T ep8pp8 8nodes
image

Hayden727 and others added 2 commits May 19, 2026 18:21
Per Huiying's runs (PR NVIDIA-NeMo#2255), DeepEP times out on Ling-1T because
hidden_size=8192 crosses the intranode TMA per-warp shared-memory budget
(half_hidden_bytes + 8 > 8192).  Switch the two 1T recipes to
dispatcher=hybridep and stamp the required HybridEP env vars on the
ci.env_vars block:

- ling_1t_lora_pp.yaml (64 GPUs, ep=8 pp=8):
    NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=8

- ling_1t_sft.yaml (256 GPUs, ep=64 pp=4):
    NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN=8
    NUM_OF_STAGES_G2S_COMBINE_API=4   # ep-group spans nodes; shrinks
                                       # backward-combine dyn-shared-mem
                                       # request under H100's 232 KiB limit

Refs NVIDIA-NeMo#2255

Signed-off-by: Hayden727 <hayden.cc@zohomail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 97cf2b9

Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support inclusionAI/Ling-mini/flash/1T-2.0

4 participants