You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit was created on GitHub.com and signed with GitHub’s verified signature.
Latest Changes
0.6.1 (2025-09-04)
Added
[Torch/JAX] Support for variable leading batch dimensions in triangle multiplicative update
[Torch/JAX] Triangle attention kernel support for additional input configs: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16. In the rare instance that the kernel does not support an input config, fallback to torch is enabled instead of erroring out.
[Torch/JAX] Tuned config for RTX PRO 6000 GPUs for triangle multiplicative update.
[JAX] vmap support for triangle multiplicative update and triangle attention
[Torch] Improved error reporting on import failure with traceback information for stacktrace
Bug fix
[Torch/JAX] Fixed illegal memory access issue stemming from int32 indexing for longer sequences in triangle multiplicative update and attention with pair bias.
[JAX] Moved to using nondiff_argnums instead of nondiff_argnames to be compatible with older JAX versions