v2.7
Release v2.7
Key Features and Enhancements
- [PyTorch] Added support for applying LayerNorm and RMSNorm to key and query tensors.
- [PyTorch] Improved performance for FP8 per tensor current scaling recipe by fusing amax computation into activation kernel.
- [PyTorch] Added support for multi-tensor swizzle kernels for MXFP8 grouped GEMMs.
- [PyTorch] Fused zero-padding and swizzle operation for MXFP8 scale inverses for improved performance.
- [PyTorch] Expanded the debug API using
nvdlfw-inpectin order to log more advanced tensor statistics. - [PyTorch] Reduced the number of calls to CUDA driver for improved performance of the core library.
- [Jax] Added new checkpointing policies that allow users to switch to TE GEMMs seamlessly without unnecessary recomputations.
- [Core] Added support for cublasMP backend for overlapping TP communication and GEMM.
Fixed Issues
- [PyTorch].Fixed a potential illegal memory access when using TP overlap.
- [PyTorch] Fixed the logic for choosing the correct attention backend depending on the cuDNN version.
- [PyTorch] Fixed a crash when using CUDA graphs by disabling garbage collection during capture.
- [PyTorch] Fixed a bug when using double buffering for CPU offloading.
- [PyTorch] Fixed a bug when overlapping gradient reduction and fusing weight gradient accumulation simultaneously.
- [PyTorch] Made multiple improvements and fixes to TE sequential API, including expanding supported operations to cover dropout, constant scale, etc.
- [PyTorch] Fixed a bug in the
make_graphed_callablesfunction when applied to multiple modules with different input requirements. - [PyTorch] Fixed the crash in the permute operation when running with the FP8 datatype for input sizes requiring padding.
- [PyTorch] Fixed a bug when using the Triton cross entropy kernel with cuda graphs.
- [PyTorch] Fixed a bug when exporting an MXFP8 model to ONNX.
- [PyTorch/Core] Disabled cuDNN attention backend for cuDNN v9.12 onwards on blackwell if the user requests a deterministic config.
- [Core] Fixed integer overflow in quantization kernels when computing offsets for large tensors.
- [Jax] Fixed partition rules for GEMM to correctly handle sequence parallelism.
- [Jax] Fixed sharding specs for TE GEMM custom call operands when using DP.
- [Jax] Fixed a crash when using
GroupedQuantizeFFIwith cuda graphs - [Jax] Fixed the fused_attn sharding constraint so that it can be used under the JAX shard_map..
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
The deprecated device_id argument for multi tensor C APIs has been removed.
Deprecated Features
There are no deprecated features in this release.