Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

default.qubit returns the new DefaultQubit device #4436

Merged
merged 112 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
0cd4c04
change default.qubit entrypoint to DQ2; copy tests
timmysilv Jul 31, 2023
de4e5fa
remove tests for old device; use explicit constructor for DQ1
timmysilv Jul 31, 2023
42ea50f
a few test fixes
timmysilv Aug 3, 2023
840f30a
add is_state_batched to mp.process_state; probs supports jitting
timmysilv Aug 3, 2023
4d295fb
copy from DQ (wires init and Pow support); fix tests/ops
timmysilv Aug 3, 2023
efe288a
fix tests/test_*.py
timmysilv Aug 4, 2023
4f5c7b6
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 4, 2023
5f7f875
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 8, 2023
e756ed0
only squeeze the 1-wire dim when sampling a pauli obs; tape test fixes
timmysilv Aug 9, 2023
cd0c260
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 9, 2023
d725ccc
add old default.qubit as default.qubit.legacy
timmysilv Aug 9, 2023
e1f533c
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 10, 2023
6abc713
fix the transforms module
timmysilv Aug 10, 2023
55b1ac4
use the legacy short name instead of the constructor
timmysilv Aug 10, 2023
25bb1f0
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 11, 2023
4e71b68
don't import DQ2 at top level
timmysilv Aug 11, 2023
90400bd
fix qinfo module tests
timmysilv Aug 13, 2023
d88e053
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 13, 2023
e7d54d3
Revert "add is_state_batched to mp.process_state; probs supports jitt…
timmysilv Aug 14, 2023
b3a07f2
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 14, 2023
6c7f69e
put back adjoint metric tensor change with new fix
timmysilv Aug 14, 2023
f7ae4d8
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 14, 2023
b508abb
use legacy device for legacy interface tests
timmysilv Aug 14, 2023
f3af961
xfail test for amplitudeembedding found mid-circuit
timmysilv Aug 15, 2023
3ae33b5
test amplitudeembedding using new device
timmysilv Aug 15, 2023
cfc016e
use dev.execute in gradients; adjoint_diff uses legacy
timmysilv Aug 15, 2023
6576771
wip: start testing old and new interface devices
timmysilv Aug 16, 2023
043d7d0
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 18, 2023
c59ee29
finish touching up gradients
timmysilv Aug 18, 2023
8c15eea
rename measurement test devices; copy for new device api
timmysilv Aug 18, 2023
025c8e4
fix measurement tests expect classical shadow
timmysilv Aug 21, 2023
10ac41d
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 22, 2023
6e5cafd
put back legacy qubit device test; copy batch test to DQ2
timmysilv Aug 22, 2023
a60911c
fix classical shadow tests; change ci to use legacy
timmysilv Aug 22, 2023
29ba4f4
fix some tests; drawer support for dq2
timmysilv Aug 23, 2023
15c68d7
device setup changes
timmysilv Aug 24, 2023
f5b1c58
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 24, 2023
16e4381
more test changes
timmysilv Aug 25, 2023
440b697
map wires for pulse stuff; more test fixes
timmysilv Aug 25, 2023
bdec875
temp for ci: use legacy in templates
timmysilv Aug 25, 2023
4e212e1
Revert "temp for ci: use legacy in templates"
timmysilv Aug 25, 2023
a79fb06
fix templates
timmysilv Aug 25, 2023
54f38dd
more test fixes
timmysilv Aug 25, 2023
2a800d0
fix test with nonsense observable
timmysilv Aug 25, 2023
6d3da6e
slight additional test coverage
timmysilv Aug 25, 2023
519a96b
use old device for now on torch qcut test
timmysilv Aug 26, 2023
ab07983
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 28, 2023
4a58615
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 29, 2023
6f9e33b
change some tests to use DQ2
timmysilv Aug 29, 2023
bbbb999
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 29, 2023
7e79aec
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 31, 2023
c5cd3f5
test device_batch_transform=False on legacy
timmysilv Aug 31, 2023
736d9df
move defer_measurements qnode tests to legacy device for now
timmysilv Aug 31, 2023
52eefdd
fix numeric_type stuff
timmysilv Aug 31, 2023
a09ceeb
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Aug 31, 2023
b3d858c
fix tape tests
timmysilv Sep 1, 2023
ad04b2a
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 1, 2023
72b866d
identity to make test work on new device
timmysilv Sep 1, 2023
f4d9295
fix sample test for numeric_type
timmysilv Sep 1, 2023
98ba5cc
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 5, 2023
18e0971
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 7, 2023
5a76b1e
remove old dq2 measurements tests
timmysilv Sep 7, 2023
ca03c82
put back other things already on master
timmysilv Sep 7, 2023
a5f2f59
fix broken tests
timmysilv Sep 7, 2023
37aa6b2
fix fisher test
timmysilv Sep 7, 2023
3ddbc0e
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 7, 2023
8b46f45
actually use new device
timmysilv Sep 7, 2023
ae4f13b
remove unneeded skips from tests
timmysilv Sep 8, 2023
6cb266c
try to run qnn tests with new device
timmysilv Sep 11, 2023
b869962
add xfail for now
timmysilv Sep 11, 2023
b910763
check out qinfo transforms file from master
timmysilv Sep 11, 2023
8723238
add back DQ2 support
timmysilv Sep 11, 2023
a4dd56a
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 11, 2023
4a92d44
clean up tests to what they used to be with latest changes
timmysilv Sep 11, 2023
c6eb716
fix more tests with latest changes to master
timmysilv Sep 11, 2023
5feaa9f
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 12, 2023
b261500
fix other template tests
timmysilv Sep 13, 2023
3213091
copy return types test for DQ2
timmysilv Sep 13, 2023
7d74ca4
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 13, 2023
c45997e
fix preprocess changes
timmysilv Sep 13, 2023
d2e0925
patch in device wires for jax jit test
timmysilv Sep 13, 2023
c49fa53
fix drawer usage of DQ2 preprocess
timmysilv Sep 13, 2023
5c9faeb
check out amplitude from master
timmysilv Sep 13, 2023
b1784bb
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 13, 2023
4d8fe20
see if torch qcut issue magically fixed in CI
timmysilv Sep 13, 2023
d99dc18
Revert "see if torch qcut issue magically fixed in CI"
timmysilv Sep 13, 2023
68ed406
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 14, 2023
dea780f
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 14, 2023
e2a6eb7
little improvements and coverage
timmysilv Sep 15, 2023
0ae2acd
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 15, 2023
df680e7
don't xfail tests unnecessarily
timmysilv Sep 15, 2023
3b55ec2
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 15, 2023
52f7890
fix up some tests after merge
timmysilv Sep 15, 2023
0b49d18
fix unrelated mpl bug
timmysilv Sep 15, 2023
31a1060
reject backprop with interface=None; fix docs tests
timmysilv Sep 15, 2023
253e61d
fix DQ2 backprop validation test
timmysilv Sep 15, 2023
274c9fa
remove old note on return system
timmysilv Sep 18, 2023
8c10163
skip Snapshot for adjoint execution backwards pass
timmysilv Sep 19, 2023
75cc45d
rename default.qubit.2 to default.qubit
timmysilv Sep 19, 2023
9c8c24a
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 19, 2023
e054db0
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 19, 2023
0d2ecbf
test qnn and hardwareHamiltonian with new device
timmysilv Sep 19, 2023
0f23793
add link to changelog entry
timmysilv Sep 19, 2023
37e5bea
fix state tests to not check custom dtype
timmysilv Sep 19, 2023
95ba06c
more state dtype fixes
timmysilv Sep 19, 2023
2de5990
autoray dispatch to tf.size
timmysilv Sep 19, 2023
aef00fa
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 19, 2023
9142944
codecov happiness
timmysilv Sep 19, 2023
ec89baa
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 19, 2023
0c10ef3
Merge branch 'master' into replace-dq-entry-with-dq2
timmysilv Sep 20, 2023
032cd3b
also add to breaking changes
timmysilv Sep 20, 2023
0fa4047
Merge branch 'master' into replace-dq-entry-with-dq2
albi3ro Sep 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/interface-unit-tests.yml
Expand Up @@ -263,9 +263,9 @@ jobs:
strategy:
matrix:
config:
- device: default.qubit
- device: default.qubit.legacy
shots: None
- device: default.qubit
- device: default.qubit.legacy
shots: 10000
# - device: default.qubit.tf
# shots: None
Expand Down
11 changes: 9 additions & 2 deletions doc/releases/changelog-dev.md
Expand Up @@ -17,9 +17,10 @@
* Quantum information transforms are updated to the new transform program system.
[(#4569)](https://github.com/PennyLaneAI/pennylane/pull/4569)

* `qml.devices.DefaultQubit` now implements the new device API. The old version of `default.qubit`
is still accessible via `qml.devices.DefaultQubitLegacy`, or via short name `default.qubit.legacy`.
* `default.qubit` now implements the new device API. The old version of the device is still
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
accessible by the short name `default.qubit.legacy`, or directly via `qml.devices.DefaultQubitLegacy`.
[(#4594)](https://github.com/PennyLaneAI/pennylane/pull/4594)
[(#4436)](https://github.com/PennyLaneAI/pennylane/pull/4436)

<h3>Improvements 🛠</h3>

Expand Down Expand Up @@ -184,6 +185,12 @@
which effectively just called `marginal_prob` with `np.abs(state) ** 2`.
[(#4602)](https://github.com/PennyLaneAI/pennylane/pull/4602)

* `default.qubit` now implements the new device API. If you initialize a device
with `qml.device("default.qubit")`, all functions and properties that were tied to the old
device API will no longer be on the device. The legacy version can still be accessed with
`qml.device("default.qubit.legacy", wires=n_wires)`.
[(#4436)](https://github.com/PennyLaneAI/pennylane/pull/4436)

<h3>Deprecations 👋</h3>

* The ``prep`` keyword argument in ``QuantumScript`` is deprecated and will be removed from `QuantumScript`.
Expand Down
3 changes: 2 additions & 1 deletion pennylane/devices/default_qubit.py
Expand Up @@ -146,7 +146,7 @@ def f(x):
@property
def name(self):
"""The name of the device."""
return "default.qubit.2"
return "default.qubit"

# pylint:disable = too-many-arguments
def __init__(
Expand Down Expand Up @@ -193,6 +193,7 @@ def supports_derivatives(
if (
execution_config.gradient_method == "backprop"
and execution_config.device_options.get("max_workers", self._max_workers) is None
and execution_config.interface is not None
):
return True

Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_qubit_legacy.py
Expand Up @@ -104,8 +104,8 @@ class DefaultQubitLegacy(QubitDevice):
returns analytical results.
"""

name = "Default qubit PennyLane plugin"
short_name = "default.qubit"
name = "Default qubit PennyLane plugin (Legacy)"
short_name = "default.qubit.legacy"
pennylane_requires = __version__
version = __version__
author = "Xanadu Inc."
Expand Down
3 changes: 0 additions & 3 deletions pennylane/devices/device_api.py
Expand Up @@ -41,9 +41,6 @@ class Device(abc.ABC):
"""A device driver that can control one or more backends. A backend can be either a physical
Quantum Processing Unit or a virtual one such as a simulator.

Device drivers should be configured to run under :func:`~.enable_return`, the newer
return shape specification, as the old return shape specification is deprecated.

Only the ``execute`` method must be defined to construct a device driver.

.. details::
Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/qubit/adjoint_jacobian.py
Expand Up @@ -75,6 +75,8 @@ def adjoint_jacobian(tape: QuantumTape, state=None):
param_number = len(tape.get_parameters(trainable_only=False, operations_only=True)) - 1
trainable_param_number = len(tape.trainable_params) - 1
for op in reversed(tape.operations[tape.num_preps :]):
if isinstance(op, qml.Snapshot):
continue
adj_op = qml.adjoint(op)
ket = apply_operation(adj_op, ket)

Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Expand Up @@ -454,6 +454,7 @@ def _cond_tf(pred, true_fn, false_fn, args):
"vander",
lambda *args, **kwargs: _i("tf").experimental.numpy.vander(*args, **kwargs),
)
ar.register_function("tensorflow", "size", lambda x: _i("tf").size(x))


# -------------------------------- Torch --------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion pennylane/measurements/state.py
Expand Up @@ -222,4 +222,4 @@ def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
wire_map = dict(zip(wire_order, range(len(wire_order))))
mapped_wires = [wire_map[w] for w in self.wires]
return qml.math.reduce_statevector(state, indices=mapped_wires, c_dtype=state.dtype)
return qml.math.reduce_statevector(state, indices=mapped_wires)
12 changes: 7 additions & 5 deletions pennylane/optimize/spsa.py
Expand Up @@ -262,11 +262,13 @@ def compute_grad(self, objective_fn, args, kwargs):
yminus = objective_fn(*thetaminus, **kwargs)
try:
# pylint: disable=protected-access
shots = (
Shots(objective_fn.device._raw_shot_sequence)
if objective_fn.device.shot_vector is not None
else Shots(None)
)
dev_shots = objective_fn.device.shots
if isinstance(dev_shots, Shots):
shots = dev_shots if dev_shots.has_partitioned_shots else Shots(None)
elif objective_fn.device.shot_vector is not None:
shots = Shots(objective_fn.device._raw_shot_sequence) # pragma: no cover
else:
shots = Shots(None)
if np.prod(objective_fn.func(*args).shape(objective_fn.device, shots)) > 1:
raise ValueError(
"The objective function must be a scalar function for the gradient "
Expand Down
4 changes: 3 additions & 1 deletion pennylane/qnode.py
Expand Up @@ -705,7 +705,9 @@ def _validate_backprop_method(device, interface, shots=None):
config = qml.devices.ExecutionConfig(gradient_method="backprop", interface=interface)
if device.supports_derivatives(config):
return "backprop", {}, device
raise qml.QuantumFunctionError(f"Device {device.name} does not support backprop")
raise qml.QuantumFunctionError(
f"Device {device.name} does not support backprop with {config}"
)

mapped_interface = INTERFACE_MAP.get(interface, interface)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -47,7 +47,7 @@
# TODO: rename entry point 'pennylane.plugins' to 'pennylane.devices'.
# This requires a rename in the setup file of all devices, and is best done during another refactor
"pennylane.plugins": [
"default.qubit = pennylane.devices:DefaultQubitLegacy",
"default.qubit = pennylane.devices:DefaultQubit",
"default.qubit.legacy = pennylane.devices:DefaultQubitLegacy",
"default.gaussian = pennylane.devices:DefaultGaussian",
"default.qubit.tf = pennylane.devices.default_qubit_tf:DefaultQubitTF",
Expand Down
59 changes: 57 additions & 2 deletions tests/devices/experimental/test_default_qubit_2.py
Expand Up @@ -27,7 +27,7 @@

def test_name():
"""Tests the name of DefaultQubit."""
assert DefaultQubit().name == "default.qubit.2"
assert DefaultQubit().name == "default.qubit"


def test_shots():
Expand Down Expand Up @@ -200,6 +200,56 @@ def test_tracking_resources(self):
assert len(tracker.history["resources"]) == 1
assert tracker.history["resources"][0] == expected_resources

def test_tracking_batched_execution(self):
"""Test the number of times the device is executed over a QNode's
lifetime is tracked by the device's tracker."""

dev_1 = qml.device("default.qubit", wires=2)

def circuit_1(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliX(1))

node_1 = qml.QNode(circuit_1, dev_1)
num_evals_1 = 10

with qml.Tracker(dev_1, persistent=True) as tracker1:
for _ in range(num_evals_1):
node_1(0.432, np.array([0.12, 0.5, 3.2]))
assert tracker1.totals["executions"] == num_evals_1

# test a second instance of a default qubit device
dev_2 = qml.device("default.qubit", wires=2)

def circuit_2(x):
qml.RX(x, wires=[0])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliX(1))

node_2 = qml.QNode(circuit_2, dev_2)
num_evals_2 = 5

with qml.Tracker(dev_2) as tracker2:
for _ in range(num_evals_2):
node_2(np.array([0.432, 0.61, 8.2]))
assert tracker2.totals["executions"] == num_evals_2

# test a new circuit on an existing instance of a qubit device
def circuit_3(y):
qml.RY(y, wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0) @ qml.PauliX(1))

node_3 = qml.QNode(circuit_3, dev_1)
num_evals_3 = 7

with tracker1:
for _ in range(num_evals_3):
node_3(np.array([0.12, 1.214]))
assert tracker1.totals["executions"] == num_evals_1 + num_evals_3


# pylint: disable=too-few-public-methods
class TestPreprocessing:
Expand Down Expand Up @@ -302,7 +352,7 @@ def test_supports_backprop(self):
assert dev.supports_jvp() is True
assert dev.supports_vjp() is True

config = ExecutionConfig(gradient_method="backprop")
config = ExecutionConfig(gradient_method="backprop", interface="auto")
assert dev.supports_derivatives(config) is True
assert dev.supports_jvp(config) is True
assert dev.supports_vjp(config) is True
Expand All @@ -317,6 +367,11 @@ def test_supports_backprop(self):
assert dev.supports_jvp(config) is False
assert dev.supports_vjp(config) is False

config = ExecutionConfig(gradient_method="backprop", interface=None)
assert dev.supports_derivatives(config) is False
assert dev.supports_jvp(config) is False
assert dev.supports_vjp(config) is False

def test_supports_adjoint(self):
"""Test that DefaultQubit says that it supports adjoint differentiation."""
dev = DefaultQubit()
Expand Down
4 changes: 2 additions & 2 deletions tests/devices/test_default_qubit_autograd.py
Expand Up @@ -487,7 +487,7 @@ def circuit(a, b):

def cost(a, b):
prob_wire_1 = circuit(a, b)
return prob_wire_1[1] - prob_wire_1[0]
return prob_wire_1[1] - prob_wire_1[0] # pylint:disable=unsubscriptable-object

res = cost(a, b)
expected = -np.cos(a) * np.cos(b)
Expand All @@ -513,7 +513,7 @@ def circuit(a, b):

def cost(a, b):
prob_wire_1 = circuit(a, b)
return prob_wire_1[:, 1] - prob_wire_1[:, 0]
return prob_wire_1[:, 1] - prob_wire_1[:, 0] # pylint:disable=unsubscriptable-object

res = cost(a, b)
expected = -np.cos(a) * np.cos(b)
Expand Down
6 changes: 3 additions & 3 deletions tests/devices/test_default_qubit_legacy.py
Expand Up @@ -643,7 +643,7 @@ def test_apply_errors_qubit_state_vector(self, qubit_device_2_wires):
with pytest.raises(
DeviceError,
match="Operation StatePrep cannot be used after other Operations have already been applied "
"on a default.qubit device.",
"on a default.qubit.legacy device.",
):
qubit_device_2_wires.reset()
qubit_device_2_wires.apply(
Expand All @@ -664,7 +664,7 @@ def test_apply_errors_basis_state(self, qubit_device_2_wires):
with pytest.raises(
DeviceError,
match="Operation BasisState cannot be used after other Operations have already been applied "
"on a default.qubit device.",
"on a default.qubit.legacy device.",
):
qubit_device_2_wires.reset()
qubit_device_2_wires.apply(
Expand Down Expand Up @@ -2091,7 +2091,7 @@ def test_apply_parametrized_evolution_raises_error(self):
param_ev = qml.evolve(ParametrizedHamiltonian([1], [qml.PauliX(0)]))
with pytest.raises(
NotImplementedError,
match="The device default.qubit cannot execute a ParametrizedEvolution operation",
match="The device default.qubit.legacy cannot execute a ParametrizedEvolution operation",
):
self.dev._apply_parametrized_evolution(state=self.state, operation=param_ev)

Expand Down
4 changes: 2 additions & 2 deletions tests/devices/test_default_qubit_legacy_broadcasting.py
Expand Up @@ -462,7 +462,7 @@ def test_apply_errors_qubit_state_vector_broadcasted(self, qubit_device_2_wires)
with pytest.raises(
DeviceError,
match="Operation StatePrep cannot be used after other Operations have already been applied "
"on a default.qubit device.",
"on a default.qubit.legacy device.",
):
qubit_device_2_wires.apply([qml.RZ(0.5, wires=[0]), vec])

Expand Down Expand Up @@ -491,7 +491,7 @@ def test_apply_errors_basis_state_broadcasted(self, qubit_device_2_wires):
with pytest.raises(
DeviceError,
match="Operation BasisState cannot be used after other Operations have already been applied "
"on a default.qubit device.",
"on a default.qubit.legacy device.",
):
qubit_device_2_wires.apply([vec])

Expand Down
14 changes: 8 additions & 6 deletions tests/devices/test_default_qubit_tf.py
Expand Up @@ -1873,7 +1873,7 @@ def circuit(a, b):
# get the probability of wire 1
prob_wire_1 = circuit(a, b)
# compute Prob(|1>_1) - Prob(|0>_1)
res = prob_wire_1[1] - prob_wire_1[0]
res = prob_wire_1[1] - prob_wire_1[0] # pylint:disable=unsubscriptable-object

expected = -tf.cos(a) * tf.cos(b)
assert np.allclose(res, expected, atol=tol, rtol=0)
Expand All @@ -1900,7 +1900,7 @@ def circuit(a, b):
# get the probability of wire 1
prob_wire_1 = circuit(a, b)
# compute Prob(|1>_1) - Prob(|0>_1)
res = prob_wire_1[:, 1] - prob_wire_1[:, 0]
res = prob_wire_1[:, 1] - prob_wire_1[:, 0] # pylint:disable=unsubscriptable-object

expected = -tf.cos(a) * tf.cos(b)
assert np.allclose(res, expected, atol=tol, rtol=0)
Expand Down Expand Up @@ -1938,6 +1938,7 @@ def circuit(a, b):
[-0.5 * np.sin(a) * (np.cos(b) + 1), 0.5 * np.sin(b) * (1 - np.cos(a))]
)

# pylint:disable=no-member
assert np.allclose(res.numpy(), expected_cost, atol=tol, rtol=0)

res = tape.gradient(res, [a_tf, b_tf])
Expand Down Expand Up @@ -1971,6 +1972,7 @@ def circuit(a, b):
[-0.5 * np.sin(a) * (np.cos(b) + 1), 0.5 * np.sin(b) * (1 - np.cos(a))]
)

# pylint:disable=no-member
assert np.allclose(res.numpy(), expected_cost, atol=tol, rtol=0)

jac = tape.jacobian(res, [a_tf, b_tf])
Expand Down Expand Up @@ -2112,8 +2114,8 @@ def circuit(a):
res = circuit(a)

assert isinstance(res, tf.Tensor)
assert res.shape == (shots,)
assert set(res.numpy()) == {-1, 1}
assert res.shape == (shots,) # pylint:disable=comparison-with-callable
assert set(res.numpy()) == {-1, 1} # pylint:disable=no-member

def test_estimating_marginal_probability(self, tol):
"""Test that the probability of a subset of wires is accurately estimated."""
Expand Down Expand Up @@ -2190,8 +2192,8 @@ def circuit(a):
res = circuit(a)

assert isinstance(res, tf.Tensor)
assert res.shape == (3, shots)
assert set(res.numpy().flat) == {-1, 1}
assert res.shape == (3, shots) # pylint:disable=comparison-with-callable
assert set(res.numpy().flat) == {-1, 1} # pylint:disable=no-member

@pytest.mark.parametrize("batch_size", [2, 3])
def test_estimating_marginal_probability_broadcasted(self, batch_size, tol):
Expand Down