Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strided indexing turns into gather #21151

Open
jewillco opened this issue May 9, 2024 · 3 comments
Open

Strided indexing turns into gather #21151

jewillco opened this issue May 9, 2024 · 3 comments
Labels
enhancement New feature or request performance make things lean and fast

Comments

@jewillco
Copy link

jewillco commented May 9, 2024

Description

Syntax such as a[1::100] where a is a JAX tensor inside a JIT appears to turn into a gather operation rather than a strided slice, at least on TPU. This is inefficient and jax.lax.slice already supports strides.

System info (python version, jaxlib version, accelerator, etc.)

N/A

@jewillco jewillco added the bug Something isn't working label May 9, 2024
@jakevdp jakevdp added enhancement New feature or request performance make things lean and fast and removed bug Something isn't working labels May 9, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented May 9, 2024

Thanks for the report! Lowering contiguous slices to lax.slice was an optimization we made a while ago, and at the time we scoped the problem to contiguous slices for simplicity. Adding strided slices to the logic would require adding support for them in this utility:

def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None:

@jewillco
Copy link
Author

jewillco commented May 9, 2024

It seems like something that should work if normal slices work. It is surprising that a syntax supported by both Python indexing/NumPy and a feature supported by JAX doesn't work efficiently.

@jakevdp
Copy link
Collaborator

jakevdp commented May 9, 2024

Agreed, that's why I marked this as a performance-related enhancement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

2 participants