Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Jax dynamic API support in quantum control-flow primitives #370

Closed
wants to merge 167 commits into from

Conversation

sergei-mironov
Copy link
Contributor

@sergei-mironov sergei-mironov commented Nov 24, 2023

In this PR we enable the Jax dynamic API in the Catalyst quantum control flow primitives: while-loops, for-loops, conditionals in both quantum and classical tracing modes.

An extended description of the solution is available in the readme.

Earlier related internal documents:

Tasks:

  • conditionals
  • while-loops
  • for-loops
  • Finalization and clean-up
    • Fix remaining tests
    • Fix CodeFactor issues
    • Merge the PR with the recent main
    • Write the documentation and clean-up
      • Write a Readme file describing the proposed tracing architecture.
      • Split the utils/jax_extras.py into jax_extras/tracing.py, jax_extras/lowering.py and jax_extras/patches.py.
      • Code refactoring (Sergei expects David to put ToDo marks highlighting the problems)
        • Find better names for some classes/variables
        • Write comments where required
    • Fix CodeCov issues

[sc-50789]

@sergei-mironov sergei-mironov force-pushed the dynshape-quantum-primitives branch 2 times, most recently from 58e385a to c0c2208 Compare February 1, 2024 10:21
@dime10
Copy link
Collaborator

dime10 commented Feb 2, 2024

UPD: @rmoyard to think about it, running a benchmark is still a good thing to do anyways, do you have an idea what exactly to measure here?

I think the capture component is likely the only relevant thing here, so from starting the tracing of a user function to the MLIR generation. Ideally one of our more complex test cases (something that generates a decent bit of IR), and comparing the MLIR generation time before and after this PR.

Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fantastic document, great work @grwlf!

FRONTEND.md Show resolved Hide resolved
FRONTEND.md Outdated Show resolved Hide resolved
FRONTEND.md Show resolved Hide resolved
@sergei-mironov
Copy link
Contributor Author

@dime10 @rmoyard, Do you have any further updates regarding this PR? I would be glad to answer questions/suggestions.

@dime10
Copy link
Collaborator

dime10 commented Feb 8, 2024

Thanks for the PR @grwlf , I have some initial comments!

My main concern is that the frontend is getting even more complex, for a very specific case. Could we imagine an abstraction where we check if we are dealing with dynamic shaped arrays and switch between a set of original functions and new functions (dynamic)? If not I think we need to benchmark the potential regressions for non-dynamic shaped arrays and see if we accept the results.

As suggested I ran a benchmark comparing 9bdfe8b (current PR head) against 96a8d1e (last sync point in main) using the Grover benchmark in the repo for 23 qubits.

What I measured is the time to run the single statement fun.get_mlir(params) using %%timeit -r50. The results show no measurable difference (exceeding observed fluctuations) between the two commits of around 345ms.

I conclude that from a performance point of view this PR is unproblematic 👍

sergei-mironov pushed a commit that referenced this pull request Feb 23, 2024
This is an utility PR split `./utils/jax_extras.py` into three
sub-modules. The goal is to make a complex
#370 PR simpler.

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
rauletorresc pushed a commit that referenced this pull request Feb 26, 2024
This is an utility PR split `./utils/jax_extras.py` into three
sub-modules. The goal is to make a complex
#370 PR simpler.

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
sergei-mironov pushed a commit that referenced this pull request Jun 13, 2024
…loop) (#775)

**Context:** Catalyst had only limited support for dynamic-shaped
arrays. Notably, using them in control-flow primitives was not allowed.
By this PR we allow to use dynamic-shaped arrays in for-loops. In
subsequent PRs we plan to support other primitives as well.

**Description of the Change:**
* Add support for dynamic-shaped arrays in for-loops.

[sc-50789]
[sc-56476]
[sc-60522]

**Benefits:**

**Possible Drawbacks:**

Capturing dynamic-shaped arrays from outer scopes is not supported.
Capturing static arrays continues to work as usual.

``` python
@qjit(abstracted_axes={1: 'n'})
def g(x, y):

    @catalyst.for_loop(0, 10, 1)
    def loop(_, a):
        # Attempt to capture `x` from the outer scope.
        return a * x

    return jnp.sum(loop(y))

a = jnp.ones([1,3], dtype=float)
b = jnp.ones([1,3], dtype=float)
g(a, b)
```


**Related GitHub Issues:**
* This PR supersedes #370

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
sergei-mironov pushed a commit that referenced this pull request Jun 18, 2024
…, while-loop) (#777)

**Context:** Catalyst had only limited support for dynamic-shaped
arrays. Notably, using them in control-flow primitives was not allowed.
By this PR we allow to use dynamic-shaped arrays in while-loops and
conditionals. In subsequent PRs we plan to support other primitives as
well.

**Description of the Change:**

[sc-56476]
[sc-60522]
[sc-50789]

**Benefits:**

* Dynamically-shaped arrays can be used in the bodies of while and cond
primitives

**Possible Drawbacks:**

**Related GitHub Issues:**

* To be merged next to #775
* This PR supersedes #370

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
@sergei-mironov
Copy link
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants