Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the new multi_dispatch decorator in the math module. Add tensordot tests. #2096

Merged
merged 15 commits into from
Jan 19, 2022
Merged
9 changes: 6 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,13 @@
[(#2063)](https://github.com/PennyLaneAI/pennylane/pull/2063)

* Added a new `multi_dispatch` decorator that helps ease the definition of new functions
inside PennyLane. We can decorate the function, indicating the arguments that are
tensors handled by the interface:
inside PennyLane. The decorator is used throughout the math module, demonstrating use cases.
[(#2082)](https://github.com/PennyLaneAI/pennylane/pull/2084)

[(#2096)](https://github.com/PennyLaneAI/pennylane/pull/2096)

We can decorate a function, indicating the arguments that are
tensors handled by the interface:

```pycon
>>> @qml.math.multi_dispatch(argnum=[0, 1])
... def some_function(tensor1, tensor2, option, like):
Expand Down
91 changes: 47 additions & 44 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def multi_dispatch(argnum=None, tensor_list=None):


Args:
argnum (list[int]): A list of integers indicating indicating the indices
argnum (list[int]): A list of integers indicating the indices
to dispatch (i.e., the arguments that are tensors handled by an interface).
If ``None``, dispatch over all arguments.
tensor_lists (list[int]): a list of integers indicating which indices
Expand Down Expand Up @@ -126,12 +126,12 @@ def multi_dispatch(argnum=None, tensor_list=None):
>>> stack = multi_dispatch(argnum=0, tensor_list=0)(autoray.numpy.stack)

We can also use the ``multi_dispatch`` decorator to dispatch
arguments of more more elaborate custom functions. Here is an example
arguments of more elaborate custom functions. Here is an example
of a ``custom_function`` that
computes :math:`c \\sum_i (v_i)^T v_i`, where :math:`v_i` are vectors in ``values`` and
:math:`c` is a fixed ``coefficient``. Note how ``argnum=0`` only points to the first argument ``values``,
how ``tensor_list=0`` indicates that said first argument is a list of vectors, and that ``coefficient`` is not
dispatched.
:math:`c` is a fixed ``coefficient``. Note how ``argnum=0`` only points to the first
argument ``values``, how ``tensor_list=0`` indicates that said first argument is a
list of vectors, and that ``coefficient`` is not dispatched.

>>> @math.multi_dispatch(argnum=0, tensor_list=0)
>>> def custom_function(values, like, coefficient=10):
Expand Down Expand Up @@ -179,7 +179,8 @@ def wrapper(*args, **kwargs):
return decorator


def block_diag(values):
@multi_dispatch(argnum=[0], tensor_list=[0])
def block_diag(values, like):
"""Combine a sequence of 2D tensors to form a block diagonal tensor.

Args:
Expand All @@ -203,12 +204,12 @@ def block_diag(values):
[ 0, 0, -1, -6, -3, 0],
[ 0, 0, 0, 0, 0, 5]])
"""
interface = _multi_dispatch(values)
values = np.coerce(values, like=interface)
return np.block_diag(values, like=interface)
values = np.coerce(values, like=like)
return np.block_diag(values, like=like)


def concatenate(values, axis=0):
@multi_dispatch(argnum=[0], tensor_list=[0])
def concatenate(values, like, axis=0):
"""Concatenate a sequence of tensors along the specified axis.

.. warning::
Expand All @@ -235,9 +236,7 @@ def concatenate(values, axis=0):
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([6.00e-01, 1.00e-01, 6.00e-01, 1.00e-01, 2.00e-01, 3.00e-01, 5.00e+00, 8.00e+00, 1.01e+02], dtype=float32)>
"""
interface = _multi_dispatch(values)

if interface == "torch":
if like == "torch":
import torch

if axis is None:
Expand All @@ -248,15 +247,19 @@ def concatenate(values, axis=0):
else:
values = [torch.as_tensor(t) for t in values]

if interface == "tensorflow" and axis is None:
if like == "tensorflow" and axis is None:
# flatten and then concatenate zero'th dimension
# to reproduce numpy's behaviour
values = [np.flatten(np.array(t)) for t in values]
axis = 0

return np.concatenate(values, axis=axis, like=interface)
return np.concatenate(values, axis=axis, like=like)


# Note that diag is not eligible for the multi_dispatch decorator because
# it is used sometimes with iterable `values` that need to be interpreted
# as a list of tensors, and sometimes with a single tensor `values` that
# might not be iterable (for example a TensorFlow `Variable`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dwierichs is it possible to modify the decorator to allow for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. It would have to try to iterate over the arguments in (via the list method extend) that are marked via tensor_list, and if it fails, attempt to just append the respective argument to the arguments to dispatch over.
I'd consider this an unreasonable overhead given the currently single use case. What do you think?

Copy link
Member

@josh146 josh146 Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was thinking something simpler, although maybe it is not in the spirit of 'multiple dispatch'.

Basically, check not isinstance(arg, (list, tuple)), and if this is the case, call qml.math.get_interface(arg).

Basically, we assume that a tensor_list argument must be a built-in Python sequence (either list or tuple), and if this is not the case, we assume it must be a tensor and simply single dispatch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, that sounds cool! I'll give it a go :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this now, a bit more lightweight, even: If an argument marked as tensor list is not a tuple or list, it is simply treated as if it was not marked as a tensor list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, nice call!

def diag(values, k=0):
"""Construct a diagonal tensor from a list of scalars.

Expand Down Expand Up @@ -292,14 +295,14 @@ def diag(values, k=0):
[0.0000, 0.0000, 0.0000]])
"""
interface = _multi_dispatch(values)

if isinstance(values, (list, tuple)):
values = np.stack(np.coerce(values, like=interface), like=interface)

return np.diag(values, k=k, like=interface)


def dot(tensor1, tensor2):
@multi_dispatch(argnum=[0, 1])
def dot(tensor1, tensor2, like):
"""Returns the matrix or dot product of two tensors.

* If both tensors are 0-dimensional, elementwise multiplication
Expand All @@ -323,34 +326,34 @@ def dot(tensor1, tensor2):
Returns:
tensor_like: the matrix or dot product of two tensors
"""
interface = _multi_dispatch([tensor1, tensor2])
x, y = np.coerce([tensor1, tensor2], like=interface)
x, y = np.coerce([tensor1, tensor2], like=like)

if interface == "torch":
if like == "torch":
if x.ndim == 0 and y.ndim == 0:
return x * y

if x.ndim <= 2 and y.ndim <= 2:
return x @ y

return np.tensordot(x, y, axes=[[-1], [-2]], like=interface)
return np.tensordot(x, y, axes=[[-1], [-2]], like=like)

if interface == "tensorflow":
if like == "tensorflow":
if len(np.shape(x)) == 0 and len(np.shape(y)) == 0:
return x * y

if len(np.shape(y)) == 1:
return np.tensordot(x, y, axes=[[-1], [0]], like=interface)
return np.tensordot(x, y, axes=[[-1], [0]], like=like)

if len(np.shape(x)) == 2 and len(np.shape(y)) == 2:
return x @ y

return np.tensordot(x, y, axes=[[-1], [-2]], like=interface)
return np.tensordot(x, y, axes=[[-1], [-2]], like=like)

return np.dot(x, y, like=interface)
return np.dot(x, y, like=like)


def tensordot(tensor1, tensor2, axes=None):
@multi_dispatch(argnum=[0, 1])
def tensordot(tensor1, tensor2, like, axes=None):
"""Returns the tensor product of two tensors.
In general ``axes`` specifies either the set of axes for both
tensors that are contracted (with the first/second entry of ``axes``
Expand All @@ -376,11 +379,12 @@ def tensordot(tensor1, tensor2, axes=None):
Returns:
tensor_like: the tensor product of the two input tensors
"""
interface = _multi_dispatch([tensor1, tensor2])
return np.tensordot(tensor1, tensor2, axes=axes, like=interface)
x, y = np.coerce([tensor1, tensor2], like=like)
return np.tensordot(x, y, axes=axes, like=like)


def get_trainable_indices(values):
@multi_dispatch(argnum=[0], tensor_list=[0])
def get_trainable_indices(values, like):
"""Returns a set containing the trainable indices of a sequence of
values.

Expand All @@ -403,10 +407,9 @@ def get_trainable_indices(values):
tensor(0.0899685, requires_grad=True)
"""
trainable = requires_grad
interface = _multi_dispatch(values)
trainable_params = set()

if interface == "jax":
if like == "jax":
import jax

if not any(isinstance(v, jax.core.Tracer) for v in values):
Expand All @@ -420,7 +423,7 @@ def get_trainable_indices(values):
trainable = requires_grad

for idx, p in enumerate(values):
if trainable(p, interface=interface):
if trainable(p, interface=like):
trainable_params.add(idx)

return trainable_params
Expand Down Expand Up @@ -459,7 +462,8 @@ def ones_like(tensor, dtype=None):
return np.ones_like(tensor)


def safe_squeeze(tensor, axis=None, exclude_axis=None):
@multi_dispatch(argnum=[0])
def safe_squeeze(tensor, like, axis=None, exclude_axis=None):
"""Squeeze a tensor either along all axes, specified axes or all
but a set of excluded axes. For selective squeezing, catch errors
and do nothing if the selected axes do not have size 1.
Expand All @@ -474,8 +478,7 @@ def safe_squeeze(tensor, axis=None, exclude_axis=None):
or not excluded and that have size 1. If no axes are specified or excluded,
all axes are attempted to be squeezed.
"""
interface = _multi_dispatch([tensor])
if interface == "tensorflow":
if like == "tensorflow":
from tensorflow.python.framework.errors_impl import InvalidArgumentError

exception = InvalidArgumentError
Expand Down Expand Up @@ -508,7 +511,8 @@ def safe_squeeze(tensor, axis=None, exclude_axis=None):
return tensor


def stack(values, axis=0):
@multi_dispatch(argnum=[0], tensor_list=[0])
def stack(values, like, axis=0):
"""Stack a sequence of tensors along the specified axis.

.. warning::
Expand Down Expand Up @@ -537,9 +541,8 @@ def stack(values, axis=0):
[1.00e-01, 2.00e-01, 3.00e-01],
[5.00e+00, 8.00e+00, 1.01e+02]], dtype=float32)>
"""
interface = _multi_dispatch(values)
values = np.coerce(values, like=interface)
return np.stack(values, axis=axis, like=interface)
values = np.coerce(values, like=like)
return np.stack(values, axis=axis, like=like)


def where(condition, x=None, y=None):
Expand Down Expand Up @@ -612,7 +615,8 @@ def where(condition, x=None, y=None):
return np.where(condition, x, y, like=_multi_dispatch([condition, x, y]))


def frobenius_inner_product(A, B, normalize=False):
@multi_dispatch(argnum=[0, 1])
def frobenius_inner_product(A, B, like, normalize=False):
r"""Frobenius inner product between two matrices.

.. math::
Expand All @@ -637,8 +641,7 @@ def frobenius_inner_product(A, B, normalize=False):
>>> qml.math.frobenius_inner_product(A, B)
3.091948202943376
"""
interface = _multi_dispatch([A, B])
A, B = np.coerce([A, B], like=interface)
A, B = np.coerce([A, B], like=like)

inner_product = np.sum(A * B)

Expand All @@ -649,6 +652,7 @@ def frobenius_inner_product(A, B, normalize=False):
return inner_product


@multi_dispatch(argnum=[0, 2])
def scatter_element_add(tensor, index, value, like=None):
"""In-place addition of a multidimensional value over various
indices of a tensor.
Expand Down Expand Up @@ -682,8 +686,7 @@ def scatter_element_add(tensor, index, value, like=None):
if len(np.shape(tensor)) == 0 and index == ():
return tensor + value

interface = like or _multi_dispatch([tensor, value])
return np.scatter_element_add(tensor, index, value, like=interface)
return np.scatter_element_add(tensor, index, value, like=like)


def unwrap(values, max_depth=None):
Expand Down
88 changes: 87 additions & 1 deletion tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Unit tests for the TensorBox functional API in pennylane.fn.fn
"""
from functools import partial
import itertools
import numpy as onp
import pytest
Expand Down Expand Up @@ -546,7 +547,9 @@ def test_multidimensional_product(self, t1, t2):


class TestTensordotTorch:
"""Tests for the tensor product function in torch."""
"""Tests for the tensor product function in torch.
This test is required because the functionality of tensordot for Torch
is being patched in PennyLane, as compared to autoray."""

v1 = torch.tensor([0.1, 0.5, -0.9, 1.0, -4.2, 0.1], dtype=torch.float64)
v2 = torch.tensor([4.3, -1.2, 8.2, 0.6, -4.2, -11.0], dtype=torch.float64)
Expand Down Expand Up @@ -762,6 +765,89 @@ def test_tensordot_torch_tensor_matrix(self, M, expected, axes1, axes2):
assert fn.allclose(fn.tensordot(self.T1, M, axes=[axes1, axes2]), expected)


class TestTensordotDifferentiability:

v0 = np.array([0.1, 5.3, -0.9, 1.1])
v1 = np.array([0.5, -1.7, -2.9, 0.0])
v2 = np.array([-0.4, 9.1, 1.6])
exp_shapes = ((len(v0), len(v2), len(v0)), (len(v0), len(v2), len(v2)))
exp_jacs = (np.zeros(exp_shapes[0]), np.zeros(exp_shapes[1]))
for i in range(len(v0)):
exp_jacs[0][i, :, i] = v2
for i in range(len(v2)):
exp_jacs[1][:, i, i] = v0

def test_autograd(self):
"""Tests differentiability of tensordot with Autograd."""
v0 = np.array(self.v0, requires_grad=True)
v1 = np.array(self.v1, requires_grad=True)
v2 = np.array(self.v2, requires_grad=True)

# Test inner product
jac = qml.jacobian(partial(fn.tensordot, axes=[0, 0]), argnum=(0, 1))(v0, v1)
assert all(fn.allclose(jac[i], _v) for i, _v in enumerate([v1, v0]))

# Test outer product
jac = qml.jacobian(partial(fn.tensordot, axes=0), argnum=(0, 1))(v0, v2)
assert all(fn.shape(jac[i]) == self.exp_shapes[i] for i in [0, 1])
assert all(fn.allclose(jac[i], self.exp_jacs[i]) for i in [0, 1])

def test_torch(self):
"""Tests differentiability of tensordot with Torch."""
jac_fn = torch.autograd.functional.jacobian

v0 = torch.tensor(self.v0, requires_grad=True, dtype=torch.float64)
v1 = torch.tensor(self.v1, requires_grad=True, dtype=torch.float64)
v2 = torch.tensor(self.v2, requires_grad=True, dtype=torch.float64)

# Test inner product
jac = jac_fn(partial(fn.tensordot, axes=[[0], [0]]), (v0, v1))
assert all(fn.allclose(jac[i], _v) for i, _v in enumerate([v1, v0]))

# Test outer product
jac = jac_fn(partial(fn.tensordot, axes=0), (v0, v2))
assert all(fn.shape(jac[i]) == self.exp_shapes[i] for i in [0, 1])
assert all(fn.allclose(jac[i], self.exp_jacs[i]) for i in [0, 1])

def test_jax(self):
"""Tests differentiability of tensordot with JAX."""
jac_fn = jax.jacobian

v0 = jnp.array(self.v0)
v1 = jnp.array(self.v1)
v2 = jnp.array(self.v2)

# Test inner product
jac = jac_fn(partial(fn.tensordot, axes=[[0], [0]]), argnums=(0, 1))(v0, v1)
assert all(fn.allclose(jac[i], _v) for i, _v in enumerate([v1, v0]))

# Test outer product
jac = jac_fn(partial(fn.tensordot, axes=0), argnums=(0, 1))(v0, v2)
assert all(fn.shape(jac[i]) == self.exp_shapes[i] for i in [0, 1])
assert all(fn.allclose(jac[i], self.exp_jacs[i]) for i in [0, 1])

def test_tensorflow(self):
"""Tests differentiability of tensordot with TensorFlow."""

def jac_fn(func, args):
with tf.GradientTape() as tape:
out = func(*args)
return tape.jacobian(out, args)

v0 = tf.Variable(self.v0, dtype=tf.float64)
v1 = tf.Variable(self.v1, dtype=tf.float64)
v2 = tf.Variable(self.v2, dtype=tf.float64)

# Test inner product
jac = jac_fn(partial(fn.tensordot, axes=[[0], [0]]), (v0, v1))
assert all(fn.allclose(jac[i], _v) for i, _v in enumerate([v1, v0]))

# Test outer product
jac = jac_fn(partial(fn.tensordot, axes=0), (v0, v2))
assert all(fn.shape(jac[i]) == self.exp_shapes[i] for i in [0, 1])
assert all(fn.allclose(jac[i], self.exp_jacs[i]) for i in [0, 1])


# the following test data is of the form
# [original shape, axis to expand, new shape]
expand_dims_test_data = [
Expand Down