Transformer Engine v2.16 Release Notes
Key Features and Enhancements
- [Common] Improved the performance of the split-overlap reduce-scatter GEMMs. (#2056)
- [Common] Improved the fused MoE auxiliary loss kernel performance for models with a large number of experts. (#2758)
- [Common] Optimized MXFP8 and NVFP4 dequantize kernels for improved performance. (#2865)
- [Common] Improved performance of the MXFP8 quantization kernels. (#2958)
- [PyTorch] Added
pad_between_seqssupport for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) attention. (#2596) - [PyTorch] Added role-based custom quantization control, enabling recipes to target specific modules and tensor types. (#2620)
- [PyTorch] Added end-to-end Mixtral MoE examples showing TE GroupedLinear integration with HuggingFace models for BF16 and FP8 training. (#2642)
- [PyTorch] Increased performance of the CPU activation offloading path in some cases (#2793)
- [PyTorch] Reduced the CPU overhead in the GroupedLinear module and operation (#2900) (#2957) (#2666)
- [PyTorch] Added CUDA Graph capture support for GroupedLinear and grouped MoE operations on supported configurations. (#2923)
- [PyTorch] Added FlashAttention 4 support for attention head dimension 256. (#2932)
- [JAX] Improved MoE permutation kernel performance. (#2975)
- [JAX] Improved JAX tutorial documentation with updated examples and guidance. (#2976)
- [Common, PyTorch] Added bias and dbias support for GroupedLinear layers. (#2885)
- [Common, PyTorch] Added variable grouped swizzle support for flexible grouped tensor memory layouts. (#2914)
- [Common, PyTorch] Implemented a row-scaled NVFP4 forward propagation recipe. (#2931)
- [Common, PyTorch] Expanded grouped GEMM support with NVFP4 on Blackwell and FP8 block scaling on Hopper. (#2971)
- [Common, JAX] Added a top-k operation for faster MoE routing. (#2890)
- [Common, JAX] Enabled the cuDNN fused attention backend for no-mask bidirectional sliding-window attention. (#2961)
Fixed Issues
- [PyTorch] Fixed variable-length attention cache reuse across devices and inference/training modes. (#2728)
- [PyTorch] Fixed FSDP2 memory leaks for FP8 weight workspaces and transpose caches. (#2805)
- [PyTorch] Fixed TE fuser behavior in torch.no_grad() paths by avoiding invalid gradient-flag updates on non-leaf tensors. (#2919)
- [PyTorch] Fixed distributed checkpoint loading for FSDP2 for models initialized with
QuantizedModelInit. (#2974) - [Common, PyTorch] Fixed cuBLAS grouped GEMM when weight dimensions are not divisible by 128. (#2954)
- [Common, PyTorch] Fixed int32 overflow and -1 sentinel value handling in
moe_permute. (#2907) - [Common, PyTorch] Fixed context-parallel FlashAttention output handling when FA3 is installed without FA2.(#2825)
- [Common, PyTorch] Disabled RHT quantization fusion on unsupported GPU architectures to avoid launch failures. (#2968)
- [PyTorch] Fixed a crash coming from GroupedLinear weight-gradient allocation. (#3049)
Breaking Changes in This Release
- [Common, PyTorch] The original FP8 delayed-scaling fused attention path has been removed. FP8 attention now uses the current cuDNN-backed implementation. (#2959)
- [Common, PyTorch, JAX] Removed the legacy f16_max512 fused-attention backend. BF16/FP16 attention is routed through the maintained arbitrary-sequence backend, but explicit selections of the old backend must be updated. (#2949)
Deprecated Features
There are no deprecated features in this release.