-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a tape unwrapping context manager #1491
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov Report
@@ Coverage Diff @@
## master #1491 +/- ##
==========================================
+ Coverage 98.34% 98.36% +0.02%
==========================================
Files 181 182 +1
Lines 12899 12933 +34
==========================================
+ Hits 12685 12722 +37
+ Misses 214 211 -3
Continue to review full report at Codecov.
|
pennylane/math/single_dispatch.py
Outdated
def _to_numpy_jax(x): | ||
from jaxlib.xla_extension import DeviceArray | ||
|
||
return np.array(x) if isinstance(x, DeviceArray) else x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's going on here? Don't we always want np arrays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think you just gave me an idea!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It worked!!!!!!!!!!
So, JAX has multiple types of arrays. On the forward pass, everything is a DeviceArray
; this can be converted into a NumPy array via np.array(x)
.
On the backward pass, however, JAX instead passes ConcreteArray
objects, which don't support NumPy conversion. I suddenly wondered if they have a private attribute that allows you to extract the NumPy array, and they do!
Original parameters: [<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.1>, | ||
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.3>] | ||
""" | ||
return UnwrapTape(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one could alternatively do
with tape.unwrap() as unwrapped_tape:
which is the same as
with Unwrap(tape) as unwrapped_tape:
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this is purely a shortcut so that developers can simply call tape.unwrap()
on their existing circuits, rather than worrying about importing an additional functions (and potentially causing circular import issues)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only serious comment would be to change the name UnwrapTape
to Unwrap
, but since this is not super important and the PR definitely improves the code base I approve :)
Context: In #1490, improved handling for tensor unwrapping and determining trainability was added to PennyLane. A big application for this functionality is in the PL interfaces; in particular, the interfaces typically must perform the following logic:
This PR provides a context manager to the tape to automate this procedure for any autodiff framework.
Description of the Change:
A new tape method,
tape.unwrap()
is added. This method is a context manager; inside the context, the tapes parameters are unwrapped to NumPy arrays and floats, the the trainable parameter indices are set.These changes are temporary, and reverted on exiting the context.
Example:
Benefits:
This logic, which was previously duplicated across all interfaces, is now provided as a separate, well tested, and easily extendible unit.
Due to the previous code duplication, bugs fixed in some interfaces were not always found and fixed in others.
The logic is now well tested with unit tests
It is possible to use this unwrapping capability elsewhere, outside the interfaces. For example, some plugins may need to unwrap tape parameters in certain cases, see Unwrap tensor in random_layer #893.
Possible Drawbacks: None, as far as I can tell.
Related GitHub Issues: n/a