-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Jax dynamic API support in quantum control-flow primitives #370
Conversation
58e385a
to
c0c2208
Compare
c0c2208
to
ee7baec
Compare
I think the capture component is likely the only relevant thing here, so from starting the tracing of a user function to the MLIR generation. Ideally one of our more complex test cases (something that generates a decent bit of IR), and comparing the MLIR generation time before and after this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fantastic document, great work @grwlf!
As suggested I ran a benchmark comparing 9bdfe8b (current PR head) against 96a8d1e (last sync point in main) using the Grover benchmark in the repo for 23 qubits. What I measured is the time to run the single statement I conclude that from a performance point of view this PR is unproblematic 👍 |
This is an utility PR split `./utils/jax_extras.py` into three sub-modules. The goal is to make a complex #370 PR simpler. --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
This is an utility PR split `./utils/jax_extras.py` into three sub-modules. The goal is to make a complex #370 PR simpler. --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
…loop) (#775) **Context:** Catalyst had only limited support for dynamic-shaped arrays. Notably, using them in control-flow primitives was not allowed. By this PR we allow to use dynamic-shaped arrays in for-loops. In subsequent PRs we plan to support other primitives as well. **Description of the Change:** * Add support for dynamic-shaped arrays in for-loops. [sc-50789] [sc-56476] [sc-60522] **Benefits:** **Possible Drawbacks:** Capturing dynamic-shaped arrays from outer scopes is not supported. Capturing static arrays continues to work as usual. ``` python @qjit(abstracted_axes={1: 'n'}) def g(x, y): @catalyst.for_loop(0, 10, 1) def loop(_, a): # Attempt to capture `x` from the outer scope. return a * x return jnp.sum(loop(y)) a = jnp.ones([1,3], dtype=float) b = jnp.ones([1,3], dtype=float) g(a, b) ``` **Related GitHub Issues:** * This PR supersedes #370 --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
…, while-loop) (#777) **Context:** Catalyst had only limited support for dynamic-shaped arrays. Notably, using them in control-flow primitives was not allowed. By this PR we allow to use dynamic-shaped arrays in while-loops and conditionals. In subsequent PRs we plan to support other primitives as well. **Description of the Change:** [sc-56476] [sc-60522] [sc-50789] **Benefits:** * Dynamically-shaped arrays can be used in the bodies of while and cond primitives **Possible Drawbacks:** **Related GitHub Issues:** * To be merged next to #775 * This PR supersedes #370 --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
This PR is outdated, so I am closing it. The same changes are already merged by: |
In this PR we enable the Jax dynamic API in the Catalyst quantum control flow primitives: while-loops, for-loops, conditionals in both quantum and classical tracing modes.
An extended description of the solution is available in the readme.
Earlier related internal documents:
Tasks:
utils/jax_extras.py
intojax_extras/tracing.py
,jax_extras/lowering.py
andjax_extras/patches.py
.[sc-50789]