Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added multi_dispatch decorator #2084

Merged
merged 28 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a1e2ea6
added multi_dispatch decorator
Qottmann Jan 12, 2022
31661c1
added release notes and ran black on only my files
Qottmann Jan 12, 2022
96ff418
typo in PR number, now 2084
Qottmann Jan 12, 2022
c606d34
updated black, reran on test_multi_dispatch
Qottmann Jan 12, 2022
790e52a
added black version requirement, optional commit
Qottmann Jan 12, 2022
c864c48
Update doc/requirements.txt
josh146 Jan 13, 2022
cc7fc37
moved black>=21 requirement out of docs to /requirements.txt
Qottmann Jan 13, 2022
b4a6e8c
Update tests/math/test_multi_disptach.py
Qottmann Jan 13, 2022
a25e9b5
Update pennylane/math/multi_dispatch.py
Qottmann Jan 13, 2022
7c97b7f
Update pennylane/math/multi_dispatch.py
Qottmann Jan 13, 2022
ba212d4
Update pennylane/math/multi_dispatch.py
Qottmann Jan 13, 2022
bd3b15e
Update pennylane/math/multi_dispatch.py
Qottmann Jan 13, 2022
ac9a82e
Update pennylane/math/multi_dispatch.py
Qottmann Jan 13, 2022
48b1834
added example in multi_dispatch and expanded tests
Qottmann Jan 13, 2022
9da9a98
merged conflict in changelog-dev
Qottmann Jan 13, 2022
5907880
changed argument handling to allow for argnum=0 and added custom func…
Qottmann Jan 14, 2022
e8cf981
Update pennylane/math/multi_dispatch.py
Qottmann Jan 14, 2022
a945579
Update doc/releases/changelog-dev.md
Qottmann Jan 14, 2022
a620e13
Update doc/releases/changelog-dev.md
Qottmann Jan 14, 2022
65110b1
Update pennylane/math/multi_dispatch.py
Qottmann Jan 14, 2022
421eb7c
Update pennylane/math/multi_dispatch.py
Qottmann Jan 14, 2022
3bdf5a7
small review
Qottmann Jan 14, 2022
c0f0a8e
Merge branch 'master' into master
josh146 Jan 16, 2022
3ff2742
Update doc/releases/changelog-dev.md
Qottmann Jan 17, 2022
7df030b
Update pennylane/math/multi_dispatch.py
Qottmann Jan 17, 2022
cf81b3e
Update pennylane/math/multi_dispatch.py
Qottmann Jan 17, 2022
fb79fbe
Update tests/math/test_multi_disptach.py
Qottmann Jan 17, 2022
ba7cb0d
Update tests/math/test_multi_disptach.py
Qottmann Jan 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@
* The QAOA module now accepts both NetworkX and RetworkX graphs as function inputs.
[(#1791)](https://github.com/PennyLaneAI/pennylane/pull/1791)

* Added `multi_dispatch` decorator that helps ease the definition of new functions.
[(#2082)](https://github.com/PennyLaneAI/pennylane/pull/2084)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good here to add a short description and code example here to help explain to readers the importance/impact of this change 🙂

The reason that this tends to be important is that the readers of the changelog tend to be a very specific demographic --- it includes both regular users, who are likely to be more aware of changes occuring, but also non-regular users who maybe only interact with PennyLane once and a while by reading the changelog.

So when writing the changelog entries, we often try to think --- 'how can we communicate this change to a skim-reader in a succinct manner?'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this approach is a bit unique to PL; quite a few OSS projects tend to have simpler changelogs! The reason we take this approach is that we often use our changelog to advertise the new releases.

In this case, you could likely copy and adapt the intro to the multi_dipatch docstring


<h3>Breaking changes</h3>

* The behaviour of `RotosolveOptimizer` has been changed regarding
Expand Down Expand Up @@ -341,4 +344,4 @@

This release contains contributions from (in alphabetical order):

Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Jay Soni, Antal Száva, David Wierichs, Shaoming Zhang
Juan Miguel Arrazola, Ali Asadi, Esther Cruz, Olivia Di Matteo, Diego Guala, Ankit Khandelwal, Korbinian Kottmann, Jay Soni, Antal Száva, David Wierichs, Shaoming Zhang
2 changes: 2 additions & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .multi_dispatch import (
_multi_dispatch,
multi_dispatch,
block_diag,
concatenate,
diag,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __getattr__(name):

__all__ = [
"_multi_dispatch",
"multi_dispatch",
"allclose",
"allequal",
"block_diag",
Expand Down
75 changes: 75 additions & 0 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Multiple dispatch functions"""
# pylint: disable=import-outside-toplevel,too-many-return-statements
import warnings
from collections.abc import Sequence
import functools

from autograd.numpy.numpy_boxes import ArrayBox
from autoray import numpy as np
Expand Down Expand Up @@ -81,6 +83,79 @@ def _multi_dispatch(values):
return "numpy"


def multi_dispatch(argnum=None, tensor_list=None):
"""Decorater to dispatch arguments handled by the interface.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qottmann marked this conversation as resolved.
Show resolved Hide resolved

This helps simplify definitions of new functions inside pennylane. Instead of writing

>>> def some_function(tensor1, tensor2, option):
... interface = qml.math._multi_dispatch([tensor1, tensor2])
... ...
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

We can decorate the function, indicating the arguments that are tensors handled
by the interface
Qottmann marked this conversation as resolved.
Show resolved Hide resolved


>>> @qml.math.multi_dispatch(argnum=[0, 1])
... def some_function(tensor1, tensor2, option, like):
... # the interface string is stored in `like`.
... ...
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

Args:
argnum (list[int]): a list of integers indicating indicating the indices
to dispatch (i.e. the arguments that are tensors handled by an interface)
If None, dispatch over all arguments
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
tensor_lists(list[int]): a list of integers indicating which indices
in argnum are lists of tensors.
If None, this option is ignored.
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

Returns:
decorator:
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

.. seealso:: :func:`pennylane.math.multi_dispatch._multi_dispatch`

.. note::
This decorator makes the interface argument "like" optional as it utilizes
the utility function `_multi_dispatch` to automatically detect the appropriate
interface based on the tensor types.

** Examples **
We can redefine external functions to be suitable for pennylane. Here, we
redefine autoray's `stack` function.
>>> stack = multi_dispatch(argnum=0, tensor_list=0)(autoray.numpy.stack)
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

"""

def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
argnums = argnum or list(range(len(args)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I am following this correctly: if argnum is not provided, the default is to assume that argnum=[0, 1, 2, ..., len(args)]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qottmann I think there might be a bug here (and as well with tensor_lists below!) The issue being that argnum=0 and tensor_list=0 are both valid inputs, but python treats 0 as False, and so the boolean statement will choose the latter argument in both cases!

This is a sign that more tests should be added, to ensure these edge cases work as expected.

tensor_lists = tensor_list or []

if not isinstance(argnums, Sequence):
argnums = [argnums]
if not isinstance(tensor_lists, Sequence):
tensor_lists = [tensor_lists]
Comment on lines +158 to +161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 💯


dispatch_args = []

for a in argnums:
if a in tensor_lists:
dispatch_args.extend(args[a])
else:
dispatch_args.append(args[a])

interface = kwargs.pop("like", None)
interface = interface or _multi_dispatch(dispatch_args)
kwargs["like"] = interface

return fn(*args, **kwargs)
Qottmann marked this conversation as resolved.
Show resolved Hide resolved

return wrapper

return decorator


def block_diag(values):
"""Combine a sequence of 2D tensors to form a block diagonal tensor.

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ semantic_version==2.6
dask[delayed]==2021.10
autoray>=0.2.5
matplotlib==3.4
black>=21
43 changes: 43 additions & 0 deletions tests/math/test_multi_disptach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2018-2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Assertion test for multi_dispatch function/decorator
"""
import autoray
import numpy as onp
import pytest
from pennylane import numpy as np
from pennylane import math as fn


tf = pytest.importorskip("tensorflow", minversion="2.1")
torch = pytest.importorskip("torch")

test_multi_dispatch_stack_data = [
[[1.0, 0.0], [2.0, 3.0]],
([1.0, 0.0], [2.0, 3.0]),
onp.array([[1.0, 0.0], [2.0, 3.0]]),
np.array([[1.0, 0.0], [2.0, 3.0]]),
# torch.tensor([[1.,0.],[2.,3.]]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come this one is commented out?

tf.Variable([[1.0, 0.0], [2.0, 3.0]]),
tf.constant([[1.0, 0.0], [2.0, 3.0]]),
]
Qottmann marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("x", test_multi_dispatch_stack_data)
def test_multi_dispatch_stack(x):
"""Test that the decorated autoray function stack can handle all inputs"""
stack = fn.multi_dispatch(argnum=0, tensor_list=0)(autoray.numpy.stack)
res = stack(x)
print(res)
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
assert fn.allequal(res, [[1.0, 0.0], [2.0, 3.0]])
Qottmann marked this conversation as resolved.
Show resolved Hide resolved