Skip to content

v2.14

Choose a tag to compare

@ksivaman ksivaman released this 21 Apr 21:57
· 2 commits to release_v2.14 since this release

Transformer Engine v2.14 Release Notes

Key Features and Enhancements

  • [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
  • [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
  • [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
  • [PyTorch] Added support for a single-parameter GroupedLinear configuration, where the weights of all experts are stored in a single parameter, which reduces CPU overheads. (#2731)
  • [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter GroupedLinear. (#2761)
  • [PyTorch] Extended the fused attention API to optionally return softmax Stats always and Max when return_max_logit=True, exposing more cuDNN intermediates to users. (#2677)
  • [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
  • [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in FusedAdam. (#2753)
  • [PyTorch] Added CUDA graph-compatible multi_tensor_scale_tensor API in the optimizer. (#2594)
  • [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
  • [PyTorch] Added support for non-FP32 params_dtype when using QK-normalization. (#2718)
  • [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
  • [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
  • [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
  • [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
  • [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
  • [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
  • [C] Enabled the fused RMSNorm dLN + add backward path through cuDNN for faster fused-residual normalization. (#2778)
  • [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
  • [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
  • [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
  • [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
  • [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
  • [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
  • [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
  • [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
  • [Documentation] Added documentation for the operator fuser API. (#2447)
  • [PyTorch, Documentation] Added end-to-end examples for fused_adam, quantized_model_init, and FSDP2 usage. (#2698) (#2662)

Fixed Issues

  • [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): when model parameters are DTensors, ensure optimizer states are also DTensors for correct sharded checkpoints. (#2795)
  • [PyTorch] Fixed async DCP checkpointing for Float8Tensor parameters. (#2721)
  • [PyTorch] Fixed the issue with cross_entropy_forward producing wrong answers for non-contiguous logits. (#2746)
  • [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
  • [PyTorch] Fixed a precision-debug-tools crash when tp_group=None. (#2733)
  • [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
  • [PyTorch] Fixed the initialization of the learnable softmax_offset parameter in DotProductAttention to zero-initialization. (#2694)
  • [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
  • [PyTorch] Added a clear error when constructing LayerNormLinear with row-wise tensor parallelism (an unsupported configuration). Previously this configuration would fail with the CUDA error (#2688)
  • [JAX] Fixed the performance issue with THD/BSHD segment-position generation. (#2823)
  • [JAX] Fixed the assertion error when using from_segment_ids_and_pos() with vmap. (#2692)
  • [JAX] Fixed the performance issue for models using both FSDP and EP. (#2649)
  • [JAX] Changed the dtype of the intermediate-result aval in fused_topk_and_score_function_fwd to fp32 to avoid precision loss. (#2752)
  • [C] Fixed an incorrect MNNVL fabric-availability check that misreported support on some systems. (#2626)
  • [C/PyTorch] Fixed score normalization in fused_score_for_moe_aux_loss when topk == 1. (#2720)
  • [PyTorch] Fixed the possible precision loss when copying from the quantized tensor to the high precision tensor. (#2120, #2673)

Breaking Changes in This Release

  • [JAX] GSPMD partitioning rules are no longer tested and will now warn on use; users on JAX with GSPMD should migrate to Shardy. (#2702)

Deprecated Features

There are no deprecated features in this release.