You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Syntax such as a[1::100] where a is a JAX tensor inside a JIT appears to turn into a gather operation rather than a strided slice, at least on TPU. This is inefficient and jax.lax.slice already supports strides.
System info (python version, jaxlib version, accelerator, etc.)
N/A
The text was updated successfully, but these errors were encountered:
Thanks for the report! Lowering contiguous slices to lax.slice was an optimization we made a while ago, and at the time we scoped the problem to contiguous slices for simplicity. Adding strided slices to the logic would require adding support for them in this utility:
It seems like something that should work if normal slices work. It is surprising that a syntax supported by both Python indexing/NumPy and a feature supported by JAX doesn't work efficiently.
Description
Syntax such as
a[1::100]
wherea
is a JAX tensor inside a JIT appears to turn into a gather operation rather than a strided slice, at least on TPU. This is inefficient andjax.lax.slice
already supports strides.System info (python version, jaxlib version, accelerator, etc.)
N/A
The text was updated successfully, but these errors were encountered: