Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add from_plxpr conversion function #837

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

add from_plxpr conversion function #837

wants to merge 18 commits into from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jun 20, 2024

Context:

This PR ports the pennylane PR #5883 to catalyst and renames it from_plxpr.

Adding this code to catalyst solves many of the dependency and testing issues that arise from placing this code in the pennylane resposity. In the future, if we move more of the catalyst frontend into pennylane, this function may move there as well.

Description of the Change:

Adds a function for converting pennylane variant jaxpr into catalyst variant jaxpr.

Benefits:

Opens up the ability to have improved program capture directly integrated with catalyst.

Possible Drawbacks:

Related GitHub Issues:

[sc-61537]

Copy link

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @albi3ro
This approach is definitely much nicer to read :) But that's also due to the clear structure you gave these functions 💯

I had a good bit of nitpicky language/naming/comment/error message comments, otherwise some optional suggestions and ideas.

doc/changelog.md Outdated Show resolved Hide resolved
frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
Comment on lines +88 to +89
# Note that the value of rtd_kwargs is a string version of
# the info kwargs, not the info kwargs itself

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice for future development to mention why we do this :)

Copy link
Contributor Author

@albi3ro albi3ro Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I'm not sure why. Think it has to do with mlir.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! @dime10 , do you have context here? I think it might be nice to comment on it :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This is just the easiest way to serialize the data since it is going through the IR. The alternative would be to provide typed attributes for each config option which is more work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info. Adding that as a clarifying source code comment.

frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
jaxpr = jax.make_jaxpr(circuit)()
qml.capture.disable()

with pytest.raises(NotImplementedError):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing a suggestion from the code itself.

Suggested change
with pytest.raises(NotImplementedError):
with pytest.raises(NotImplementedError, match="Only wire-based and observable-based measurements"):

frontend/test/pytest/test_from_plxpr.py Outdated Show resolved Hide resolved
call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"]
call_jaxpr_c = catalxpr.eqns[0].params["call_jaxpr"]

compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could parametrize the expval, probs, var tests, they are very similar.

frontend/test/pytest/test_from_plxpr.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_from_plxpr.py Show resolved Hide resolved
albi3ro added a commit to PennyLaneAI/pennylane that referenced this pull request Jun 21, 2024
**Context:**

Catalyst PR #837 (PennyLaneAI/catalyst#837)
needs a couple minor updates to the capture module.

**Description of the Change:**

1) makes it possible to do `from pennylane.capture import
AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to
touch private functions

2) Adds `qnode` as a keyword argument that gets bound to the qnode
primitive

3) Makes it so we can capture a sample measurement specified like
`qml.sample(wires=1)`

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-66703]

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
mudit2812 pushed a commit to PennyLaneAI/pennylane that referenced this pull request Jul 2, 2024
**Context:**

Catalyst PR #837 (PennyLaneAI/catalyst#837)
needs a couple minor updates to the capture module.

**Description of the Change:**

1) makes it possible to do `from pennylane.capture import
AbstractOperator, AbstractMeasurement, qnode_prim` so we don't have to
touch private functions

2) Adds `qnode` as a keyword argument that gets bound to the qnode
primitive

3) Makes it so we can capture a sample measurement specified like
`qml.sample(wires=1)`

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-66703]

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Copy link

codecov bot commented Jul 15, 2024

Codecov Report

Attention: Patch coverage is 96.29630% with 5 lines in your changes missing coverage. Please review.

Project coverage is 97.92%. Comparing base (5fa4b21) to head (f1f90aa).

Files Patch % Lines
frontend/catalyst/from_plxpr.py 96.29% 2 Missing and 3 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #837      +/-   ##
==========================================
- Coverage   97.93%   97.92%   -0.02%     
==========================================
  Files          73       74       +1     
  Lines       10330    10465     +135     
  Branches     1170     1211      +41     
==========================================
+ Hits        10117    10248     +131     
- Misses        170      171       +1     
- Partials       43       46       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
for s in shots:
for m in measurements:
shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires)
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is interesting, I wonder if it is actually needed? Abstract JAX arrays can already hold Python types like int, float, complex which are considered "weak types" and will adapt to the bitwidth of "strong types" like int64, or fall back to the same types defined in the map above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like tests still pass when getting rid of it, so I'm going to say we don't need it until proven otherwise.

frontend/catalyst/from_plxpr.py Outdated Show resolved Hide resolved
operator is consumed.
"""

def read(self, var):
Copy link
Collaborator

@dime10 dime10 Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is really nicely structured! The only thing I'm wondering is whether we can't use the interpreter state for the whole conversion procedure, since it seems to duplicate environment and free standing read function above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of defining a base class for jaxpr interpreters and reusing the structure for a variety of different algorithms using the template design pattern.

You can see a prototype here:

https://github.com/PennyLaneAI/pennylane/blob/plxpr-interpreter/pennylane/capture/interpreters.py

The design and implementation is going to take a little bit of work, but in the end, we will be able to use that design to make this code much nicer :) I do see this code getting restructured once that happens.

Right now we free standing functions with a mutable input, but we could also make them all class methods. I'm fine with promoting everything to class methods now too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewriting te code to follow the structure I'm thinking of for a "Plxpr interpreter template".


_deallocate(state)
# Read the final result of the Jaxpr from the environment
return [state.read(outvar) for outvar in measurements]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this is expected for now, but just noting that here we can only ever return MP results, not values from classical ops.

wires = [state.get_wire(w) for w in wire_values]

invals = [state.read(invar) for invar in eqn.invars[:-n_wires]]
outvals = qinst_p.bind(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two custom primitives because they deviate from the form qinst_p has, whose conversion isn't included here:

  • qunitary_p
  • gphase_p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding.

@dime10
Copy link
Collaborator

dime10 commented Jul 30, 2024

Thanks for your comments @albi3ro, let me know once you are finished with them and I'll give it a final look :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants