Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax interface for all devices. #1076

Merged
merged 14 commits into from
Feb 9, 2021
Merged

Jax interface for all devices. #1076

merged 14 commits into from
Feb 9, 2021

Conversation

chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Feb 8, 2021

Ported over from #1065

Context:
Previously, only the default.qubit.jax device supported the JAX interface. This PR adds JAX interface support to all devices including built-in gradient support for non-default simulators.

Description of the Change:
The recent addition of jax.experimental.host_callback allows us to finally do jax -> numpy -> numpy -> jax within a jax.jit function! This PR adds the needed scaffolding to do it.

Benefits:
Everyone can start using JAX more and I won't take any more excuses.

Possible Drawbacks:
None.

Related GitHub Issues:
#943

TODOs

  • Vmap support

@chaserileyroberts chaserileyroberts changed the title Move to new branch Jax interface for all devices. Feb 8, 2021
@codecov
Copy link

codecov bot commented Feb 8, 2021

Codecov Report

Merging #1076 (4785bd3) into master (cfdb6f8) will decrease coverage by 0.03%.
The diff coverage is 89.79%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1076      +/-   ##
==========================================
- Coverage   97.74%   97.71%   -0.04%     
==========================================
  Files         153      154       +1     
  Lines       11590    11637      +47     
==========================================
+ Hits        11329    11371      +42     
- Misses        261      266       +5     
Impacted Files Coverage Δ
pennylane/tape/interfaces/jax.py 89.74% <89.74%> (ø)
pennylane/tape/qnode.py 98.88% <90.00%> (-0.35%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cfdb6f8...4785bd3. Read the comment docs.

.github/CHANGELOG.md Show resolved Hide resolved
.github/CHANGELOG.md Outdated Show resolved Hide resolved
@chaserileyroberts chaserileyroberts added the review-ready 👌 PRs which are ready for review by someone from the core team. label Feb 8, 2021
return_type = self.observables[0].return_type
if return_type is not Variance and return_type is not Expectation:
raise ValueError(
f"Only Variance and Expectation returns are support, given {return_type}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, maybe mention that this is the JAX interface speaking here, and not general PennyLane?

"""Test that the device provides the correct
result for a simple circuit with a device using a different interface."""
if not qml.tape_mode_active():
pytest.skip("Tape mode only test")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't include it the tests fails. :(


weights = jnp.array([0.1, 0.2])
val = jax.jacrev(circuit)(weights)
assert "DeviceArray" in val.__repr__()
Copy link
Contributor

@mariaschuld mariaschuld Feb 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, what you are testing is not the value, but the return type, right? Maybe mention in test names? test_jacobian sounds kind of like more!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would one have to test shape/values of the return value as well? I'm just thinking of a situation where the return value is some trivial "zero" or so because something went wrong, but still has the right type...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I should add more checks in this test.

def loss(weights, a):
# the following global variable is defined simply for testing
# purposes, so that we can easily extract the transformed QNode
# for verification.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, which global variable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, new_circuit, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this comment should be deleted. new_circuit was treated as a global variable before and checked outside this method. However, we don't want to do that (the similar checks for the autograd interface don't apply here). I'll delete the comment.

assert grad[1].shape == a.shape

# compare against the expected values
tol = 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not so important, but I remember we tried to use global tol fixtures for closeness checks...?

Copy link
Contributor

@mariaschuld mariaschuld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left some comments for now. Awesome addition!

@chaserileyroberts
Copy link
Contributor Author

@josh146 this should be ready to be merged

@chaserileyroberts chaserileyroberts added the merge-ready ✔️ All tests pass and the PR is ready to be merged. label Feb 9, 2021
@mariaschuld
Copy link
Contributor

Overwriting the codecov because untested lines are falsely reported due to how jitting works.

@mariaschuld mariaschuld merged commit 12c8927 into master Feb 9, 2021
@mariaschuld mariaschuld deleted the jax_hostcallback branch February 9, 2021 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge-ready ✔️ All tests pass and the PR is ready to be merged. review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants