-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Execute tapes in cuda-quantum #477
Conversation
ebe3c8c
to
c172170
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #477 +/- ##
==========================================
+ Coverage 99.52% 99.54% +0.02%
==========================================
Files 45 48 +3
Lines 7942 8430 +488
Branches 537 558 +21
==========================================
+ Hits 7904 8392 +488
Misses 20 20
Partials 18 18 ☔ View full report in Codecov by Sentry. |
This reverts commit e44f48c.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @erick-xanadu, I think this is good to go pending some final adjustments 🎉
"RY", | ||
"RZ", | ||
"SWAP", | ||
# "CSWAP", This is a bug in cuda-quantum. CSWAP is not exposed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Q only): Is it stated somewhere that it is included, but it actually isn't, so we keep it here to remind ourselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found out on Friday that it was added on version 0.6.0 but it wasn't added on version 0.5.0 (which is what was available when I started the branch). I will add it before merging.
one_compiler_per_distribution = pl_version == ">=0.32,<=0.34" | ||
if one_compiler_per_distribution: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned this offline, but ideally we just require the right dev version of PL now :)
# TODO(@erick-xanadu) why does the device instruction lists the whole | ||
# name instead of a short name? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, in Catalyst that name is either what is returned by get_c_interface
or what is in the (temporary) device map.
I think you're free to produce the short name for this compilation pipeline down on L838.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this ticket, which is relevant: #517 And also relevant to the discussion above about qdevice_p
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well the qdevice_p
primitive is there to enable the device to function. For Catalyst devices it needs the name of the C++ device so it can find the right entry point function in the low-level library (it is independent of whatever long or short name the Python device class has). For this compilation pipeline we are free to supply whatever name is suitable for it.
Maybe this can solve the issue in the discussion above by getting rid of the map entirely? Instead the qdevice_p
primitive is already supplied by the right backend name upon instantiation, which can be sourced from the Python device class (similar to how get_c_interface
obtains the right device name in Catalyst).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can solve the issue in the discussion above by getting rid of the map entirely? Instead the qdevice_p primitive is already supplied by the right backend name upon instantiation, which can be sourced from the Python device class (similar to how get_c_interface obtains the right device name in Catalyst).
I see what you mean. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Erick, left some comments but looks great to me!
Would be good just to get some verification of what works and what doesn't. From my testing, the following works:
- expval
- counts with shots specified
- basic arithmetic (tested addition,
jnp.sin
, andjnp.cos
)
What doesn't work/errors:
- var/probs (get clear error message)
- expval(float * obs) (get vague JAX error)
- expval(obs + obs) (get vague CUDA quantum error)
- gradients (get vague JAX error)
- counts/sample without shots specified (get a vague JAX error)
What provides unexpected results:
- sample (I am getting arrays with odd integer values like 11)
- state (it returns a
cudaq.State()
object, which I then have to convert to a JAX array manually) - I expected the function tracing to only happen once, and to not repeat on repeated execution of the function
from catalyst.cuda.catalyst_to_cuda_interpreter import interpret | ||
|
||
|
||
def qjit(fn=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@erick-xanadu is this the main user entry point for the catalyst QJIT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick question I have: because qjit
is not re-using/subclassing the Catalyst QJIT
class, it doesn't have support of features like AOT (based on type signatures), autograph, etc.
Out of scope for this PR, but do you see a straightforward pathway for adding these features in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@erick-xanadu is this the main user entry point for the catalyst QJIT?
No. This is only the entry point for CUDA catalyst QJIT. This is different from the Catalyst QJIT. (A bit confusing, two different functions).
it doesn't have support of features like AOT (based on type signatures), autograph, etc.
Right now we generate JAXPR every single time the function is called and then evaluate this JAXPR according to the CUDA Quantum interpreter, which executes the quantum instructions in CUDA Quantum. We can evaluate it abstractly in order to generate a new JAXPR and save this JAXPR representation. I think this can be added somewhat easily.
def __init__(self, shots=None, wires=None, mps=False, multi_gpu=False): | ||
self.mps = mps | ||
self.multi_gpu = multi_gpu | ||
super().__init__(wires=wires, shots=shots) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@erick-xanadu recommend making this small change, just because it means multi_gpu
and mps
are options to softwareq.qpp
, which doesn't make sense 🤔
|
||
with pytest.raises(CompileError, match="Cannot translate tapes with context"): | ||
catalyst.cuda.qjit(wrapper)(1.0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@erick-xanadu quick question; I notice that with the following circuit,
dev = qml.device("softwareq.qpp", wires=2)
@qjit
@qml.qnode(dev)
def circuit(x):
qml.RX(x, wires=0)
print("compiling")
return qml.expval(qml.PauliZ(0))
the string 'compiling' prints with every execution. So it seems that the function is being traced with every execution -- is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, yes. We can save the JAXPR and avoid that, but I don't think it makes such a big impact yet. This is a completely different pipeline from what we had in the past, so it takes time to catch up.
Thanks! I can definitely improve these error messages!
The first two are interesting. I'll have to take a look! Thanks! |
**Context:** A map from quantum tapes to CUDA-quantum's Python interface. **Description of the Change:** * Adds a CUDA-quantum device * Adds a map from Catalyst JAXPR to CUDA-quantum's Python interface. In this PR we add a CUDA-quantum device that describes the legal operations for CUDA-quantum kernels via the Python interface and we add a way to execute tapes in this CUDA-quantum device. The CUDA-quantum device targets `qpp-cpu` by default. Execution on CUDA-quantum's API is driven by a custom JAX interpreter. A JAX interpreter will iterate over Catalyst JAXPR but instead of executing Catalyst's instructions, it will execute equivalent instructions in CUDA-quantum's API. **Benefits:** Users can execute tapes via CUDA-quantum's Python interface. **Possible Drawbacks:** Due to symbol conflicts (possibly in OpenMP) running `import cudaq` before running any kokkos kernel will result in a segfault. CUDA-quantum's operations that are not reachable from tapes: * `ch` * `sdg` * `tdg` * `cs` * `ct` * `r1` Catalyst's operations that are unimplemented: ``` zne_p: unimplemented_impl, qunitary_p: unimplemented_impl, hermitian_p: unimplemented_impl, tensorobs_p: unimplemented_impl, var_p: unimplemented_impl, probs_p: unimplemented_impl, cond_p: unimplemented_impl, while_p: unimplemented_impl, for_p: unimplemented_impl, grad_p: unimplemented_impl, func_p: unimplemented_impl, jvp_p: unimplemented_impl, vjp_p: unimplemented_impl, print_p: unimplemented_impl, ``` [sc-55098] --------- Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Co-authored-by: David Ittah <dime10@users.noreply.github.com> Co-authored-by: Josh Izaac <josh146@gmail.com>
Context: A map from quantum tapes to CUDA-quantum's Python interface.
Description of the Change:
In this PR we add a CUDA-quantum device that describes the legal operations for CUDA-quantum kernels via the Python interface and we add a way to execute tapes in this CUDA-quantum device. The CUDA-quantum device targets
qpp-cpu
by default. Execution on CUDA-quantum's API is driven by a custom JAX interpreter. A JAX interpreter will iterate over Catalyst JAXPR but instead of executing Catalyst's instructions, it will execute equivalent instructions in CUDA-quantum's API.Benefits: Users can execute tapes via CUDA-quantum's Python interface.
Possible Drawbacks: Due to symbol conflicts (possibly in OpenMP) running
import cudaq
before running any kokkos kernel will result in a segfault.CUDA-quantum's operations that are not reachable from tapes:
ch
sdg
tdg
cs
ct
r1
Catalyst's operations that are unimplemented:
[sc-55098]