v2.14
Transformer Engine v2.14 Release Notes
Key Features and Enhancements
- [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
- [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
- [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
- [PyTorch] Added support for a single-parameter
GroupedLinearconfiguration, where the weights of all experts are stored in a single parameter, which reduces CPU overheads. (#2731) - [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter
GroupedLinear. (#2761) - [PyTorch] Extended the fused attention API to optionally return softmax
Statsalways andMaxwhenreturn_max_logit=True, exposing more cuDNN intermediates to users. (#2677) - [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
- [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in
FusedAdam. (#2753) - [PyTorch] Added CUDA graph-compatible
multi_tensor_scale_tensorAPI in the optimizer. (#2594) - [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
- [PyTorch] Added support for non-FP32
params_dtypewhen using QK-normalization. (#2718) - [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
- [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
- [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
- [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
- [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
- [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
- [C] Enabled the fused RMSNorm
dLN + addbackward path through cuDNN for faster fused-residual normalization. (#2778) - [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
- [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
- [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
- [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
- [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
- [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
- [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
- [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
- [Documentation] Added documentation for the operator fuser API. (#2447)
- [PyTorch, Documentation] Added end-to-end examples for
fused_adam,quantized_model_init, and FSDP2 usage. (#2698) (#2662)
Fixed Issues
- [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): when model parameters are
DTensors, ensure optimizer states are alsoDTensors for correct sharded checkpoints. (#2795) - [PyTorch] Fixed async DCP checkpointing for
Float8Tensorparameters. (#2721) - [PyTorch] Fixed the issue with
cross_entropy_forwardproducing wrong answers for non-contiguous logits. (#2746) - [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
- [PyTorch] Fixed a precision-debug-tools crash when
tp_group=None. (#2733) - [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
- [PyTorch] Fixed the initialization of the learnable
softmax_offsetparameter inDotProductAttentionto zero-initialization. (#2694) - [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
- [PyTorch] Added a clear error when constructing
LayerNormLinearwith row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688) - [JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)
- [JAX] Fixed the assertion error when using
from_segment_ids_and_pos()withvmap. (#2692) - [JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)
- [JAX] Changed the dtype of the intermediate-result aval in
fused_topk_and_score_function_fwdtofp32to avoid precision loss. (#2752) - [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)
- [C/PyTorch] Fixed score normalization in
fused_score_for_moe_aux_losswhentopk == 1. (#2720) - [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)
Breaking Changes in This Release
- [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)
Deprecated Features
There are no deprecated features in this release.