Skip to content

v2.16

Latest

Choose a tag to compare

@ksivaman ksivaman released this 09 Jun 01:15

Transformer Engine v2.16 Release Notes

Key Features and Enhancements

  • [Common] Improved the performance of the split-overlap reduce-scatter GEMMs. (#2056)
  • [Common] Improved the fused MoE auxiliary loss kernel performance for models with a large number of experts. (#2758)
  • [Common] Optimized MXFP8 and NVFP4 dequantize kernels for improved performance. (#2865)
  • [Common] Improved performance of the MXFP8 quantization kernels. (#2958)
  • [PyTorch] Added pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) attention. (#2596)
  • [PyTorch] Added role-based custom quantization control, enabling recipes to target specific modules and tensor types. (#2620)
  • [PyTorch] Added end-to-end Mixtral MoE examples showing TE GroupedLinear integration with HuggingFace models for BF16 and FP8 training. (#2642)
  • [PyTorch] Increased performance of the CPU activation offloading path in some cases (#2793)
  • [PyTorch] Reduced the CPU overhead in the GroupedLinear module and operation (#2900) (#2957) (#2666)
  • [PyTorch] Added CUDA Graph capture support for GroupedLinear and grouped MoE operations on supported configurations. (#2923)
  • [PyTorch] Added FlashAttention 4 support for attention head dimension 256. (#2932)
  • [JAX] Improved MoE permutation kernel performance. (#2975)
  • [JAX] Improved JAX tutorial documentation with updated examples and guidance. (#2976)
  • [Common, PyTorch] Added bias and dbias support for GroupedLinear layers. (#2885)
  • [Common, PyTorch] Added variable grouped swizzle support for flexible grouped tensor memory layouts. (#2914)
  • [Common, PyTorch] Implemented a row-scaled NVFP4 forward propagation recipe. (#2931)
  • [Common, PyTorch] Expanded grouped GEMM support with NVFP4 on Blackwell and FP8 block scaling on Hopper. (#2971)
  • [Common, JAX] Added a top-k operation for faster MoE routing. (#2890)
  • [Common, JAX] Enabled the cuDNN fused attention backend for no-mask bidirectional sliding-window attention. (#2961)

Fixed Issues

  • [PyTorch] Fixed variable-length attention cache reuse across devices and inference/training modes. (#2728)
  • [PyTorch] Fixed FSDP2 memory leaks for FP8 weight workspaces and transpose caches. (#2805)
  • [PyTorch] Fixed TE fuser behavior in torch.no_grad() paths by avoiding invalid gradient-flag updates on non-leaf tensors. (#2919)
  • [PyTorch] Fixed distributed checkpoint loading for FSDP2 for models initialized with QuantizedModelInit. (#2974)
  • [Common, PyTorch] Fixed cuBLAS grouped GEMM when weight dimensions are not divisible by 128. (#2954)
  • [Common, PyTorch] Fixed int32 overflow and -1 sentinel value handling in moe_permute. (#2907)
  • [Common, PyTorch] Fixed context-parallel FlashAttention output handling when FA3 is installed without FA2.(#2825)
  • [Common, PyTorch] Disabled RHT quantization fusion on unsupported GPU architectures to avoid launch failures. (#2968)
  • [PyTorch] Fixed a crash coming from GroupedLinear weight-gradient allocation. (#3049)

Breaking Changes in This Release

  • [Common, PyTorch] The original FP8 delayed-scaling fused attention path has been removed. FP8 attention now uses the current cuDNN-backed implementation. (#2959)
  • [Common, PyTorch, JAX] Removed the legacy f16_max512 fused-attention backend. BF16/FP16 attention is routed through the maintained arbitrary-sequence backend, but explicit selections of the old backend must be updated. (#2949)

Deprecated Features

There are no deprecated features in this release.