Skip to content

v0.6.1

Choose a tag to compare

@hsadasivan hsadasivan released this 04 Sep 22:05
eb0ffd7

Latest Changes

0.6.1 (2025-09-04)

Added

  • [Torch/JAX] Support for variable leading batch dimensions in triangle multiplicative update
  • [Torch/JAX] Triangle attention kernel support for additional input configs: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16. In the rare instance that the kernel does not support an input config, fallback to torch is enabled instead of erroring out.
  • [Torch/JAX] Tuned config for RTX PRO 6000 GPUs for triangle multiplicative update.
  • [JAX] vmap support for triangle multiplicative update and triangle attention
  • [Torch] Improved error reporting on import failure with traceback information for stacktrace

Bug fix

  • [Torch/JAX] Fixed illegal memory access issue stemming from int32 indexing for longer sequences in triangle multiplicative update and attention with pair bias.
  • [JAX] Moved to using nondiff_argnums instead of nondiff_argnames to be compatible with older JAX versions