Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Measurements Pytrees #4607

Merged
merged 8 commits into from
Sep 20, 2023
Merged

Make Measurements Pytrees #4607

merged 8 commits into from
Sep 20, 2023

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Sep 18, 2023

This PR registers all MeasurementProcess objects as jax pytrees. [sc-40588]

@codecov
Copy link

codecov bot commented Sep 19, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (c6542b4) 99.62% compared to head (fc6b71d) 99.62%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #4607   +/-   ##
=======================================
  Coverage   99.62%   99.62%           
=======================================
  Files         375      375           
  Lines       33404    33443   +39     
=======================================
+ Hits        33279    33318   +39     
  Misses        125      125           
Files Changed Coverage Δ
pennylane/measurements/__init__.py 100.00% <ø> (ø)
pennylane/measurements/classical_shadow.py 100.00% <100.00%> (ø)
pennylane/measurements/counts.py 100.00% <100.00%> (ø)
pennylane/measurements/measurements.py 100.00% <100.00%> (ø)
pennylane/measurements/mid_measure.py 100.00% <100.00%> (ø)
pennylane/measurements/mutual_info.py 100.00% <100.00%> (ø)
pennylane/measurements/vn_entropy.py 100.00% <100.00%> (ø)
pennylane/ops/functions/equal.py 98.52% <100.00%> (+0.04%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@albi3ro albi3ro marked this pull request as ready for review September 19, 2023 18:37
@albi3ro albi3ro requested review from vincentmr and a team September 19, 2023 18:38
@mudit2812
Copy link
Contributor

Considering the size of this PR, I'd assume it will get merged before #4544 . So I will update the _flatten and _unflatten` methods there.

Copy link
Contributor

@vincentmr vincentmr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, cheers @albi3ro .

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks awesome, also helped highlight some little things to touch up!

quick question: (un/)flattening should help with copying and replacing operators, right? Curious because we replace measurements in-place in device preprocessing, and I'm wondering if that code can benefit from this change

@vincentmr
Copy link
Contributor

Not necessarily in the scope of the PR, but I was wondering whether we would like to leverage _flatten in dyadics like == or _equal?

@albi3ro
Copy link
Contributor Author

albi3ro commented Sep 20, 2023

Not necessarily in the scope of the PR, but I was wondering whether we would like to leverage _flatten in dyadics like == or _equal?

Potentially we could rewrite a lot of qml.equal (used by __eq__). Tensor data has to use qml.math.allclose instead of ==, so it wouldn't necessarily be as easy as type(obj1) == type(obj2) and obj1._flatten() == obj2._flatten(), but it might still be a better solution than what we currently have.

@albi3ro albi3ro enabled auto-merge (squash) September 20, 2023 19:36
@albi3ro albi3ro merged commit e909e56 into master Sep 20, 2023
39 checks passed
@albi3ro albi3ro deleted the measurement-process-pytree branch September 20, 2023 20:25
Comment on lines +164 to +166
PennyLane measurements are automatically registered as `Pytrees <https://jax.readthedocs.io/en/latest/pytrees.html>`_ .

The :class:`~.MeasurementProcess` definitions are sufficient for all PL measurements.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this mean that users creating custom measurements do not need to do anything extra?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Actually that comment is wrong. They need to be overridden if the measurement process has extra metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix that in the tape pytree PR.

mudit2812 pushed a commit that referenced this pull request Sep 21, 2023
This PR registers all `MeasurementProcess` objects as jax pytrees.
[sc-40588]
rmoyard added a commit to PennyLaneAI/catalyst that referenced this pull request Oct 20, 2023
**Context:**
PennyLane 0.33.0 will introduce measurements as pytrees
PennyLaneAI/pennylane#4607
Most measurements have therefore no leaves, this breaks a Catalyst
assumptions for capturing the program.

**Description of the Change:**

- Unflatten the return to get it
- Flatten the return again but this time with `is_leaf` true for
measurement processes.
 
**Benefits:**

Catalyst is up to date with PennyLane master

**Possible Drawbacks:**

Potential slow down
Benchmark:
```
import pennylane as qml
from jax import numpy as jnp

from catalyst import qjit

dev = qml.device("lightning.qubit", wires=3)

import timeit

def my_function_v1():

    @qjit
    @qml.qnode(device=dev)
    def circuit(x: float, y: float):
        qml.RX(x, wires=0)
        qml.RY(y, wires=1)
        qml.CNOT(wires=[0, 1])
        return [tuple([qml.expval(qml.PauliZ(wires=0))]), jnp.sin(y)], {"expval": qml.expval(qml.PauliZ(wires=1))}, tuple([qml.probs(wires=[0,1]), qml.expval(qml.PauliZ(wires=2))])

def benchmark_function(func):
    setup_code = f"from __main__ import {func.__name__}"
    stmt = f"{func.__name__}()"
    execution_time = timeit.timeit(stmt, setup_code, number=100)
    return execution_time

if __name__ == "__main__":
    time_v1 = benchmark_function(my_function_v1)
    print(f"Version execution time: {time_v1:.6f} seconds")
```
PL master:
For 100 runs: Version execution time: 6.361672 seconds
PL 0.32.0:
For 100 runs: Version execution time: 6.301733 seconds
Diff: around 0.06s for hundred runs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants