Skip to content

Commit

Permalink
Add custom pattern to broadcast (#603)
Browse files Browse the repository at this point in the history
* add custom pattern

* add figure to docstrign

* polish

* fix image paths

* add an output test

* black code

* polish

* revert import change
  • Loading branch information
mariaschuld committed Apr 27, 2020
1 parent 488bdd7 commit 8be49d9
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 11 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Expand Up @@ -2,6 +2,9 @@

<h3>New features since last release</h3>

* The ``templates.broadcast`` function can now take custom patterns.
[(#603)](https://github.com/XanaduAI/pennylane/pull/603)

* PennyLane QNodes can now be converted into Keras layers, allowing for creation of quantum and
hybrid models using the Keras API.
[(#529)](https://github.com/XanaduAI/pennylane/pull/529)
Expand Down
Binary file added doc/_static/templates/broadcast_custom.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 11 additions & 1 deletion doc/code/qml_templates.rst
Expand Up @@ -36,13 +36,23 @@ State preperations
Custom templates
----------------

Custom templates can be constructed using the template decorator.
The template decorator can used to register a quantum function as a template.

.. autosummary::
:toctree:

pennylane.templates.template

Broadcasting function
---------------------

The broadcast function creates a new template by broadcasting templates (or normal gates) over wires in a
predefined pattern. You can import this function both via ``qml.broadcast`` and ``qml.templates.broadcast``.

.. autosummary::

pennylane.broadcast

Utility functions for input checks
----------------------------------

Expand Down
4 changes: 4 additions & 0 deletions doc/introduction/templates.rst
Expand Up @@ -240,6 +240,10 @@ takes single quantum operations or other templates and applies them to wires in
:description: broadcast (all-to-all)
:figure: ../_static/templates/broadcast_alltoall.png

.. customgalleryitem::
:link: ../code/api/pennylane.broadcast.html
:description: broadcast (custom)
:figure: ../_static/templates/broadcast_custom.png

.. raw:: html

Expand Down
97 changes: 89 additions & 8 deletions pennylane/templates/broadcast.py
Expand Up @@ -125,6 +125,14 @@ def broadcast(unitary, wires, pattern, parameters=None, kwargs=None):
:width: 20%
:target: javascript:void(0);
* A custom pattern can be passed by provding a list of wire lists to ``pattern``. The ``unitary`` is applied
to each set of wires specified in the list.
.. figure:: ../../_static/templates/broadcast_custom.png
:align: center
:width: 20%
:target: javascript:void(0);
Each ``unitary`` may depend on a different set of parameters. These are passed as a list by the ``parameters``
argument.
Expand Down Expand Up @@ -400,6 +408,8 @@ def circuit(pars):
.. code-block:: python
dev = qml.device('default.qubit', wires=4)
@qml.qnode(dev)
def circuit(pars):
broadcast(unitary=qml.CRot, pattern='pyramid',
Expand All @@ -416,6 +426,8 @@ def circuit(pars):
.. code-block:: python
dev = qml.device('default.qubit', wires=4)
@qml.qnode(dev)
def circuit(pars):
broadcast(unitary=qml.CRot, pattern='ring',
Expand All @@ -430,9 +442,49 @@ def circuit(pars):
pars6 = [3, -2, -3]
circuit([pars1, pars2, pars3, pars4, pars5, pars6])
* Custom pattern
For a custom pattern, the wire lists for each application of the unitary is
passed to ``pattern``:
.. code-block:: python
dev = qml.device('default.qubit', wires=5)
pattern = [[0, 1], [3, 4]]
@qml.qnode(dev)
def circuit():
broadcast(unitary=qml.CNOT, pattern=pattern,
wires=range(5))
return qml.expval(qml.PauliZ(0))
circuit()
When using a parametrized unitary, make sure that the number of wire lists in ``pattern`` corresponds to the
number of parameters in ``parameters``.
.. code-block:: python
pattern = [[0, 1], [3, 4]]
@qml.qnode(dev)
def circuit(pars):
broadcast(unitary=qml.CRot, pattern=pattern,
wires=range(5), parameters=pars)
return qml.expval(qml.PauliZ(0))
pars1 = [1, 2, 3]
pars2 = [-1, 3, 1]
pars = [pars1, pars2]
assert len(pars) == len(pattern)
circuit(pars)
"""

OPTIONS = ["single", "double", "double_odd", "chain", "ring", "pyramid", "all_to_all"]
OPTIONS = ["single", "double", "double_odd", "chain", "ring", "pyramid", "all_to_all", "custom"]

#########
# Input checks
Expand All @@ -446,20 +498,45 @@ def circuit(pars):
"Iterable; got {}".format(type(parameters)),
)

check_type(
pattern, [str], msg="'pattern' must be a string; got {}".format(type(pattern)),
)

if kwargs is None:
kwargs = {}

check_type(
kwargs, [dict], msg="'kwargs' must be a dictionary; got {}".format(type(kwargs)),
)

check_is_in_options(
pattern, OPTIONS, msg="did not recognize option {} for 'pattern'".format(pattern),
)
custom_pattern = None

if isinstance(pattern, str):
check_is_in_options(
pattern, OPTIONS, msg="did not recognize option {} for 'pattern'".format(pattern),
)
else:
check_type(
pattern,
[Iterable],
msg="a custom pattern must be a list of lists of wire indices"
"; got {}".format(parameters),
)
for wire_set in pattern:
check_type(
wire_set,
[Iterable],
msg="a custom pattern must be a list of lists of wire indices"
"; got {}".format(parameters),
)
for wire in wire_set:
check_type(
wire,
[int],
msg="a custom pattern must be a list of lists of wire indices"
"; got {}".format(parameters),
)

# remember the wire pattern
custom_pattern = pattern
# set "pattern" to "custom", indicating that custom settings have to be used
pattern = "custom"

n_parameters = {
"single": len(wires),
Expand All @@ -471,6 +548,9 @@ def circuit(pars):
"all_to_all": 0 if len(wires) in [0, 1] else len(wires) * (len(wires) - 1) // 2,
}

if pattern == "custom":
n_parameters["custom"] = len(custom_pattern)

# check that enough parameters for pattern
if parameters is not None:
shape = get_shape(parameters)
Expand Down Expand Up @@ -505,6 +585,7 @@ def circuit(pars):
"ring": wires_ring(wires),
"pyramid": wires_pyramid(wires),
"all_to_all": wires_all_to_all(wires),
"custom": custom_pattern,
}

# broadcast the unitary
Expand Down
66 changes: 64 additions & 2 deletions tests/templates/test_broadcast.py
Expand Up @@ -121,8 +121,8 @@ def KwargTemplateDouble(par, wires, a=True):
]


class TestConstructorBroadcast:
"""Tests the broadcast template constructor."""
class TestBuiltinPatterns:
"""Tests the built-in patterns ("single", "ring", etc) of the broadcast template constructor."""

@pytest.mark.parametrize("unitary, parameters", [(RX, [[0.1], [0.2], [0.3]]),
(Rot,
Expand Down Expand Up @@ -288,3 +288,65 @@ def test_wire_sequence_generating_functions(self, function, wires, target):

sequence = function(wires)
assert sequence == target


class TestCustomPattern:
"""Additional tests for using broadcast with a custom pattern."""

@pytest.mark.parametrize("custom_pattern, pattern", [([[0, 1], [1, 2], [2, 3], [3, 0]], "ring"),
([[0, 1], [1, 2], [2, 3]], "chain"),
([[0, 1], [2, 3]], "double")
])
def test_reproduce_builtin_patterns(self, custom_pattern, pattern):
"""Tests that the custom pattern can reproduce the built in patterns."""

dev = qml.device('default.qubit', wires=4)

# qnode using custom pattern
@qml.qnode(dev)
def circuit1():
broadcast(unitary=qml.CNOT, pattern=custom_pattern, wires=range(4))
return [qml.expval(qml.PauliZ(wires=w)) for w in range(4)]

# qnode using built-in pattern
@qml.qnode(dev)
def circuit2():
broadcast(unitary=qml.CNOT, pattern=pattern, wires=range(4))
return [qml.expval(qml.PauliZ(wires=w)) for w in range(4)]

custom = circuit1()
built_in = circuit2()
assert np.allclose(custom, built_in)

@pytest.mark.parametrize("custom_pattern", [1,
[1, 2],
[['a'], ['b']]
])
def test_exception_custom_pattern_not_valid(self, custom_pattern):
"""Tests that an exception is raised if the pattern is not a list of lists of integers."""

dev = qml.device('default.qubit', wires=2)

@qml.qnode(dev)
def circuit():
broadcast(unitary=qml.Hadamard, wires=[0, 1], pattern=custom_pattern)
return qml.expval(qml.PauliZ(0))

with pytest.raises(ValueError, match="a custom pattern must be a list"):
circuit()

@pytest.mark.parametrize("custom_pattern, expected", [([[0], [2], [3], [2]], [-1., 1., 1., -1.]),
([[3], [2], [0]], [-1., 1., -1., -1.]),
])
def test_correct_output(self, custom_pattern, expected):
"""Tests the output for simple cases."""

dev = qml.device('default.qubit', wires=4)

@qml.qnode(dev)
def circuit():
broadcast(unitary=qml.PauliX, wires=range(4), pattern=custom_pattern)
return [qml.expval(qml.PauliZ(w)) for w in range(4)]

res = circuit()
assert np.allclose(res, expected)

0 comments on commit 8be49d9

Please sign in to comment.