New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bind lightning provided VJPs #4914
Conversation
Hello. You may have forgotten to update the changelog!
|
…lane into lightning-vjps
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #4914 +/- ##
==========================================
- Coverage 99.67% 99.66% -0.02%
==========================================
Files 394 394
Lines 35670 35457 -213
==========================================
- Hits 35554 35337 -217
- Misses 116 120 +4 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I just left some suggestions to fix typos. Thanks @albi3ro .
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_jax_qnode_default_qubit_2.py
Outdated
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_torch_qnode_default_qubit_2.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thanks @albi3ro for putting it together!
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I won't make any strong claim about where these tests should be while we have this circular dependency thing going on... feels like anything goes. I'm happiest to not have to copy-paste or re-write tests, so this seems sensible. just a few comments, but mostly lgtm!
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_autograd_qnode_default_qubit_2.py
Show resolved
Hide resolved
tests/interfaces/default_qubit_2_integration/test_jax_jit_qnode_default_qubit_2.py
Show resolved
Hide resolved
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
For the code: ```python n_wires = 20 n_layers = 5 dev = qml.device('lightning.qubit', wires=n_wires) shape = qml.StronglyEntanglingLayers.shape(n_wires=n_wires, n_layers=n_layers) rng = qml.numpy.random.default_rng(seed=42) params = rng.random(shape) params_torch = torch.tensor(params, requires_grad=True) @qml.qnode(dev, device_vjp=True) def circuit_dev_jp(params): qml.StronglyEntanglingLayers(params, wires=range(n_wires)) return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)] @qml.qnode(dev, diff_method="adjoint", device_vjp=False) def circuit_dev_jac(params): qml.StronglyEntanglingLayers(params, wires=range(n_wires)) return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)] ``` We get: ![Screenshot 2023-12-04 at 10 54 40 AM](https://github.com/PennyLaneAI/pennylane/assets/6364575/f88ea573-a437-4775-a697-4f5da0de4a2d) --------- Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com> Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Open questions:
Where do we test this? Do we test it in
pennylane
orpennylane-lightning
?Is this the expected way to compute the vjps?
Is there any extra validation or preprocessing we need to do?
Can we re-use the device state between calls to
vjp
?[sc-51789]
For the code:
We get: