Skip to content

GroupedLinear performance on Mixtral-8x7B at EP=2 #2939

@faradawn

Description

@faradawn

Summary

transformer_engine.pytorch.GroupedLinear is ~2× slower than a Python loop of F.linear calls (against the same stacked weight tensor) on Mixtral-8x7B at EP=2 / seq=8192 across batch sizes 1-16. nsys traces show the cuBLAS grouped-matmul path firing ~6× more nvjet_* GEMM launches plus splitKreduce + buffer-fill scaffolding. Filing as an investigation thread.

Context: tutorial work in #2642 (docs/examples/te_mixtral/).

Step-time gap

8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.

Batch Loop (F.linear × 4) GroupedLinear Ratio
1 240 ms 449 ms 1.87×
2 242 ms 474 ms 1.96×
4 242 ms 511 ms 2.11×
8 306 ms 617 ms 2.02×
16 506 ms OOM

Profile (nsys, batch=2, BF16, 30 timed steps)

nsys stats --report cuda_gpu_kern_sum on nsys_tier{2,3}_batch2.nsys-rep:

Kernel Loop instances GroupedLinear instances
nvjet_sm103_tst_* BF16 GEMM (sum) 10,693 66,083 (6.2×)
cublasLt::splitKreduce_kernel 0 48,823
FillFunctor<bf16> (zero-init) (not in top-15) 138,296
CUDAFunctor_add (not in top-15) 98,560
Total GPU kernel time (capture) ~70 s ~138 s (1.99×)

NCCL SendRecv is actually faster on the grouped path (37 s vs 46 s); the 2× wall-time gap is purely additional GPU compute kernels.

Working hypothesis

cuBLAS Lt selects a split-K plan for (M ≈ 4096, K = 4096, N = 14336) grouped GEMMs and decomposes one grouped matmul into many sub-tile launches plus reduction + buffer fills. Plain F.linear at the same shape runs in a single non-split-K nvjet launch.

Repro (in PR #2642)

cd docs/examples/te_mixtral

# Loop (Tier 2)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
    --improvement 2 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
    --warmup-steps 10 --train-steps 200

# GroupedLinear (Tier 3)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
    --improvement 3 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
    --warmup-steps 10 --train-steps 200

Code path: te_mixtral.py:551-580 selects expert_ffn_mode = "grouped" | "loop". Both paths use the same stacked weight tensor.

Environment

  • 8× NVIDIA B300 SXM
  • NGC pytorch-25.12-py3 container (CUDA 13.1)
  • torch 2.10, transformer_engine 2.10
  • Mixtral-8x7B v0.1, BF16 mixed precision

Questions

  1. Is split-K the intended cuBLAS plan for (M ≈ 4k, K = 4k, N = 14k) grouped GEMMs?
  2. Could GroupedLinear set a cublasLtMatmulPreference to discourage split-K when per-expert M is already large?
  3. Has anyone benchmarked GroupedLinear vs per-expert F.linear loop at Mixtral FFN shapes?

Happy to share full nsys traces (~500 MB each) on request. Tagging for visibility.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions