diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 06e6332ee5e..19164addbbc 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -525,6 +525,10 @@

Bug fixes

+* Fixes a bug where an incorrect number of executions are recorded by + a QNode using a custom cache with `diff_method="backprop"`. + [(#2171)](https://github.com/PennyLaneAI/pennylane/pull/2171) + * Fixes a bug where the `default.qubit.jax` device can't be used with `diff_method=None` and jitting. [(#2136)](https://github.com/PennyLaneAI/pennylane/pull/2136) diff --git a/pennylane/qnode.py b/pennylane/qnode.py index bdfd1983c98..0a0a7da2720 100644 --- a/pennylane/qnode.py +++ b/pennylane/qnode.py @@ -217,6 +217,7 @@ def __init__( self._original_device = device self.gradient_fn = None self.gradient_kwargs = None + self._tape_cached = False self._update_gradient_fn() functools.update_wrapper(self, func) @@ -265,7 +266,9 @@ def _update_original_device(self): # of the user's device before and after executing the tape. if self.device is not self._original_device: - self._original_device._num_executions += 1 # pylint: disable=protected-access + + if not self._tape_cached: + self._original_device._num_executions += 1 # pylint: disable=protected-access # Update for state vector simulators that have the _pre_rotated_state attribute if hasattr(self._original_device, "_pre_rotated_state"): @@ -546,6 +549,14 @@ def __call__(self, *args, **kwargs): # construct the tape self.construct(args, kwargs) + cache = self.execute_kwargs.get("cache", False) + using_custom_cache = ( + hasattr(cache, "__getitem__") + and hasattr(cache, "__setitem__") + and hasattr(cache, "__delitem__") + ) + self._tape_cached = using_custom_cache and self.tape.hash in cache + res = qml.execute( [self.tape], device=self.device, diff --git a/tests/test_qnode.py b/tests/test_qnode.py index f527848754b..36340e0b68e 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -711,6 +711,56 @@ def func(): assert dev.num_executions == 6 + def test_num_exec_caching_device_swap(self): + """Tests that if we swapped the original device (e.g., when + diff_method='backprop') then the number of executions recorded is + correct.""" + dev = qml.device("default.qubit", wires=2) + + cache = {} + + @qml.qnode(dev, diff_method="backprop", cache=cache) + def circuit(): + qml.RY(0.345, wires=0) + return qml.expval(qml.PauliZ(0)) + + for _ in range(15): + circuit() + + # Although we've evaluated the QNode more than once, due to caching, + # there was one device execution recorded + assert dev.num_executions == 1 + assert cache != {} + + def test_num_exec_caching_device_swap_two_exec(self): + """Tests that if we swapped the original device (e.g., when + diff_method='backprop') then the number of executions recorded is + correct even with multiple QNode evaluations.""" + dev = qml.device("default.qubit", wires=2) + + cache = {} + + @qml.qnode(dev, diff_method="backprop", cache=cache) + def circuit(): + qml.RY(0.345, wires=0) + return qml.expval(qml.PauliZ(0)) + + for _ in range(15): + circuit() + + @qml.qnode(dev, diff_method="backprop", cache=cache) + def circuit(): + qml.RZ(0.345, wires=0) + return qml.expval(qml.PauliZ(0)) + + for _ in range(15): + circuit() + + # Although we've evaluated the QNode several times, due to caching, + # there were two device executions recorded + assert dev.num_executions == 2 + assert cache != {} + @pytest.mark.parametrize("diff_method", ["parameter-shift", "finite-diff"]) def test_single_expectation_value_with_argnum_one(self, diff_method, tol): """Tests correct output shape and evaluation for a QNode