Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tape unwrapping context manager #1491

Merged
merged 16 commits into from
Aug 4, 2021
Merged

Add a tape unwrapping context manager #1491

merged 16 commits into from
Aug 4, 2021

Conversation

josh146
Copy link
Member

@josh146 josh146 commented Aug 3, 2021

Context: In #1490, improved handling for tensor unwrapping and determining trainability was added to PennyLane. A big application for this functionality is in the PL interfaces; in particular, the interfaces typically must perform the following logic:

  • Extract the tape parameters
  • Determine which are trainable
  • Unwrap all parameters to NumPy arrays and floats
  • Set the tape parameters as unwrapped
  • Execute the device
  • Set the original tape parameters

This PR provides a context manager to the tape to automate this procedure for any autodiff framework.

Description of the Change:

A new tape method, tape.unwrap() is added. This method is a context manager; inside the context, the tapes parameters are unwrapped to NumPy arrays and floats, the the trainable parameter indices are set.

These changes are temporary, and reverted on exiting the context.

Example:

>>> with tf.GradientTape():
...     with qml.tape.QuantumTape() as tape:
...         qml.RX(tf.Variable(0.1), wires=0)
...         qml.RY(tf.constant(0.2), wires=0)
...         qml.RZ(tf.Variable(0.3), wires=0)
...     with UnwrapTape(tape) as unwrapped_tape:
...         print("Trainable params:", unwrapped_tape.trainable_params)
...         print("Unwrapped params:", unwrapped_tape.get_parameters())
Trainable params: {0, 2}
Unwrapped params: [0.1, 0.3]
>>> print("Original parameters:", tape.get_parameters())
Original parameters: [<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.1>,
  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.3>]

Benefits:

  • This logic, which was previously duplicated across all interfaces, is now provided as a separate, well tested, and easily extendible unit.

  • Due to the previous code duplication, bugs fixed in some interfaces were not always found and fixed in others.

  • The logic is now well tested with unit tests

  • It is possible to use this unwrapping capability elsewhere, outside the interfaces. For example, some plugins may need to unwrap tape parameters in certain cases, see Unwrap tensor in random_layer #893.

Possible Drawbacks: None, as far as I can tell.

Related GitHub Issues: n/a

@josh146 josh146 changed the title Add a tape unwrapping context manager [WIP] Add a tape unwrapping context manager Aug 3, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Aug 3, 2021

Hello. You may have forgotten to update the changelog!
Please edit .github/CHANGELOG.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented Aug 3, 2021

Codecov Report

Merging #1491 (64e95bc) into master (ab71001) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1491      +/-   ##
==========================================
+ Coverage   98.34%   98.36%   +0.02%     
==========================================
  Files         181      182       +1     
  Lines       12899    12933      +34     
==========================================
+ Hits        12685    12722      +37     
+ Misses        214      211       -3     
Impacted Files Coverage Δ
pennylane/math/single_dispatch.py 99.47% <100.00%> (+0.01%) ⬆️
pennylane/tape/__init__.py 100.00% <100.00%> (ø)
pennylane/tape/tape.py 98.61% <100.00%> (+0.82%) ⬆️
pennylane/tape/unwrap.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ab71001...64e95bc. Read the comment docs.

@josh146 josh146 changed the title [WIP] Add a tape unwrapping context manager Add a tape unwrapping context manager Aug 3, 2021
@josh146 josh146 added the review-ready 👌 PRs which are ready for review by someone from the core team. label Aug 3, 2021
.github/CHANGELOG.md Outdated Show resolved Hide resolved
def _to_numpy_jax(x):
from jaxlib.xla_extension import DeviceArray

return np.array(x) if isinstance(x, DeviceArray) else x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here? Don't we always want np arrays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think you just gave me an idea!!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It worked!!!!!!!!!!

So, JAX has multiple types of arrays. On the forward pass, everything is a DeviceArray; this can be converted into a NumPy array via np.array(x).

On the backward pass, however, JAX instead passes ConcreteArray objects, which don't support NumPy conversion. I suddenly wondered if they have a private attribute that allows you to extract the NumPy array, and they do!

Original parameters: [<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.1>,
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.3>]
"""
return UnwrapTape(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So one could alternatively do

with tape.unwrap() as unwrapped_tape:

which is the same as

with Unwrap(tape) as unwrapped_tape:

right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is purely a shortcut so that developers can simply call tape.unwrap() on their existing circuits, rather than worrying about importing an additional functions (and potentially causing circular import issues)

Copy link
Contributor

@mariaschuld mariaschuld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only serious comment would be to change the name UnwrapTape to Unwrap, but since this is not super important and the PR definitely improves the code base I approve :)

.github/CHANGELOG.md Outdated Show resolved Hide resolved
@josh146 josh146 merged commit 0428134 into master Aug 4, 2021
@josh146 josh146 deleted the unwrap2 branch August 4, 2021 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants