Skip to content

v0.9.0

Choose a tag to compare

@mariogeiger mariogeiger released this 17 Feb 22:56
· 26 commits to main since this release
2b9cb54

0.9.0 (2026-02-17)

Added

  • GB10 (DGX Spark) support
  • Support for Python 3.13 and 3.14
  • [JAX] Support for Triton 3.6.0
  • [JAX] flax.nnx MACE example
  • [Torch/JAX] Deterministic indexing mode for uniform 1d kernels
  • [Torch/JAX] Parallel JIT compilation for uniform_1d kernels with per-kernel caching, significantly reducing compilation time. New optional environment variable CUEQUIVARIANCE_OPS_NVRTC_CACHE_DIR allows setting a directory for caching compiled kernels.
  • Documentation: new tutorials for JAX and PyTorch segmented polynomials

Bug fix

  • [JAX] Fixed Triton tuning issue for triangular multiplicative update
  • [JAX] Compatibility with JAX 0.8.2: fixed FFI interface and dtype casting issues when x64 mode is not enabled
  • [JAX] Improved triangle attention error messages
  • [Torch/JAX] Fixed yx_rotation descriptor
  • [Torch] TensorRT QDP plugin workaround

Breaking Changes

  • [Torch/JAX] The environment variable CUEQUIVARIANCE_OPS_USE_JIT no longer exists. JIT compilation is now the default behavior for uniform_1d kernels (already since few releases).
  • [Torch/JAX] Renamed filter_drop_unsued_operands to filter_drop_unused_operands (typo fix)
  • [Torch/JAX] Removed nvfatbin optional dependency
  • [Torch] Removed deprecated primitive classes: TensorProduct, EquivariantTensorProduct, SymmetricTensorProduct, and IWeightedSymmetricTensorProduct. Use cuet.SegmentedPolynomial with method='uniform_1d' instead, or the high-level APIs (cuet.ChannelWiseTensorProduct, cuet.FullyConnectedTensorProduct, cuet.SymmetricContraction). Attempting to import these classes will raise an ImportError with migration instructions.
  • [Torch] Removed deprecated low-level wrapper classes: TensorProductUniform1d, TensorProductUniform4x1d, TensorProductUniform3x1dIndexed, TensorProductUniform4x1dIndexed, and SymmetricTensorContraction from cuequivariance_ops_torch. Use torch.ops.cuequivariance.uniform_1d or cuet.SegmentedPolynomial instead.

Notes

  • [JAX] DGX Spark/GB10 (sm_121) with CUDA 12.9: This release uses PTX 87, which works correctly for most architectures but is not compatible with DGX Spark/GB10 on CUDA 12.9. To enable DGX Spark/GB10 support with CUDA 12.9, refer to #250 for a simple frontend integration tweak that restricts PTX 88 to sm_121 only. This fix will be merged after the 0.9.0 release.