Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing JAX interface for adjoint diff #1349

Merged
merged 11 commits into from
May 31, 2021
Merged

Fixing JAX interface for adjoint diff #1349

merged 11 commits into from
May 31, 2021

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented May 21, 2021

Closes Issue #1345

Currently, the jax interface fails to return correct results with the adjoint differentiation method. The Jax interface makes a copy of the current tape to take it's derivative, but the tape copy didn't copy over differentiation information.

I have attempted to rectify this by improving the tape copy method, but now trying to take the derivative of the circuit causes my kernel to end.

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit .github/CHANGELOG.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented May 21, 2021

Codecov Report

Merging #1349 (97f23b8) into master (23b194e) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1349   +/-   ##
=======================================
  Coverage   98.16%   98.16%           
=======================================
  Files         154      154           
  Lines       11544    11548    +4     
=======================================
+ Hits        11332    11336    +4     
  Misses        212      212           
Impacted Files Coverage Δ
pennylane/interfaces/autograd.py 100.00% <ø> (ø)
pennylane/tape/jacobian_tape.py 98.00% <100.00%> (+0.04%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 23b194e...97f23b8. Read the comment docs.

@albi3ro albi3ro requested a review from josh146 May 21, 2021 18:57
jnp.array(0.1, dtype=jnp.float64)
except UserWarning:
return False
return True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to test this line, given float64 mode must be configured on startup.

@albi3ro albi3ro added the review-ready 👌 PRs which are ready for review by someone from the core team. label May 21, 2021
@albi3ro albi3ro requested a review from antalszava May 25, 2021 15:27
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this @albi3ro! Just two suggestions re: improving the precision check and moving the location of the warning logic, but once addressed this should be good for merging.

.github/CHANGELOG.md Outdated Show resolved Hide resolved
Comment on lines 38 to 46
import jax.numpy as jnp

with warnings.catch_warnings():
warnings.filterwarnings("error")
try:
jnp.array(0.1, dtype=jnp.float64)
except UserWarning:
return False
return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice workaround, but it feels a bit too reliant on implementation details of JAX and how it raises warnings.

I would suggest instead simply querying the JAX configuration:

>>> from jax.config import config
>>> config.x64_enabled
False
>>> config.update("jax_enable_x64", True)
>>> config.x64_enabled
True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I would recommend moving this to interfaces/jax.py, so that all JAX imports are isolated in a single place. The warning can then be raised by the tape mixin class querying jacobian_options dictionary

Comment on lines 463 to 471
if interface == "jax":
if not _jax_float64_support():
warnings.warn(
"float64 support not enabled for jax. "
"May cause inaccuracies for `diff_method='finite-diff'`",
UserWarning,
)

return JacobianTape, interface, device, {"method": "numeric"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I would say this definitely feels like something that lives on the interface rather than the QNode :) Especially because it simply raises a warning, rather than changing the diff method.

@albi3ro albi3ro changed the title JacobianTape copy duplicates jacobian_options Fixing JAX interface for adjoint diff and finite-diff May 27, 2021
* Fixes drawing QNodes that contain multiple measurements on a single wire.
[(#1353)](https://github.com/PennyLaneAI/pennylane/pull/1353)


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

@@ -533,7 +538,8 @@ def jacobian(self, device, params=None, **options):
# First order (forward) finite-difference will be performed.
# Compute the value of the tape at the current parameters here. This ensures
# this computation is only performed once, for all parameters.
options["y0"] = np.asarray(self.execute_device(params, device))
params_f64 = np.array(params, dtype=np.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the jax device tests doesn't seem to like this line 🤔
https://github.com/PennyLaneAI/pennylane/pull/1349/checks?check_run_id=2697038221#step:10:63

On another note: wouldn't this casting be a bit too specific at a general place? This would potentially cast the parameters all the time regardless of the interface. Would we only like to do it for the jax interface?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albi3ro I think I agree with @antalszava in this case. While there is a good argument for enforcing finite difference to always use float64, I feel like this should be an interface level option/default, rather than enforced in the Jacobian method

@@ -55,6 +55,23 @@ def test_parameter_info(self, make_tape):
}


class TestTapeCopying:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding an integration test similar to the original example in the #1345 issue to ensure that the results using the jax interface are correct.

@antalszava
Copy link
Contributor

Hi @albi3ro, it must have been quite some digging to work this out, nice one! 🥇 One device test seems to be failing and also wondering if as Josh suggested, it could be good to add the fix specifically to the jax interface code.

@albi3ro albi3ro removed the review-ready 👌 PRs which are ready for review by someone from the core team. label May 28, 2021
@albi3ro
Copy link
Contributor Author

albi3ro commented May 28, 2021

The one test is failing because of an issue I just raised, Issue #1369 . When a parameter is a matrix instead of float, params is a deprecated ragged array and can't be cast to float64.

While this problem arose with the jax interface, it will actually occur any time we use finite differences with float32, or any other less precise data type. Therefore, I think the solution should be more general than just the jax interface.

numeric_pd uses float64 parameters anyway. "y0" just uses the less precise data type. Simple fix, just one that took me forever to find.

Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro!

When a parameter is a matrix instead of float, params is a deprecated ragged array and can't be cast to float64.

I think this might be a bug with the test not specifying requires_grad=False on the matrix, as opposed to a bug in the code, see
#1369 (comment)

I have two requests:

  • I feel like the general float64 fix within the JacobianTape is out of scope for this PR, and should be considered separately. Instead, I think it should be sufficient to just make the JAX interface use float64 in the finite-diff tests (with maybe small modifications to jax_ops.py).

  • Does this mean that there are no adjoint tests for the JAX interface? We should probably add them now to ensure that this bug is fixed :)

@@ -533,7 +538,8 @@ def jacobian(self, device, params=None, **options):
# First order (forward) finite-difference will be performed.
# Compute the value of the tape at the current parameters here. This ensures
# this computation is only performed once, for all parameters.
options["y0"] = np.asarray(self.execute_device(params, device))
params_f64 = np.array(params, dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albi3ro I think I agree with @antalszava in this case. While there is a good argument for enforcing finite difference to always use float64, I feel like this should be an interface level option/default, rather than enforced in the Jacobian method

@albi3ro albi3ro changed the title Fixing JAX interface for adjoint diff and finite-diff Fixing JAX interface for adjoint diff May 31, 2021
@albi3ro
Copy link
Contributor Author

albi3ro commented May 31, 2021

I added the finite-diff change to this PR because I thought it was going to be easy. But there needs to better a better solution than just the easy one, so that is getting moved to it's own PR.

I had the adjoint test before, but it got removed when I fire and brimstoned the PR after a bad merge that messed everything up. It's back. Everything should be good to go now.

@albi3ro albi3ro requested a review from josh146 May 31, 2021 13:49
Copy link
Member

@josh146 josh146 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @albi3ro 🙂

@@ -119,7 +119,7 @@ def get_parameters(self, trainable_only=True, return_arraybox=False):
qml.RY(0.543, wires=0)
qml.CNOT(wires=[0, 'a'])
qml.RX(0.133, wires='a')
expval(qml.PauliZ(wires=[0]))
qml.expval(qml.PauliZ(wires=[0]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, good catch

grad_adjoint = jax.grad(qnode_adjoint)(params1, params2)
grad_backprop = jax.grad(qnode_backprop)(params1, params2)

assert np.allclose(grad_adjoint, grad_backprop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have this test in 💪

@albi3ro albi3ro merged commit a4d7abd into master May 31, 2021
@albi3ro albi3ro deleted the tape_copy branch May 31, 2021 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants