Skip to content

v2.7

Choose a tag to compare

@ptrendx ptrendx released this 01 Oct 00:20
· 643 commits to main since this release

Release v2.7

Key Features and Enhancements

  • [PyTorch] Added support for applying LayerNorm and RMSNorm to key and query tensors.
  • [PyTorch] Improved performance for FP8 per tensor current scaling recipe by fusing amax computation into activation kernel.
  • [PyTorch] Added support for multi-tensor swizzle kernels for MXFP8 grouped GEMMs.
  • [PyTorch] Fused zero-padding and swizzle operation for MXFP8 scale inverses for improved performance.
  • [PyTorch] Expanded the debug API using nvdlfw-inpect in order to log more advanced tensor statistics.
  • [PyTorch] Reduced the number of calls to CUDA driver for improved performance of the core library.
  • [Jax] Added new checkpointing policies that allow users to switch to TE GEMMs seamlessly without unnecessary recomputations.
  • [Core] Added support for cublasMP backend for overlapping TP communication and GEMM.

Fixed Issues

  • [PyTorch].Fixed a potential illegal memory access when using TP overlap.
  • [PyTorch] Fixed the logic for choosing the correct attention backend depending on the cuDNN version.
  • [PyTorch] Fixed a crash when using CUDA graphs by disabling garbage collection during capture.
  • [PyTorch] Fixed a bug when using double buffering for CPU offloading.
  • [PyTorch] Fixed a bug when overlapping gradient reduction and fusing weight gradient accumulation simultaneously.
  • [PyTorch] Made multiple improvements and fixes to TE sequential API, including expanding supported operations to cover dropout, constant scale, etc.
  • [PyTorch] Fixed a bug in the make_graphed_callables function when applied to multiple modules with different input requirements.
  • [PyTorch] Fixed the crash in the permute operation when running with the FP8 datatype for input sizes requiring padding.
  • [PyTorch] Fixed a bug when using the Triton cross entropy kernel with cuda graphs.
  • [PyTorch] Fixed a bug when exporting an MXFP8 model to ONNX.
  • [PyTorch/Core] Disabled cuDNN attention backend for cuDNN v9.12 onwards on blackwell if the user requests a deterministic config.
  • [Core] Fixed integer overflow in quantization kernels when computing offsets for large tensors.
  • [Jax] Fixed partition rules for GEMM to correctly handle sequence parallelism.
  • [Jax] Fixed sharding specs for TE GEMM custom call operands when using DP.
  • [Jax] Fixed a crash when using GroupedQuantizeFFI with cuda graphs
  • [Jax] Fixed the fused_attn sharding constraint so that it can be used under the JAX shard_map..

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

The deprecated device_id argument for multi tensor C APIs has been removed.

Deprecated Features

There are no deprecated features in this release.