Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Enable Jax dynamic API and use abstracted_axes in qjit #366

Merged
merged 82 commits into from
Dec 5, 2023

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented Nov 16, 2023

Context: Enable Jax Dynamic API and add abstracted_axes top-level argument

Description of the Change:

  • Enable dynamic_shapes option in JAX.
  • Automatically deduce implicit parameters when calling the compiled function
  • Remove implicit parameters from wrapped function arguments.

Benefits:

  • Jax dynamic API allows to compile tensors with variable dimension sizes.
  • jax.numpy.ones/zeros/empy etc could now be used in the compiled programs.

Drawbacks:

  • Accessing dynamically-shaped tensors inside Catalyst control-flow primitives is not yet supported
  • Abstracted axes of more than 3 dimensions appear to not work on some jax.lax functions.

Related GitHub Issues:

This PR contains a workaround for the following Jax issue

Also related: Please do not remove

[sc-47630]
[sc-47631]

@erick-xanadu erick-xanadu force-pushed the eochoa/2023-11-16/abstracted-axes branch 3 times, most recently from bda393e to bd41005 Compare November 17, 2023 20:27
@erick-xanadu erick-xanadu force-pushed the eochoa/2023-11-16/abstracted-axes branch from bd41005 to 1737f9d Compare November 17, 2023 20:29
@erick-xanadu erick-xanadu changed the title Use abstracted_axes in qjit [Frontend] Use abstracted_axes in qjit Nov 17, 2023
frontend/test/pytest/test_jax_dynamic_api.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Show resolved Hide resolved
frontend/test/pytest/test_jax_dynamic_api.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation looks great @erick-xanadu @grwlf!

doc/changelog.md Outdated Show resolved Hide resolved
doc/changelog.md Outdated Show resolved Hide resolved
frontend/catalyst/compilation_pipelines.py Show resolved Hide resolved
Co-authored-by: Josh Izaac <josh146@gmail.com>
@erick-xanadu erick-xanadu merged commit b0f9a3f into main Dec 5, 2023
19 checks passed
@erick-xanadu erick-xanadu deleted the eochoa/2023-11-16/abstracted-axes branch December 5, 2023 13:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants