-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add from_plxpr conversion function #837
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work @albi3ro
This approach is definitely much nicer to read :) But that's also due to the clear structure you gave these functions 💯
I had a good bit of nitpicky language/naming/comment/error message comments, otherwise some optional suggestions and ideas.
# Note that the value of rtd_kwargs is a string version of | ||
# the info kwargs, not the info kwargs itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be nice for future development to mention why we do this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... I'm not sure why. Think it has to do with mlir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! @dime10 , do you have context here? I think it might be nice to comment on it :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This is just the easiest way to serialize the data since it is going through the IR. The alternative would be to provide typed attributes for each config option which is more work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info. Adding that as a clarifying source code comment.
jaxpr = jax.make_jaxpr(circuit)() | ||
qml.capture.disable() | ||
|
||
with pytest.raises(NotImplementedError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reusing a suggestion from the code itself.
with pytest.raises(NotImplementedError): | |
with pytest.raises(NotImplementedError, match="Only wire-based and observable-based measurements"): |
call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] | ||
call_jaxpr_c = catalxpr.eqns[0].params["call_jaxpr"] | ||
|
||
compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could parametrize the expval, probs, var
tests, they are very similar.
**Context:** Catalyst PR #837 (PennyLaneAI/catalyst#837) needs a couple minor updates to the capture module. **Description of the Change:** 1) makes it possible to do `from pennylane.capture import AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to touch private functions 2) Adds `qnode` as a keyword argument that gets bound to the qnode primitive 3) Makes it so we can capture a sample measurement specified like `qml.sample(wires=1)` **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-66703] --------- Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
**Context:** Catalyst PR #837 (PennyLaneAI/catalyst#837) needs a couple minor updates to the capture module. **Description of the Change:** 1) makes it possible to do `from pennylane.capture import AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to touch private functions 2) Adds `qnode` as a keyword argument that gets bound to the qnode primitive 3) Makes it so we can capture a sample measurement specified like `qml.sample(wires=1)` **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-66703] --------- Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #837 +/- ##
==========================================
- Coverage 97.93% 97.92% -0.02%
==========================================
Files 73 74 +1
Lines 10330 10465 +135
Branches 1170 1211 +41
==========================================
+ Hits 10117 10248 +131
- Misses 170 171 +1
- Partials 43 46 +3 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
for s in shots: | ||
for m in measurements: | ||
shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires) | ||
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This map is interesting, I wonder if it is actually needed? Abstract JAX arrays can already hold Python types like int
, float
, complex
which are considered "weak types" and will adapt to the bitwidth of "strong types" like int64
, or fall back to the same types defined in the map above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like tests still pass when getting rid of it, so I'm going to say we don't need it until proven otherwise.
operator is consumed. | ||
""" | ||
|
||
def read(self, var): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is really nicely structured! The only thing I'm wondering is whether we can't use the interpreter state for the whole conversion procedure, since it seems to duplicate environment and free standing read function above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking of defining a base class for jaxpr interpreters and reusing the structure for a variety of different algorithms using the template design pattern.
You can see a prototype here:
https://github.com/PennyLaneAI/pennylane/blob/plxpr-interpreter/pennylane/capture/interpreters.py
The design and implementation is going to take a little bit of work, but in the end, we will be able to use that design to make this code much nicer :) I do see this code getting restructured once that happens.
Right now we free standing functions with a mutable input, but we could also make them all class methods. I'm fine with promoting everything to class methods now too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewriting te code to follow the structure I'm thinking of for a "Plxpr interpreter template".
|
||
_deallocate(state) | ||
# Read the final result of the Jaxpr from the environment | ||
return [state.read(outvar) for outvar in measurements] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this is expected for now, but just noting that here we can only ever return MP results, not values from classical ops.
wires = [state.get_wire(w) for w in wire_values] | ||
|
||
invals = [state.read(invar) for invar in eqn.invars[:-n_wires]] | ||
outvals = qinst_p.bind( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have two custom primitives because they deviate from the form qinst_p
has, whose conversion isn't included here:
- qunitary_p
- gphase_p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding.
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Thanks for your comments @albi3ro, let me know once you are finished with them and I'll give it a final look :) |
Context:
This PR ports the pennylane PR #5883 to catalyst and renames it
from_plxpr
.Adding this code to catalyst solves many of the dependency and testing issues that arise from placing this code in the pennylane resposity. In the future, if we move more of the catalyst frontend into pennylane, this function may move there as well.
Description of the Change:
Adds a function for converting pennylane variant jaxpr into catalyst variant jaxpr.
Benefits:
Opens up the ability to have improved program capture directly integrated with catalyst.
Possible Drawbacks:
Related GitHub Issues:
[sc-61537]