Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient/Jacobians with pytrees #500

Merged
merged 11 commits into from
Feb 12, 2024
Merged

Gradient/Jacobians with pytrees #500

merged 11 commits into from
Feb 12, 2024

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Feb 8, 2024

Context:
No matter what is the structure of the return, when taking derivatives we also return the flatten derivatives.

Description of the Change:
Add support for any return when taking the derivatives and keep the structure. (Pytree support)

dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(phi, psi):
    qml.RY(phi, wires=0)
    qml.RX(psi, wires=0)
    return [{"expval0": qml.expval(qml.PauliZ(0))}, qml.expval(qml.PauliZ(0))]

psi = 0.1
print(qjit(jacobian(circuit, argnum=[0, 1]))(psi, phi))
[{'expval0': (array(-0.0978434), array(-0.19767681))}, (array(-0.0978434), array(-0.19767681))]

@rmoyard rmoyard changed the title Gradient with pytrees Gradient/Jacobians with pytrees Feb 8, 2024
@rmoyard
Copy link
Contributor Author

rmoyard commented Feb 9, 2024

[sc-55112]

@codecov-commenter
Copy link

codecov-commenter commented Feb 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (0509f9e) 99.54% compared to head (99e8022) 99.55%.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #500   +/-   ##
=======================================
  Coverage   99.54%   99.55%           
=======================================
  Files          43       43           
  Lines        7775     7786   +11     
  Branches      536      540    +4     
=======================================
+ Hits         7740     7751   +11     
  Misses         18       18           
  Partials       17       17           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@rmoyard rmoyard marked this pull request as ready for review February 9, 2024 19:04
Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥳

@rmoyard rmoyard merged commit 3e1fbd8 into main Feb 12, 2024
34 checks passed
@rmoyard rmoyard deleted the gradient_pytree branch February 12, 2024 16:25
rmoyard added a commit that referenced this pull request Feb 12, 2024
**Context:**

Following #500, we aim to
add support for arbitrary return of functions for VJP and JVP.

**Description of the Change:**

- JVP and VJP are updated to support pytree as return.
- Clean the tests.
@@ -522,6 +523,27 @@ def _check_grad_params(
return GradParams(method, scalar_out, h, argnum)


def _unflatten_derivatives(results, out_tree, argnum, num_results):
"""Unflatten the flat list of derivatives results given the out tree."""
num_trainable_params = len(argnum) if isinstance(argnum, list) else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for a list is unnecessarily restrictive, maybe better to check for any iterable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include tests for Pytree arguments as well, unless we already have those?

Comment on lines +529 to +539
results = tuple(
tuple(results[i * num_trainable_params : i * num_trainable_params + num_trainable_params])
for i in range(0, num_results)
)
# In jax, argnums=[int] wraps single derivatives in a tuple
if (
(isinstance(argnum, int) or argnum is None)
and num_trainable_params == 1
and num_results != 1
):
results = tuple(r[0] if len(r) == 1 else r for r in results)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find these lines very difficult to follow, maybe there is a way to refactor them or explain what they are doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing:

flat list of derivatives -> structure of the jacobian (tuple of tuple if multiple return and multiple parameters)

Comment on lines +540 to +543
if out_tree.children() != []:
results = tree_unflatten(out_tree, results)
else:
results = results[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sort of special case really required? I would have hoped that unflattening the tree would take care of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes me too, but in some cases there is an error internal to Jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants