Skip to content

TE EP MXFP8 fails with last_dim % MXFP8_BLOCK_SIZE == 0 #2966

@faradawn

Description

@faradawn

Summary

transformer_engine.pytorch.ops.GroupedLinear fails with MXFP8BlockScaling in a Mixtral MoE expert-parallel run when per-expert token splits are not divisible by 32.

Repro context

  • Model: Mixtral-8x7B
  • GPUs: 8x B300
  • EP size: 2
  • Batch size: 8
  • Sequence length: 8192
  • Precision recipe: MXFP8BlockScaling
  • Expert FFN path: transformer_engine.pytorch.ops.GroupedLinear
  • Dispatcher: NCCL all-to-all
  • Tutorial/example PR: Add examples for MoE models - Mixtral in TE #2642

Repro command:

cd /lustre/fsw/coreai_prod_infbench/faradawny/TransformerEngine/docs/examples/te_mixtral

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
  --improvement 8 \
  --ep-size 2 \
  --batch-size 8 \
  --max-seq-length 8192 \
  --warmup-steps 5 \
  --train-steps 10 \
  2>&1 | tee logs/sweep_seq8k_ep2_8gpus_sequential_ops/seq8k_batch8_ep2_tier8_sequential_ops_mxfp8.log

Error

[rank0]:   File "/lustre/fsw/coreai_prod_infbench/faradawny/TransformerEngine/docs/examples/te_mixtral/te_mixtral.py", line 687, in _expert_ffn
[rank0]:     gate_up_output = self.experts_gate_up(tokens, split_sizes)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/op.py", line 522, in forward
[rank0]:     return OperationFuser([self])(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/basic/grouped_linear.py", line 739, in fuser_forward
[rank0]:     xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/TransformerEngine/transformer_engine/pytorch/csrc/quantizer.cpp:1668 in function get_scale_shape: Assertion failed: last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0. MXFP8 requires tensor dims that are divisible by 32 (got shape=(2283,4096))

Other ranks fail similarly with shapes like (1441,4096), (1178,4096), and (1225,4096).

Expected

The Sequential Ops grouped path should either handle MXFP8 padding internally per split, or provide a clear documented requirement/workaround for MoE token splits whose per-expert token counts are not multiples of 32.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions