Skip to content

perf(qwen3-next): set expandable_segments on GB300 BF16/FP8_MX to fix OOM#3767

Merged
ko3n1g merged 1 commit into
r0.4.0from
ko3n1g/r0.4.0/fix-qwen3-next-gb300-oom
May 11, 2026
Merged

perf(qwen3-next): set expandable_segments on GB300 BF16/FP8_MX to fix OOM#3767
ko3n1g merged 1 commit into
r0.4.0from
ko3n1g/r0.4.0/fix-qwen3-next-gb300-oom

Conversation

@ko3n1g
Copy link
Copy Markdown
Contributor

@ko3n1g ko3n1g commented May 10, 2026

Claude summary

Context

The qwen3_next_80b_a3b_gb300_bf16_50steps_perf (and FP8_MX twin) test in the nemo-ci pipeline OOMs at first forward on the lyris cluster: RuntimeError: Triton Error [CUDA]: out of memory from rank 52 during DDP setup. Reference failing job: https://gitlab-master.nvidia.com/dl/JoC/nemo-ci/-/jobs/314607989.

Cause hypothesis

The GB300 BF16/FP8_MX config is the only GB-series qwen3-next variant with micro_batch_size=4 (the GB200 sister uses MBS=2; B200/B300/H100 use MBS=1). With HybridEP + TE-scoped CUDA graphs (attn, moe_router, moe_preprocess) the allocator suffers fragmentation when the captured-graph and HybridEP buffer scratch interleave, exactly the same pattern as the existing H100 FP8_CS branch in perf_plugins.py.

Fix

Extend the existing qwen3-next H100 FP8_CS branch to also cover GB300 BF16 and FP8_MX with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. No model-config changes, no recompute, no MBS reduction — TFLOPs preserved.

 elif (
     model_family_name in ["qwen"]
     and model_recipe_name in ["qwen3_next_80b_a3b"]
     and train_task == "pretrain"
-    and gpu in ["h100"]
-    and compute_dtype == "fp8_cs"
+    and (
+        (gpu == "h100" and compute_dtype == "fp8_cs")
+        or (gpu == "gb300" and compute_dtype in ["bf16", "fp8_mx"])
+    )
 ):
     executor.env_vars["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

Verification plan

  • Triage container nvcr.io/nvidian/nemo:dgxctestingtemp-nemofw-nightly.50761024 (the exact image the failing pipeline used).
  • nemo-ci pipeline TBD (run on lyris, single test case).
  • Pass criterion: training step 0 completes (i.e. no OOM at first forward).
  • If still OOM: fall through to recompute_modules=["core_attn"], then ["core_attn","moe_act"].

… OOM

The qwen3_next_80b_a3b GB300 BF16/FP8_MX perf config OOMs at first forward
under MBS=4 with HybridEP and TE-scoped CUDA graphs (attn/moe_router/moe_preprocess).
The pattern matches the existing H100 FP8_CS path: NCCL/HybridEP buffer
allocations fragment the heap, and expandable_segments lets the allocator
reclaim physical memory without disabling any NCCL algorithms or reducing
the micro-batch size.

This preserves MBS=4 (and therefore TFLOPs) instead of falling back to
selective recompute or matching the GB200 BF16 sister config (MBS=2).

Verified hypothesis (CI run TBD): triage container nightly.50761024 +
anchor MCore + this MBridge HEAD on test
qwen3_next_80b_a3b_gb300_bf16_50steps_perf, cluster lyris.

Signed-off-by: oliver könig <okoenig@nvidia.com>
@ko3n1g
Copy link
Copy Markdown
Contributor Author

ko3n1g commented May 10, 2026

✅ Verified on triage container.

Probe pipeline (nemo-ci): pipeline 50829366 (job 314920873, cluster lyris, container dgxctestingtemp-nemofw-nightly.50761024, MBridge 1aeff19f7… = this PR HEAD, MCore 9f07e41148… from r26.04 pin).

Check Result
OOM at first forward ✅ gone — step 1/50 ran in 282s warmup; steady-state ~5.35s/iter
All 50 training steps ✅ completed, EXIT_CODE_TRAINING=0
Perf signed_diff +0.49% within ±5% (current_avg_gpu_util 299.77 vs golden 298.29) — TFLOPs preserved
Memory current_max_alloc 216.97 == golden 216.97 (identical)
Convergence

The leaf job shows failed only because the test config still has allow_failure: true + KNOWN_ISSUE_ID: 4034 and the reconciliation hook fires "Test is passing, but set to 'allowed to fail'." Once this PR is merged into r0.4.0 and the nemo-ci r26.04 MEGATRON_BRIDGE_COMMIT advances to include it, a follow-up nemo-ci MR will remove the KNOWN_ISSUE block and close issue dl/joc/nemo-ci#4034.

@ko3n1g ko3n1g requested a review from malay-nagda May 10, 2026 11:23
@ko3n1g ko3n1g merged commit eab9e5a into r0.4.0 May 11, 2026
53 checks passed
@ko3n1g ko3n1g deleted the ko3n1g/r0.4.0/fix-qwen3-next-gb300-oom branch May 11, 2026 07:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants