-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added multi_dispatch decorator #2084
Changes from 7 commits
a1e2ea6
31661c1
96ff418
c606d34
790e52a
c864c48
cc7fc37
b4a6e8c
a25e9b5
7c97b7f
ba212d4
bd3b15e
ac9a82e
48b1834
9da9a98
5907880
e8cf981
a945579
a620e13
65110b1
421eb7c
3bdf5a7
c0f0a8e
3ff2742
7df030b
cf81b3e
fb79fbe
ba7cb0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
"""Multiple dispatch functions""" | ||
# pylint: disable=import-outside-toplevel,too-many-return-statements | ||
import warnings | ||
from collections.abc import Sequence | ||
import functools | ||
|
||
from autograd.numpy.numpy_boxes import ArrayBox | ||
from autoray import numpy as np | ||
|
@@ -81,6 +83,79 @@ def _multi_dispatch(values): | |
return "numpy" | ||
|
||
|
||
def multi_dispatch(argnum=None, tensor_list=None): | ||
"""Decorater to dispatch arguments handled by the interface. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is a link to the rendered version: https://pennylane--2084.org.readthedocs.build/en/2084/code/api/pennylane.math.multi_dispatch.html It looks really nice!
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
This helps simplify definitions of new functions inside pennylane. Instead of writing | ||
|
||
>>> def some_function(tensor1, tensor2, option): | ||
... interface = qml.math._multi_dispatch([tensor1, tensor2]) | ||
... ... | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
We can decorate the function, indicating the arguments that are tensors handled | ||
by the interface | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
>>> @qml.math.multi_dispatch(argnum=[0, 1]) | ||
... def some_function(tensor1, tensor2, option, like): | ||
... # the interface string is stored in `like`. | ||
... ... | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
argnum (list[int]): a list of integers indicating indicating the indices | ||
to dispatch (i.e. the arguments that are tensors handled by an interface) | ||
If None, dispatch over all arguments | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tensor_lists(list[int]): a list of integers indicating which indices | ||
in argnum are lists of tensors. | ||
If None, this option is ignored. | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
decorator: | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. seealso:: :func:`pennylane.math.multi_dispatch._multi_dispatch` | ||
|
||
.. note:: | ||
This decorator makes the interface argument "like" optional as it utilizes | ||
the utility function `_multi_dispatch` to automatically detect the appropriate | ||
interface based on the tensor types. | ||
|
||
** Examples ** | ||
We can redefine external functions to be suitable for pennylane. Here, we | ||
redefine autoray's `stack` function. | ||
>>> stack = multi_dispatch(argnum=0, tensor_list=0)(autoray.numpy.stack) | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
|
||
def decorator(fn): | ||
@functools.wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
argnums = argnum or list(range(len(args))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make sure I am following this correctly: if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Qottmann I think there might be a bug here (and as well with This is a sign that more tests should be added, to ensure these edge cases work as expected. |
||
tensor_lists = tensor_list or [] | ||
|
||
if not isinstance(argnums, Sequence): | ||
argnums = [argnums] | ||
if not isinstance(tensor_lists, Sequence): | ||
tensor_lists = [tensor_lists] | ||
Comment on lines
+158
to
+161
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice 💯 |
||
|
||
dispatch_args = [] | ||
|
||
for a in argnums: | ||
if a in tensor_lists: | ||
dispatch_args.extend(args[a]) | ||
else: | ||
dispatch_args.append(args[a]) | ||
|
||
interface = kwargs.pop("like", None) | ||
interface = interface or _multi_dispatch(dispatch_args) | ||
kwargs["like"] = interface | ||
|
||
return fn(*args, **kwargs) | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
def block_diag(values): | ||
"""Combine a sequence of 2D tensors to form a block diagonal tensor. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ semantic_version==2.6 | |
dask[delayed]==2021.10 | ||
autoray>=0.2.5 | ||
matplotlib==3.4 | ||
black>=21 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2018-2020 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Assertion test for multi_dispatch function/decorator | ||
""" | ||
import autoray | ||
import numpy as onp | ||
import pytest | ||
from pennylane import numpy as np | ||
from pennylane import math as fn | ||
|
||
|
||
tf = pytest.importorskip("tensorflow", minversion="2.1") | ||
torch = pytest.importorskip("torch") | ||
|
||
test_multi_dispatch_stack_data = [ | ||
[[1.0, 0.0], [2.0, 3.0]], | ||
([1.0, 0.0], [2.0, 3.0]), | ||
onp.array([[1.0, 0.0], [2.0, 3.0]]), | ||
np.array([[1.0, 0.0], [2.0, 3.0]]), | ||
# torch.tensor([[1.,0.],[2.,3.]]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come this one is commented out? |
||
tf.Variable([[1.0, 0.0], [2.0, 3.0]]), | ||
tf.constant([[1.0, 0.0], [2.0, 3.0]]), | ||
] | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize("x", test_multi_dispatch_stack_data) | ||
def test_multi_dispatch_stack(x): | ||
"""Test that the decorated autoray function stack can handle all inputs""" | ||
stack = fn.multi_dispatch(argnum=0, tensor_list=0)(autoray.numpy.stack) | ||
res = stack(x) | ||
print(res) | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert fn.allequal(res, [[1.0, 0.0], [2.0, 3.0]]) | ||
Qottmann marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good here to add a short description and code example here to help explain to readers the importance/impact of this change 🙂
The reason that this tends to be important is that the readers of the changelog tend to be a very specific demographic --- it includes both regular users, who are likely to be more aware of changes occuring, but also non-regular users who maybe only interact with PennyLane once and a while by reading the changelog.
So when writing the changelog entries, we often try to think --- 'how can we communicate this change to a skim-reader in a succinct manner?'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this approach is a bit unique to PL; quite a few OSS projects tend to have simpler changelogs! The reason we take this approach is that we often use our changelog to advertise the new releases.
In this case, you could likely copy and adapt the intro to the
multi_dipatch
docstring