Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Execute tapes in cuda-quantum #477

Merged
merged 184 commits into from
Feb 22, 2024
Merged

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented Jan 25, 2024

Context: A map from quantum tapes to CUDA-quantum's Python interface.

Description of the Change:

  • Adds a CUDA-quantum device
  • Adds a map from Catalyst JAXPR to CUDA-quantum's Python interface.

In this PR we add a CUDA-quantum device that describes the legal operations for CUDA-quantum kernels via the Python interface and we add a way to execute tapes in this CUDA-quantum device. The CUDA-quantum device targets qpp-cpu by default. Execution on CUDA-quantum's API is driven by a custom JAX interpreter. A JAX interpreter will iterate over Catalyst JAXPR but instead of executing Catalyst's instructions, it will execute equivalent instructions in CUDA-quantum's API.

Benefits: Users can execute tapes via CUDA-quantum's Python interface.

Possible Drawbacks: Due to symbol conflicts (possibly in OpenMP) running import cudaq before running any kokkos kernel will result in a segfault.

CUDA-quantum's operations that are not reachable from tapes:

  • ch
  • sdg
  • tdg
  • cs
  • ct
  • r1

Catalyst's operations that are unimplemented:

    zne_p: unimplemented_impl,
    qunitary_p: unimplemented_impl,
    hermitian_p: unimplemented_impl,
    tensorobs_p: unimplemented_impl,
    var_p: unimplemented_impl,
    probs_p: unimplemented_impl,
    cond_p: unimplemented_impl,
    while_p: unimplemented_impl,
    for_p: unimplemented_impl,
    grad_p: unimplemented_impl,
    func_p: unimplemented_impl,
    jvp_p: unimplemented_impl,
    vjp_p: unimplemented_impl,
    print_p: unimplemented_impl,

[sc-55098]

@erick-xanadu erick-xanadu force-pushed the eochoa/2024-01-24/cuda-quantum branch 2 times, most recently from ebe3c8c to c172170 Compare January 29, 2024 19:15
Copy link

codecov bot commented Jan 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (8649c9e) 99.52% compared to head (c0284c2) 99.54%.
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #477      +/-   ##
==========================================
+ Coverage   99.52%   99.54%   +0.02%     
==========================================
  Files          45       48       +3     
  Lines        7942     8430     +488     
  Branches      537      558      +21     
==========================================
+ Hits         7904     8392     +488     
  Misses         20       20              
  Partials       18       18              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@erick-xanadu erick-xanadu added the ci:build-wheels Run the wheel building workflows on this Pull Request label Feb 15, 2024
@erick-xanadu erick-xanadu removed the ci:build-wheels Run the wheel building workflows on this Pull Request label Feb 15, 2024
@erick-xanadu erick-xanadu added the ci:build-wheels Run the wheel building workflows on this Pull Request label Feb 16, 2024
@dime10 dime10 added this to the v0.5.0 milestone Feb 21, 2024
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @erick-xanadu, I think this is good to go pending some final adjustments 🎉

frontend/catalyst/cuda/__init__.py Show resolved Hide resolved
"RY",
"RZ",
"SWAP",
# "CSWAP", This is a bug in cuda-quantum. CSWAP is not exposed.
Copy link
Collaborator

@dime10 dime10 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Q only): Is it stated somewhere that it is included, but it actually isn't, so we keep it here to remind ourselves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found out on Friday that it was added on version 0.6.0 but it wasn't added on version 0.5.0 (which is what was available when I started the branch). I will add it before merging.

Comment on lines 51 to 52
one_compiler_per_distribution = pl_version == ">=0.32,<=0.34"
if one_compiler_per_distribution:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned this offline, but ideally we just require the right dev version of PL now :)

Comment on lines +280 to +281
# TODO(@erick-xanadu) why does the device instruction lists the whole
# name instead of a short name?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, in Catalyst that name is either what is returned by get_c_interface or what is in the (temporary) device map.

I think you're free to produce the short name for this compilation pipeline down on L838.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this ticket, which is relevant: #517 And also relevant to the discussion above about qdevice_p.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the qdevice_p primitive is there to enable the device to function. For Catalyst devices it needs the name of the C++ device so it can find the right entry point function in the low-level library (it is independent of whatever long or short name the Python device class has). For this compilation pipeline we are free to supply whatever name is suitable for it.

Maybe this can solve the issue in the discussion above by getting rid of the map entirely? Instead the qdevice_p primitive is already supplied by the right backend name upon instantiation, which can be sourced from the Python device class (similar to how get_c_interface obtains the right device name in Catalyst).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can solve the issue in the discussion above by getting rid of the map entirely? Instead the qdevice_p primitive is already supplied by the right backend name upon instantiation, which can be sourced from the Python device class (similar to how get_c_interface obtains the right device name in Catalyst).

I see what you mean. Thanks!

Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Erick, left some comments but looks great to me!

Would be good just to get some verification of what works and what doesn't. From my testing, the following works:

  • expval
  • counts with shots specified
  • basic arithmetic (tested addition, jnp.sin, and jnp.cos)

What doesn't work/errors:

  • var/probs (get clear error message)
  • expval(float * obs) (get vague JAX error)
  • expval(obs + obs) (get vague CUDA quantum error)
  • gradients (get vague JAX error)
  • counts/sample without shots specified (get a vague JAX error)

What provides unexpected results:

  • sample (I am getting arrays with odd integer values like 11)
  • state (it returns a cudaq.State() object, which I then have to convert to a JAX array manually)
  • I expected the function tracing to only happen once, and to not repeat on repeated execution of the function

from catalyst.cuda.catalyst_to_cuda_interpreter import interpret


def qjit(fn=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erick-xanadu is this the main user entry point for the catalyst QJIT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question I have: because qjit is not re-using/subclassing the Catalyst QJIT class, it doesn't have support of features like AOT (based on type signatures), autograph, etc.

Out of scope for this PR, but do you see a straightforward pathway for adding these features in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erick-xanadu is this the main user entry point for the catalyst QJIT?

No. This is only the entry point for CUDA catalyst QJIT. This is different from the Catalyst QJIT. (A bit confusing, two different functions).

it doesn't have support of features like AOT (based on type signatures), autograph, etc.

Right now we generate JAXPR every single time the function is called and then evaluate this JAXPR according to the CUDA Quantum interpreter, which executes the quantum instructions in CUDA Quantum. We can evaluate it abstractly in order to generate a new JAXPR and save this JAXPR representation. I think this can be added somewhat easily.

Comment on lines +70 to +73
def __init__(self, shots=None, wires=None, mps=False, multi_gpu=False):
self.mps = mps
self.multi_gpu = multi_gpu
super().__init__(wires=wires, shots=shots)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erick-xanadu recommend making this small change, just because it means multi_gpu and mps are options to softwareq.qpp, which doesn't make sense 🤔


with pytest.raises(CompileError, match="Cannot translate tapes with context"):
catalyst.cuda.qjit(wrapper)(1.0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erick-xanadu quick question; I notice that with the following circuit,

dev = qml.device("softwareq.qpp", wires=2)

@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    print("compiling")
    return qml.expval(qml.PauliZ(0))

the string 'compiling' prints with every execution. So it seems that the function is being traced with every execution -- is this expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, yes. We can save the JAXPR and avoid that, but I don't think it makes such a big impact yet. This is a completely different pipeline from what we had in the past, so it takes time to catch up.

@erick-xanadu
Copy link
Contributor Author

What doesn't work/errors:

Thanks! I can definitely improve these error messages!

What provides unexpected results:
sample (I am getting arrays with odd integer values like 11)
state (it returns a cudaq.State() object, which I then have to convert to a JAX array manually)
I expected the function tracing to only happen once, and to not repeat on repeated execution of the function

The first two are interesting. I'll have to take a look! Thanks!

@dime10 dime10 merged commit 5362cbc into main Feb 22, 2024
41 of 43 checks passed
@dime10 dime10 deleted the eochoa/2024-01-24/cuda-quantum branch February 22, 2024 16:18
rauletorresc pushed a commit that referenced this pull request Feb 26, 2024
**Context:** A map from quantum tapes to CUDA-quantum's Python
interface.

**Description of the Change:** 
* Adds a CUDA-quantum device
* Adds a map from Catalyst JAXPR to CUDA-quantum's Python interface.

In this PR we add a CUDA-quantum device that describes the legal
operations for CUDA-quantum kernels via the Python interface and we add
a way to execute tapes in this CUDA-quantum device. The CUDA-quantum
device targets `qpp-cpu` by default. Execution on CUDA-quantum's API is
driven by a custom JAX interpreter. A JAX interpreter will iterate over
Catalyst JAXPR but instead of executing Catalyst's instructions, it will
execute equivalent instructions in CUDA-quantum's API.

**Benefits:** Users can execute tapes via CUDA-quantum's Python
interface.

**Possible Drawbacks:** Due to symbol conflicts (possibly in OpenMP)
running `import cudaq` before running any kokkos kernel will result in a
segfault.

CUDA-quantum's operations that are not reachable from tapes:
* `ch`
* `sdg`
* `tdg`
* `cs`
* `ct`
* `r1`

Catalyst's operations that are unimplemented:
```
    zne_p: unimplemented_impl,
    qunitary_p: unimplemented_impl,
    hermitian_p: unimplemented_impl,
    tensorobs_p: unimplemented_impl,
    var_p: unimplemented_impl,
    probs_p: unimplemented_impl,
    cond_p: unimplemented_impl,
    while_p: unimplemented_impl,
    for_p: unimplemented_impl,
    grad_p: unimplemented_impl,
    func_p: unimplemented_impl,
    jvp_p: unimplemented_impl,
    vjp_p: unimplemented_impl,
    print_p: unimplemented_impl,
```

[sc-55098]

---------

Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci:build-wheels Run the wheel building workflows on this Pull Request frontend Pull requests that update the frontend requires-wheel-builds Pull Requests will need wheel building job successful before being merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants