diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 06e6332ee5e..19164addbbc 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -525,6 +525,10 @@
Bug fixes
+* Fixes a bug where an incorrect number of executions are recorded by
+ a QNode using a custom cache with `diff_method="backprop"`.
+ [(#2171)](https://github.com/PennyLaneAI/pennylane/pull/2171)
+
* Fixes a bug where the `default.qubit.jax` device can't be used with `diff_method=None` and jitting.
[(#2136)](https://github.com/PennyLaneAI/pennylane/pull/2136)
diff --git a/pennylane/qnode.py b/pennylane/qnode.py
index bdfd1983c98..0a0a7da2720 100644
--- a/pennylane/qnode.py
+++ b/pennylane/qnode.py
@@ -217,6 +217,7 @@ def __init__(
self._original_device = device
self.gradient_fn = None
self.gradient_kwargs = None
+ self._tape_cached = False
self._update_gradient_fn()
functools.update_wrapper(self, func)
@@ -265,7 +266,9 @@ def _update_original_device(self):
# of the user's device before and after executing the tape.
if self.device is not self._original_device:
- self._original_device._num_executions += 1 # pylint: disable=protected-access
+
+ if not self._tape_cached:
+ self._original_device._num_executions += 1 # pylint: disable=protected-access
# Update for state vector simulators that have the _pre_rotated_state attribute
if hasattr(self._original_device, "_pre_rotated_state"):
@@ -546,6 +549,14 @@ def __call__(self, *args, **kwargs):
# construct the tape
self.construct(args, kwargs)
+ cache = self.execute_kwargs.get("cache", False)
+ using_custom_cache = (
+ hasattr(cache, "__getitem__")
+ and hasattr(cache, "__setitem__")
+ and hasattr(cache, "__delitem__")
+ )
+ self._tape_cached = using_custom_cache and self.tape.hash in cache
+
res = qml.execute(
[self.tape],
device=self.device,
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index f527848754b..36340e0b68e 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -711,6 +711,56 @@ def func():
assert dev.num_executions == 6
+ def test_num_exec_caching_device_swap(self):
+ """Tests that if we swapped the original device (e.g., when
+ diff_method='backprop') then the number of executions recorded is
+ correct."""
+ dev = qml.device("default.qubit", wires=2)
+
+ cache = {}
+
+ @qml.qnode(dev, diff_method="backprop", cache=cache)
+ def circuit():
+ qml.RY(0.345, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ for _ in range(15):
+ circuit()
+
+ # Although we've evaluated the QNode more than once, due to caching,
+ # there was one device execution recorded
+ assert dev.num_executions == 1
+ assert cache != {}
+
+ def test_num_exec_caching_device_swap_two_exec(self):
+ """Tests that if we swapped the original device (e.g., when
+ diff_method='backprop') then the number of executions recorded is
+ correct even with multiple QNode evaluations."""
+ dev = qml.device("default.qubit", wires=2)
+
+ cache = {}
+
+ @qml.qnode(dev, diff_method="backprop", cache=cache)
+ def circuit():
+ qml.RY(0.345, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ for _ in range(15):
+ circuit()
+
+ @qml.qnode(dev, diff_method="backprop", cache=cache)
+ def circuit():
+ qml.RZ(0.345, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ for _ in range(15):
+ circuit()
+
+ # Although we've evaluated the QNode several times, due to caching,
+ # there were two device executions recorded
+ assert dev.num_executions == 2
+ assert cache != {}
+
@pytest.mark.parametrize("diff_method", ["parameter-shift", "finite-diff"])
def test_single_expectation_value_with_argnum_one(self, diff_method, tol):
"""Tests correct output shape and evaluation for a QNode