Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when creating qnode with backprop and finite-shots #1588

Merged
merged 22 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ and requirements-ci.txt (unpinned). This latter would be used by the CI.

<h3>Breaking changes</h3>

* An error is raised during QNode creation when a user requests backpropagation on
a device with finite-shots.
[(#1588)](https://github.com/PennyLaneAI/pennylane/pull/1588)

* The class `qml.Interferometer` is deprecated and will be renamed `qml.InterferometerUnitary`
after one release cycle.
[(#1546)](https://github.com/PennyLaneAI/pennylane/pull/1546)
Expand Down
6 changes: 6 additions & 0 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ def _validate_backprop_method(device, interface):
qml.QuantumFunctionError: if the device does not support backpropagation, or the
interface provided is not compatible with the device
"""
if device.shots is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great 👍

raise qml.QuantumFunctionError(
"Devices with finite shots are incompatible with backpropogation. "
"Please set shots=None or chose a different diff_method."
)

# determine if the device supports backpropagation
backprop_interface = device.capabilities().get("passthru_interface", None)

Expand Down
11 changes: 6 additions & 5 deletions tests/devices/test_default_qubit_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,15 @@ def circuit():
expected = jnp.array([amplitude, 0, jnp.conj(amplitude), 0])
assert jnp.allclose(state, expected, atol=tol, rtol=0)

@pytest.mark.skip(reason="sampling doesnt work in backprop, fails with parameter-shift")
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
def test_sampling_with_jit(self):
"""Test that sampling works with a jax.jit"""

@jax.jit
def circuit(key):
dev = qml.device("default.qubit.jax", wires=1, shots=1000, prng_key=key)

@qml.qnode(dev, interface="jax", diff_method="backprop")
@qml.qnode(dev, interface="jax", diff_method=None)
def inner_circuit():
qml.Hadamard(0)
return qml.sample(qml.PauliZ(wires=0))
Expand All @@ -176,7 +177,7 @@ def test_sampling_op_by_op(self):
"""Test that op-by-op sampling works as a new user would expect"""
dev = qml.device("default.qubit.jax", wires=1, shots=1000)

@qml.qnode(dev, interface="jax", diff_method="backprop")
@qml.qnode(dev, interface="jax", diff_method=None)
def circuit():
qml.Hadamard(0)
return qml.sample(qml.PauliZ(wires=0))
Expand All @@ -191,7 +192,7 @@ def test_sampling_analytic_mode(self):
"""
dev = qml.device("default.qubit.jax", wires=1, shots=None)

@qml.qnode(dev, interface="jax", diff_method="backprop")
@qml.qnode(dev, interface="jax", diff_method=None)
def circuit():
return qml.sample(qml.PauliZ(wires=0))

Expand All @@ -206,7 +207,7 @@ def test_gates_dont_crash(self):
"""Test for gates that weren't covered by other tests."""
dev = qml.device("default.qubit.jax", wires=2, shots=1000)

@qml.qnode(dev, interface="jax", diff_method="backprop")
@qml.qnode(dev, interface="jax", diff_method=None)
def circuit():
qml.CRZ(0.0, wires=[0, 1])
qml.CRX(0.0, wires=[0, 1])
Expand All @@ -222,7 +223,7 @@ def test_diagonal_doesnt_crash(self):
"""Test that diagonal gates can be used."""
dev = qml.device("default.qubit.jax", wires=1, shots=1000)

@qml.qnode(dev, interface="jax", diff_method="backprop")
@qml.qnode(dev, interface="jax", diff_method=None)
def circuit():
qml.DiagonalQubitUnitary(np.array([1.0, 1.0]), wires=0)
return qml.sample(qml.PauliZ(wires=0))
Expand Down
47 changes: 4 additions & 43 deletions tests/devices/test_default_qubit_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def test_sample_observables(self):
shots = 100
dev = qml.device("default.qubit.tf", wires=2, shots=shots)

@qml.qnode(dev, diff_method="backprop", interface="tf")
@qml.qnode(dev, diff_method="best", interface="tf")
def circuit(a):
qml.RX(a, wires=0)
return qml.sample(qml.PauliZ(0))
Expand All @@ -1324,28 +1324,11 @@ def circuit(a):
assert res.shape == (shots,)
assert set(res.numpy()) == {-1, 1}

def test_sample_observables_non_differentiable(self):
"""Test that sampled observables cannot be differentiated."""
shots = 100
dev = qml.device("default.qubit.tf", wires=2, shots=shots)

@qml.qnode(dev, diff_method="backprop", interface="tf")
def circuit(a):
qml.RX(a, wires=0)
return qml.sample(qml.PauliZ(0))

a = tf.Variable(0.54)

with tf.GradientTape() as tape:
res = circuit(a)

assert tape.gradient(res, a) is None

def test_estimating_marginal_probability(self, tol):
"""Test that the probability of a subset of wires is accurately estimated."""
dev = qml.device("default.qubit.tf", wires=2, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="tf")
@qml.qnode(dev, diff_method=None, interface="tf")
def circuit():
qml.PauliX(0)
return qml.probs(wires=[0])
Expand All @@ -1361,7 +1344,7 @@ def test_estimating_full_probability(self, tol):
"""Test that the probability of a subset of wires is accurately estimated."""
dev = qml.device("default.qubit.tf", wires=2, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="tf")
@qml.qnode(dev, diff_method=None, interface="tf")
def circuit():
qml.PauliX(0)
qml.PauliX(1)
Expand All @@ -1379,7 +1362,7 @@ def test_estimating_expectation_values(self, tol):
of shots produces a numeric tensor"""
dev = qml.device("default.qubit.tf", wires=3, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="tf")
@qml.qnode(dev, diff_method=None, interface="tf")
def circuit(a, b):
qml.RX(a, wires=[0])
qml.RX(b, wires=[1])
Expand All @@ -1397,28 +1380,6 @@ def circuit(a, b):
# expected = [tf.cos(a), tf.cos(a) * tf.cos(b)]
# assert np.allclose(res, expected, atol=tol, rtol=0)

def test_estimating_expectation_values_not_differentiable(self, tol):
"""Test that finite shots results in non-differentiable QNodes"""

dev = qml.device("default.qubit.tf", wires=3, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="tf")
def circuit(a, b):
qml.RX(a, wires=[0])
qml.RX(b, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

a = tf.Variable(0.543)
b = tf.Variable(0.43)

with tf.GradientTape() as tape:
res = circuit(a, b)

assert isinstance(res, tf.Tensor)
grad = tape.gradient(res, [a, b])
assert grad == [None, None]


class TestHighLevelIntegration:
"""Tests for integration with higher level components of PennyLane."""
Expand Down
28 changes: 4 additions & 24 deletions tests/devices/test_default_qubit_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ def test_sample_observables(self):
shots = 100
dev = qml.device("default.qubit.torch", wires=2, shots=shots)

@qml.qnode(dev, diff_method="backprop", interface="torch")
@qml.qnode(dev, diff_method=None, interface="torch")
def circuit(a):
qml.RX(a, wires=0)
return qml.sample(qml.PauliZ(0))
Expand All @@ -1451,7 +1451,7 @@ def test_estimating_marginal_probability(self, tol):
"""Test that the probability of a subset of wires is accurately estimated."""
dev = qml.device("default.qubit.torch", wires=2, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="torch")
@qml.qnode(dev, diff_method=None, interface="torch")
def circuit():
qml.PauliX(0)
return qml.probs(wires=[0])
Expand All @@ -1467,7 +1467,7 @@ def test_estimating_full_probability(self, tol):
"""Test that the probability of a subset of wires is accurately estimated."""
dev = qml.device("default.qubit.torch", wires=2, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="torch")
@qml.qnode(dev, diff_method=None, interface="torch")
def circuit():
qml.PauliX(0)
qml.PauliX(1)
Expand All @@ -1485,7 +1485,7 @@ def test_estimating_expectation_values(self, tol):
of shots produces a numeric tensor"""
dev = qml.device("default.qubit.torch", wires=3, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="torch")
@qml.qnode(dev, diff_method=None, interface="torch")
def circuit(a, b):
qml.RX(a, wires=[0])
qml.RX(b, wires=[1])
Expand All @@ -1503,26 +1503,6 @@ def circuit(a, b):
# expected = [torch.cos(a), torch.cos(a) * torch.cos(b)]
# assert np.allclose(res, expected, atol=tol, rtol=0)

def test_estimating_expectation_values_not_differentiable(self, tol):
"""Test that finite shots results in non-differentiable QNodes"""

dev = qml.device("default.qubit.torch", wires=3, shots=1000)

@qml.qnode(dev, diff_method="backprop", interface="torch")
def circuit(a, b):
qml.RX(a, wires=[0])
qml.RX(b, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))

a = torch.tensor(0.543)
b = torch.tensor(0.43)

res = circuit(a, b)

with pytest.raises(RuntimeError):
res.backward()


class TestHighLevelIntegration:
"""Tests for integration with higher level components of PennyLane."""
Expand Down
9 changes: 9 additions & 0 deletions tests/tape/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def test_validate_device_method(self, monkeypatch):
assert interface == "interface"
assert device is dev

@pytest.mark.parametrize("interface", ("autograd", "torch", "tensorflow", "jax"))
def test_validate_backprop_method_finite_shots(self, interface):
"""Tests that an error is raised for backpropagation with finite shots."""

dev = qml.device("default.qubit", wires=1, shots=3)

with pytest.raises(qml.QuantumFunctionError, match="Devices with finite shots"):
Copy link
Contributor

@rmoyard rmoyard Sep 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test 💯

QNode._validate_backprop_method(dev, interface)

def test_validate_backprop_method_invalid_device(self):
"""Test that the method for validating the backprop diff method
tape raises an exception if the device does not support backprop."""
Expand Down