Skip to content

Commit

Permalink
Only update object's queue metadata if already in the queue (#2612)
Browse files Browse the repository at this point in the history
* add safe_update_info

* testing

* Update pennylane/queuing.py

* changelog and updated test

* linting

* add tensor queuing tests

Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
  • Loading branch information
albi3ro and Jaybsoni committed May 30, 2022
1 parent 5e411af commit e40c47e
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 48 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@
* Sparse Hamiltonians representation has changed from COOrdinate (COO) to Compressed Sparse Row (CSR) format. The CSR representation is more performant for arithmetic operations and matrix vector products. This change decreases the `expval()` calculation time, for `qml.SparseHamiltonian`, specially for large workflows. Also, the CRS format consumes less memory for the `qml.SparseHamiltonian` storage.
[(#2561)](https://github.com/PennyLaneAI/pennylane/pull/2561)

* A new method `safe_update_info` is added to `qml.QueuingContext`. This method is substituted
for `qml.QueuingContext.update_info` in a variety of places.
[(#2612)](https://github.com/PennyLaneAI/pennylane/pull/2612)

* `BasisEmbedding` can accept an int as argument instead of a list of bits (optionally). Example: `qml.BasisEmbedding(4, wires = range(4))` is now equivalent to `qml.BasisEmbedding([0,1,0,0], wires = range(4))` (because 4=0b100).
[(#2601)](https://github.com/PennyLaneAI/pennylane/pull/2601)

Expand Down
7 changes: 1 addition & 6 deletions pennylane/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,7 @@ def expand(self):
def queue(self, context=qml.QueuingContext):
"""Append the measurement process to an annotated queue."""
if self.obs is not None:
try:
context.update_info(self.obs, owner=self)
except qml.queuing.QueuingError:
self.obs.queue(context=context)
context.update_info(self.obs, owner=self)

context.safe_update_info(self.obs, owner=self)
context.append(self, owns=self.obs)
else:
context.append(self)
Expand Down
14 changes: 5 additions & 9 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,11 +1741,7 @@ def queue(self, context=qml.QueuingContext, init=False): # pylint: disable=argu
else:
raise ValueError("Can only perform tensor products between observables.")

try:
context.update_info(o, owner=self)
except qml.queuing.QueuingError:
o.queue(context=context)
context.update_info(o, owner=self)
context.safe_update_info(o, owner=self)

context.append(self, owns=tuple(constituents))
return self
Expand Down Expand Up @@ -1849,16 +1845,16 @@ def __matmul__(self, other):
owning_info = qml.QueuingContext.get_info(self)["owns"] + (other,)

# update the annotated queue information
qml.QueuingContext.update_info(self, owns=owning_info)
qml.QueuingContext.update_info(other, owner=self)
qml.QueuingContext.safe_update_info(self, owns=owning_info)
qml.QueuingContext.safe_update_info(other, owner=self)

return self

def __rmatmul__(self, other):
if isinstance(other, Observable):
self.obs[:0] = [other]
if qml.QueuingContext.recording():
qml.QueuingContext.update_info(other, owner=self)
qml.QueuingContext.safe_update_info(self, owns=tuple(self.obs))
qml.QueuingContext.safe_update_info(other, owner=self)
return self

raise ValueError("Can only perform tensor products between observables.")
Expand Down
8 changes: 1 addition & 7 deletions pennylane/ops/qubit/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pennylane import numpy as np

from pennylane.operation import Observable, Tensor
from pennylane.queuing import QueuingError
from pennylane.wires import Wires

OBS_MAP = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "Hadamard": "H", "Identity": "I"}
Expand Down Expand Up @@ -633,11 +632,6 @@ def __isub__(self, H):
def queue(self, context=qml.QueuingContext):
"""Queues a qml.Hamiltonian instance"""
for o in self.ops:
try:
context.update_info(o, owner=self)
except QueuingError:
o.queue(context=context)
context.update_info(o, owner=self)

context.safe_update_info(o, owner=self)
context.append(self, owns=tuple(self.ops))
return self
25 changes: 24 additions & 1 deletion pennylane/queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,27 @@ def update_info(cls, obj, **kwargs):
if cls.recording():
cls.active_context()._update_info(obj, **kwargs) # pylint: disable=protected-access

# pylint: disable=protected-access
@classmethod
def safe_update_info(cls, obj, **kwargs):
"""Updates information of an object in the active queue if it is already in the queue.
Args:
obj: the object with metadata to be updated
"""
if cls.recording():
cls.active_context()._safe_update_info(obj, **kwargs)

@abc.abstractmethod
def _safe_update_info(self, obj, **kwargs):
"""Updates information of an object in the queue instance only if the object is in the queue.
If the object is not in the queue, nothing is done and no errors are raised.
"""

@abc.abstractmethod
def _update_info(self, obj, **kwargs):
"""Updates information of an object in the queue instance."""
"""Updates information of an object in the queue instance. Raises a ``QueuingError`` if the object
is not in the queue."""

@classmethod
def get_info(cls, obj):
Expand Down Expand Up @@ -222,6 +240,10 @@ def _append(self, obj, **kwargs):
def _remove(self, obj):
del self._queue[obj]

def _safe_update_info(self, obj, **kwargs):
if obj in self._queue:
self._queue[obj].update(kwargs)

def _update_info(self, obj, **kwargs):
if obj not in self._queue:
raise QueuingError(f"Object {obj} not in the queue.")
Expand All @@ -240,6 +262,7 @@ def _get_info(self, obj):
append = _append
remove = _remove
update_info = _update_info
safe_update_info = _safe_update_info
get_info = _get_info

@property
Expand Down
22 changes: 10 additions & 12 deletions tests/ops/qubit/test_hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,19 +788,12 @@ def test_arithmetic_errors(self):
with pytest.raises(ValueError, match="Cannot subtract"):
H -= A

def test_hamiltonian_queue(self):
"""Tests that Hamiltonian are queued correctly"""

# Outside of tape
def test_hamiltonian_queue_outside(self):
"""Tests that Hamiltonian are queued correctly when components are defined outside the recording context."""

queue = [
qml.Hadamard(wires=1),
qml.PauliX(wires=0),
qml.PauliZ(0),
qml.PauliZ(2),
qml.PauliZ(0) @ qml.PauliZ(2),
qml.PauliX(1),
qml.PauliZ(1),
qml.Hamiltonian(
[1, 3, 1], [qml.PauliX(1), qml.PauliZ(0) @ qml.PauliZ(2), qml.PauliZ(1)]
),
Expand All @@ -813,9 +806,14 @@ def test_hamiltonian_queue(self):
qml.PauliX(wires=0)
qml.expval(H)

assert np.all([q1.compare(q2) for q1, q2 in zip(tape.queue, queue)])
assert len(tape.queue) == 3
assert isinstance(tape.queue[0], qml.Hadamard)
assert isinstance(tape.queue[1], qml.PauliX)
assert isinstance(tape.queue[2], qml.measurements.MeasurementProcess)
assert H.compare(tape.queue[2].obs)

# Inside of tape
def test_hamiltonian_queue_inside(self):
"""Tests that Hamiltonian are queued correctly when components are instantiated inside the recording context."""

queue = [
qml.Hadamard(wires=1),
Expand Down Expand Up @@ -1278,7 +1276,7 @@ def test_grouping_does_not_alter_queue(self):
with qml.tape.QuantumTape() as tape:
H = qml.Hamiltonian(coeffs, obs, grouping_type="qwc")

assert tape.queue == [a, b, c, H]
assert tape.queue == [H]

def test_grouping_method_can_be_set(self):
r"""Tests that the grouping method can be controlled by kwargs.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,21 +656,21 @@ def test_annotating_tensor_return_type(self, op1, op2, stat_func, return_type):
)
def test_queueing_tensor_observable(self, op1, op2, stat_func, return_type):
"""Test that if the constituent components of a tensor operation are not
found in the queue for annotation, that they are queued first and then annotated."""
found in the queue for annotation, they are not queued or annotated."""
A = op1(0)
B = op2(1)

with AnnotatedQueue() as q:
tensor_op = A @ B
stat_func(tensor_op)

assert q.queue[:-1] == [A, B, tensor_op]
assert len(q._queue) == 2

assert q.queue[0] is tensor_op
meas_proc = q.queue[-1]
assert isinstance(meas_proc, MeasurementProcess)
assert meas_proc.return_type == return_type

assert q._get_info(A) == {"owner": tensor_op}
assert q._get_info(B) == {"owner": tensor_op}
assert q._get_info(tensor_op) == {"owns": (A, B), "owner": meas_proc}


Expand Down
40 changes: 33 additions & 7 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,13 +861,9 @@ def test_queuing_defined_outside(self):
with qml.tape.QuantumTape() as tape:
T.queue()

assert len(tape.queue) == 3
assert tape.queue[0] is op1
assert tape.queue[1] is op2
assert tape.queue[2] is T
assert len(tape.queue) == 1
assert tape.queue[0] is T

assert tape._queue[op1] == {"owner": T}
assert tape._queue[op2] == {"owner": T}
assert tape._queue[T] == {"owns": (op1, op2)}

def test_queuing(self):
Expand All @@ -887,7 +883,7 @@ def test_queuing(self):
assert tape._queue[op2] == {"owner": T}
assert tape._queue[T] == {"owns": (op1, op2)}

def test_queuing_matmul(self):
def test_queuing_observable_matmul(self):
"""Test queuing when tensor constructed with matmul."""

with qml.tape.QuantumTape() as tape:
Expand All @@ -900,6 +896,36 @@ def test_queuing_matmul(self):
assert tape._queue[op2] == {"owner": t}
assert tape._queue[t] == {"owns": (op1, op2)}

def test_queuing_tensor_matmul(self):
"""Tests the tensor-specific matmul method updates queuing metadata."""

with qml.tape.QuantumTape() as tape:
op1 = qml.PauliX(0)
op2 = qml.PauliY(1)
t = Tensor(op1, op2)

op3 = qml.PauliZ(2)
t2 = t @ op3

assert tape._queue[t2] == {"owns": (op1, op2, op3)}
assert tape._queue[op3] == {"owner": t2}

def test_queuing_tensor_rmatmul(self):
"""Tests tensor-specific rmatmul updates queuing metatadata."""

with qml.tape.QuantumTape() as tape:
op1 = qml.PauliX(0)
op2 = qml.PauliY(1)

t1 = op1 @ op2

op3 = qml.PauliZ(3)

t2 = op3 @ t1

assert tape._queue[op3] == {"owner": t2}
assert tape._queue[t2] == {"owns": (op3, op1, op2)}

def test_name(self):
"""Test that the names of the observables are
returned as expected"""
Expand Down
44 changes: 42 additions & 2 deletions tests/test_queuing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,15 @@ def test_update_info(self):
q.append(A, inv=True)
assert QueuingContext.get_info(A) == {"inv": True}

assert q._get_info(A) == {"inv": True}
qml.QueuingContext.update_info(A, key="value1")

# should pass silently because no longer recording
qml.QueuingContext.update_info(A, key="value2")

assert q._get_info(A) == {"inv": True, "key": "value1"}

q._update_info(A, inv=False, owner=None)
assert q._get_info(A) == {"inv": False, "owner": None}
assert q._get_info(A) == {"inv": False, "owner": None, "key": "value1"}

def test_update_error(self):
"""Test that an exception is raised if get_info is called
Expand All @@ -234,6 +239,41 @@ def test_update_error(self):
with pytest.raises(QueuingError, match="not in the queue"):
q._update_info(B, inv=True)

def test_safe_update_info_queued(self):
"""Test the `safe_update_info` method if the object is already queued."""
op = qml.RX(0.5, wires=1)

with AnnotatedQueue() as q:
q.append(op, key="value1")
assert q.get_info(op) == {"key": "value1"}
qml.QueuingContext.safe_update_info(op, key="value2")

qml.QueuingContext.safe_update_info(op, key="no changes here")
assert q.get_info(op) == {"key": "value2"}

q.safe_update_info(op, key="value3")
assert q.get_info(op) == {"key": "value3"}

q._safe_update_info(op, key="value4")
assert q.get_info(op) == {"key": "value4"}

def test_safe_update_info_not_queued(self):
"""Tests the safe_update_info method passes silently if the object is
not already queued."""
op = qml.RX(0.5, wires=1)

with AnnotatedQueue() as q:
qml.QueuingContext.safe_update_info(op, key="value2")
qml.QueuingContext.safe_update_info(op, key="no changes here")

assert len(q.queue) == 0

q.safe_update_info(op, key="value3")
assert len(q.queue) == 0

q._safe_update_info(op, key="value4")
assert len(q.queue) == 0

def test_append_annotating_object(self):
"""Test appending an object that writes annotations when queuing itself"""

Expand Down

0 comments on commit e40c47e

Please sign in to comment.