Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX integration tests #1685

Merged
merged 18 commits into from
Oct 20, 2021
Merged

Add JAX integration tests #1685

merged 18 commits into from
Oct 20, 2021

Conversation

josh146
Copy link
Member

@josh146 josh146 commented Sep 24, 2021

Context: The JAX interface is missing many of the integration tests that the other interfaces have. This PR adds these integration tests in, and makes note of where the JAX interface may be lacking in feature parity.

Description of the Change:

  • Adds a JAX integration test to test_gradient_transform.py (including JIT tests).

  • Adds QNode integration tests tests/interfaces/test_batch_jax_qnode.py.

  • Adds a batch transform integration test test_batch_transform.py.

  • Modifies qml.math.get_trainable_indices() to correctly return results when JAX is performing a forward-only computation. This is done by making the following change:

    • On forward-only passes, all DeviceArray objects are treated as trainable.
    • On forwards+backwards passes, only jax.core.Tracer objects are treated as trainable (which matches the jax.grad(cost, argnum=...) argument).

    This change allows the metric tensor/gradient transform functions --- which apply in forward-only mode but require knowledge of trainable parameters --- to apply to QNodes when using JAX.

    This is required because, since JAX does not have a method of specifying trainable parameters on the forward pass, perviously gradient transforms would simply register no trainable parameters on forward passes, and return an empty list as a result! Paradoxically, differentiating the gradient transform would work fine, since the trainable parameters are now registered.

Benefits:

  • Better tests for JAX, and a better idea of what works, and what doesn't.

  • Gradient transforms, and the metric tensor, now works for the JAX interface.

Possible Drawbacks:

Several areas were noticed were JAX usage resulted in errors or issues, unlike other interfaces:

  • Ragged QNodes in backprop mode. E.g., return qml.expval(qml.PauliZ(0)), qml.probs(wires=[1]). This appears to be because line 230 in QubitDevice (results = self._asarray(results)) fails.

  • Hamiltonian expansion of expval(H) when using finite shots fails. When using finite shots, expval(H) results in the tape being expanded to multiple tapes, each one having a vector-valued output. Since the JAX parameter-shift interface does not support vector-valued tapes, an error occurs.

  • JIT mode with the adjoint method. Since the adjoint method is not using host_callback in the jax.py interface, it results in a failure attempting to JIT the QNode.

  • JIT mode with jax.grad(cost, argnums=...) where argnums is a subset of allowed arguments, e.g., argnums=[0, 2]. In this scenario, the JAX interface is only passing unwrapped trainable parameters to the host_callback - as a result, non-trainable parameters are remaining as jax.core.Tracer objects on the tape, which cannot be understood by the device.

Related GitHub Issues: n/a

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@josh146 josh146 mentioned this pull request Sep 25, 2021
@codecov
Copy link

codecov bot commented Sep 28, 2021

Codecov Report

Merging #1685 (db973d5) into master (b352691) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1685   +/-   ##
=======================================
  Coverage   98.90%   98.90%           
=======================================
  Files         206      206           
  Lines       15388    15396    +8     
=======================================
+ Hits        15219    15227    +8     
  Misses        169      169           
Impacted Files Coverage Δ
pennylane/math/multi_dispatch.py 100.00% <100.00%> (ø)
pennylane/math/utils.py 100.00% <100.00%> (ø)
pennylane/transforms/batch_transform.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b352691...db973d5. Read the comment docs.

@josh146 josh146 added the review-ready 👌 PRs which are ready for review by someone from the core team. label Sep 28, 2021
@@ -279,6 +279,6 @@ def requires_grad(tensor, interface=None):
if interface == "jax":
import jax

return isinstance(tensor, jax.interpreters.ad.JVPTracer)
return isinstance(tensor, jax.core.Tracer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.core.Tracer is the original parent class, so this is a lot safer :) There are cases I discovered where JAX will use tracers that aren't JVPTracer.

@@ -400,7 +400,7 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su
coeffs = coeffs.numpy()

expected = np.dot(qcval, coeffs)
assert np.all(res == expected)
assert np.allclose(res, expected)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, this test was failing for me on CI (but not locally) 🤔

Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me! 🙂 Left some questions, but no major blocker.

If I recall correctly, we plan to discontinue support for QNodes with ragged outputs (right? 🤔). Would that affect the failing case of Ragged QNodes in backprop mode?

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
if interface == "jax":
import jax

if not any(isinstance(v, jax.core.Tracer) for v in values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great change to have 🥇

How come it's placed here, instead of into the JAX branch of requires_grad?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to be here, since the not any check can only be done here, it cannot be done inside the requires_grad check (which only checks a single tensor at a time) 🤔

I could be wrong though, let me know if you see a way around this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think you're right. 🤔 At least nothing else comes to mind that we could use here.

pennylane/transforms/batch_transform.py Show resolved Hide resolved
tests/gradients/test_finite_difference.py Show resolved Hide resolved
tests/gradients/test_finite_difference.py Show resolved Hide resolved
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliY(1))

res = jax.grad(cost_fn, argnums=[0, 1])(a, b, shots=30000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the test case make sure that we had shots=30000 instead of shots=100? Would the deviation from the expected analytic result be even bigger than 0.1 with shots=100?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is not the most ideal way of testing this, but I spent a while and couldn't come up with anything better 🤔

The other interfaces use shots=[(1, 1000)], which is nicer since the output shape changes. However, it can't be used for JAX, since JAX only supports scalar outputs :(

assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift

# if we set the shots to None, backprop can now be used
cost_fn(a, b, shots=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we then pass in diff_method="param-shift" if we really wanted to use parameter shift with shots=None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep! The reason the internal diff_method is changing here is because, by default, diff_method="best". If you instead set diff_method="param-shift", then it will not change dynamically.

Comment on lines +831 to +833

if diff_method not in {"backprop"}:
pytest.skip("Test only supports backprop")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this have any effect at this point?

Suggested change
if diff_method not in {"backprop"}:
pytest.skip("Test only supports backprop")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to leave it in, since we may have support for vector valued QNodes in parameter-shift mode at some point in the future!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Would we only test assert res.dtype is np.dtype("complex128") for non-backprop diff methods? That seems to be the only statement before skipping the test.

tests/interfaces/test_batch_jax_qnode.py Outdated Show resolved Hide resolved
tests/transforms/test_metric_tensor.py Outdated Show resolved Hide resolved
josh146 and others added 2 commits October 20, 2021 00:08
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! 💯 Thank you for adding these tests to check the JAX interface. 😍

Just double-checking this one:

If I recall correctly, we plan to discontinue support for QNodes with ragged outputs (right? thinking). Would that affect the failing case of Ragged QNodes in backprop mode?

This is just to understand better what the priority of the drawbacks would be that were identified in the PR description. 🙂

@josh146
Copy link
Member Author

josh146 commented Oct 20, 2021

Would that affect the failing case of Ragged QNodes in backprop mode?

Yes! Hopefully, this would allow these QNodes to work with JAX, once we make that change 🙂

@josh146 josh146 merged commit f20c19d into master Oct 20, 2021
@josh146 josh146 deleted the batch-qnode-jax branch October 20, 2021 05:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants