Skip to content

Commit

Permalink
[Queuing] Adds AnnotatingQueue() (#728)
Browse files Browse the repository at this point in the history
* Adds AnnotatingQueue()

* revert nesting change to fix tests

* added tests

* added tests

* update changelog

* increase coverage

* suggested changes
  • Loading branch information
josh146 committed Aug 4, 2020
1 parent 51dd8b4 commit 474982e
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 11 deletions.
6 changes: 5 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
[7.78800783e-01 1.94700196e-01 2.43375245e-02 2.02812704e-03 1.26757940e-04]
```

* Refactor of the QNode queuing architecture.
[(#722)](https://github.com/PennyLaneAI/pennylane/pull/722)
[(#728)](https://github.com/PennyLaneAI/pennylane/pull/728)

<h3>Breaking changes</h3>

<h3>Bug fixes</h3>
Expand All @@ -96,7 +100,7 @@

This release contains contributions from (in alphabetical order):

Josh Izaac, Maria Schuld, Antal Száva, Nicola Vitucci
Josh Izaac, Nathan Killoran, Maria Schuld, Antal Száva, Nicola Vitucci

# Release 0.10.0 (current release)

Expand Down
91 changes: 81 additions & 10 deletions pennylane/_queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
This module contains the :class:`QueuingContext` abstract base class.
"""
import abc
from collections import OrderedDict
from collections import OrderedDict, deque

import pennylane as qml

Expand All @@ -42,8 +42,8 @@ class QueuingContext(abc.ABC):

# TODO: update docstring

_active_contexts = []
"""The list of contexts that are currently active."""
_active_contexts = deque()
"""The stack of contexts that are currently active."""

def __enter__(self):
"""Adds this instance to the global list of active contexts.
Expand All @@ -61,22 +61,33 @@ def __exit__(self, exception_type, exception_value, traceback):
QueuingContext._active_contexts.remove(self)

@abc.abstractmethod
def _append(self, obj):
def _append(self, obj, **kwargs):
"""Append an object to this QueuingContext instance.
Args:
obj: The object to be appended
"""

@classmethod
def append(cls, obj):
def active_context(cls):
"""Returns the currently active queuing context."""
if cls._active_contexts:
return cls._active_contexts[-1]

return None

@classmethod
def append(cls, obj, **kwargs):
"""Append an object to the queue(s).
Args:
obj: the object to be appended
"""
# TODO: this method should append only to `cls.active_context`, *not*
# all active contexts. However this will require a refactor in
# the template decorator and the operation recorder.
for context in cls._active_contexts:
context._append(obj) # pylint: disable=protected-access
context._append(obj, **kwargs) # pylint: disable=protected-access

@abc.abstractmethod
def _remove(self, obj):
Expand All @@ -93,15 +104,28 @@ def remove(cls, obj):
Args:
obj: the object to be removed
"""
# TODO: this method should remove only from `cls.active_context`, *not*
# all active contexts. However this will require a refactor in
# the template decorator and the operation recorder.
for context in cls._active_contexts:
# We use the duck-typing approach to assume that the underlying remove
# behaves like list.remove and throws a ValueError if the operator
# is not in the list
# behaves like `list.remove(obj)` or `del dict[key]` and throws a
# ValueError or KeyError if the operator is not present
try:
context._remove(obj) # pylint: disable=protected-access
except ValueError:
except (ValueError, KeyError):
pass

@classmethod
def update_info(cls, obj, **kwargs):
"""Updates information of an object in the queue."""
cls.active_context().update_info(obj, **kwargs)

@classmethod
def get_info(cls, obj):
"""Returns information of an object in the queue."""
return cls.active_context().get_info(obj)


class Queue(QueuingContext):
"""Lightweight class that maintains a basic queue of operations and pre/post-processing steps
Expand All @@ -110,13 +134,60 @@ class Queue(QueuingContext):
def __init__(self):
self.queue = []

def _append(self, obj):
def _append(self, obj, **kwargs):
self.queue.append(obj)

def _remove(self, obj):
self.queue.remove(obj)


class AnnotatedQueue(QueuingContext):
"""Lightweight class that maintains a basic queue of operations, in addition
to annotations."""

def __init__(self):
self._queue = OrderedDict()

def _append(self, obj, **kwargs):
self._queue[obj] = kwargs

def _remove(self, obj):
del self._queue[obj]

def update_info(self, obj, **kwargs):
"""Updates the annotated information of an object in the queue.
Args:
obj: the object to update
kwargs: Keyword arguments and values to add to the annotation.
If a particular keyword already exists in the annotation,
the value is updated.
"""
if obj not in self._queue:
raise ValueError(f"Object {obj} not in the queue.")

self._queue[obj].update(kwargs)

def get_info(self, obj):
"""Returns the annotated information of an object in the queue.
Args:
obj: the object to query
Returns:
dict: the annotated information
"""
if obj not in self._queue:
raise ValueError(f"Object {obj} not in the queue.")

return self._queue[obj]

@property
def queue(self):
"""Returns a list of objects in the annotated queue"""
return list(self._queue.keys())


class OperationRecorder(Queue):
"""A template and quantum function inspector,
allowing easy introspection of operators that have been
Expand Down
149 changes: 149 additions & 0 deletions tests/test_queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def test_remove_no_context(self):

QueuingContext.remove(qml.PauliZ(0))

def test_no_active_context(self, mock_queuing_context):
"""Test that if there are no active contexts, active_context() returns None"""
assert mock_queuing_context.active_context() is None


class TestQueue:
"""Test the Queue class."""
Expand Down Expand Up @@ -283,3 +287,148 @@ def template(x):
template(3)

assert str(recorder) == expected_output


class TestAnnotatedQueue:
"""Tests for the annotated queue class"""

def test_remove_not_in_queue(self):
"""Test that remove does not fail when the object to be removed is not in the queue."""

with qml._queuing.AnnotatedQueue() as q1:
op1 = qml.PauliZ(0)
op2 = qml.PauliZ(1)
q1.append(op1)
q1.append(op2)

with qml._queuing.AnnotatedQueue() as q2:
q2.append(op1)
q2.remove(op2)

def test_append_qubit_gates(self):
"""Test that gates are successfully appended to the queue."""
with qml._queuing.AnnotatedQueue() as q:
ops = [
qml.RX(0.5, wires=0),
qml.RY(-10.1, wires=1),
qml.CNOT(wires=[0, 1]),
qml.PhaseShift(-1.1, wires=18),
qml.T(wires=99),
]
assert q.queue == ops

def test_append_qubit_observables(self):
"""Test that ops that are also observables are successfully
appended to the queue."""
with qml._queuing.AnnotatedQueue() as q:
# wire repetition is deliberate, Queue contains no checks/logic
# for circuits
ops = [
qml.Hadamard(wires=0),
qml.PauliX(wires=1),
qml.PauliY(wires=1),
qml.Hermitian(np.ones([2, 2]), wires=7),
]
assert q.queue == ops

def test_append_tensor_ops(self):
"""Test that ops which are used as inputs to `Tensor`
are successfully added to the queue, but no `Tensor` object is."""

with qml._queuing.AnnotatedQueue() as q:
A = qml.PauliZ(0)
B = qml.PauliY(1)
tensor_op = qml.operation.Tensor(A, B)
assert q.queue == [A, B]
assert tensor_op.obs == [A, B]
assert all(not isinstance(op, qml.operation.Tensor) for op in q.queue)

def test_append_tensor_ops_overloaded(self):
"""Test that Tensor ops created using `@`
are successfully added to the queue, but no `Tensor` object is."""

with qml._queuing.AnnotatedQueue() as q:
A = qml.PauliZ(0)
B = qml.PauliY(1)
tensor_op = A @ B
assert q.queue == [A, B]
assert tensor_op.obs == [A, B]
assert all(not isinstance(op, qml.operation.Tensor) for op in q.queue)

def test_get_info(self):
"""Test that get_info correctly returns an annotation"""
A = qml.RZ(0.5, wires=1)

with qml._queuing.AnnotatedQueue() as q:
q.append(A, inv=True)

assert q.get_info(A) == {"inv": True}

def test_get_info_error(self):
"""Test that an exception is raised if get_info is called
for a non-existent object"""

with qml._queuing.AnnotatedQueue() as q:
A = qml.PauliZ(0)

B = qml.PauliY(1)

with pytest.raises(ValueError, match="not in the queue"):
q.get_info(B)

def test_update_info(self):
"""Test that update_info correctly updates an annotation"""
A = qml.RZ(0.5, wires=1)

with qml._queuing.AnnotatedQueue() as q:
q.append(A, inv=True)
assert qml.QueuingContext.get_info(A) == {"inv": True}

assert q.get_info(A) == {"inv": True}

q.update_info(A, inv=False, owner=None)
assert q.get_info(A) == {"inv": False, "owner": None}

def test_update_error(self):
"""Test that an exception is raised if get_info is called
for a non-existent object"""

with qml._queuing.AnnotatedQueue() as q:
A = qml.PauliZ(0)

B = qml.PauliY(1)

with pytest.raises(ValueError, match="not in the queue"):
q.update_info(B, inv=True)

def test_append_annotating_object(self):
"""Test appending an object that writes annotations when queuing itself"""

class AnnotatingTensor(qml.operation.Tensor):
"""Dummy tensor class that queues itself on initialization
to an annotating queue."""

def __init__(self, *args):
super().__init__(*args)
self.queue()

def queue(self):
qml.QueuingContext.append(self, owns=tuple(self.obs))

for o in self.obs:
try:
qml.QueuingContext.update_info(o, owner=self)
except AttributeError:
pass

return self

with qml._queuing.AnnotatedQueue() as q:
A = qml.PauliZ(0)
B = qml.PauliY(1)
tensor_op = AnnotatingTensor(A, B)

assert q.queue == [A, B, tensor_op]
assert q.get_info(A) == {"owner": tensor_op}
assert q.get_info(B) == {"owner": tensor_op}
assert q.get_info(tensor_op) == {"owns": (A, B)}

0 comments on commit 474982e

Please sign in to comment.