diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index befcbfa123a..56ed89c9498 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -122,7 +122,7 @@ * New basis sets, `6-311g` and `CC-PVDZ`, are added to the qchem basis set repo. [#3279](https://github.com/PennyLaneAI/pennylane/pull/3279) -* Added a `pauli_decompose()` which takes a hermitian matrix and decomposes it in the +* Added a `pauli_decompose()` which takes a hermitian matrix and decomposes it in the Pauli basis, returning it either as a `Hamiltonian` or `PauliSentence` instance. [(#3384)](https://github.com/PennyLaneAI/pennylane/pull/3384) @@ -290,7 +290,7 @@ Replaces `qml.transforms.make_tape` with `make_qscript`. [(#3429)](https://github.com/PennyLaneAI/pennylane/pull/3429) -* Add a UserWarning when creating a `Tensor` object with overlapping wires, +* Add a UserWarning when creating a `Tensor` object with overlapping wires, informing that this can in some cases lead to undefined behaviour. [(#3459)](https://github.com/PennyLaneAI/pennylane/pull/3459) @@ -425,6 +425,39 @@ [-0.38466667, -0.19233333, 0. , 0. , 0.19233333]])> ``` +* The JAX-JIT interface now supports gradient transforms and device gradient execution in `backward` mode with the new + return types system. + [(#3235)](https://github.com/PennyLaneAI/pennylane/pull/3235) + + ```python + import pennylane as qml + import jax + from jax import numpy as jnp + + jax.config.update("jax_enable_x64", True) + + qml.enable_return() + + dev = qml.device("lightning.qubit", wires=2) + + @jax.jit + @qml.qnode(dev, interface="jax-jit", diff_method="parameter-shift") + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1)) + + a, b = jnp.array(1.0), jnp.array(2.0) + ``` + + ```pycon + >>> jax.jacobian(circuit, argnums=[0, 1])(a, b) + ((DeviceArray(0.35017549, dtype=float64, weak_type=True), + DeviceArray(-0.4912955, dtype=float64, weak_type=True)), + (DeviceArray(5.55111512e-17, dtype=float64, weak_type=True), + DeviceArray(0., dtype=float64, weak_type=True))) + ``` + * Updated `qml.transforms.split_non_commuting` to support the new return types. [(#3414)](https://github.com/PennyLaneAI/pennylane/pull/3414) diff --git a/pennylane/interfaces/execution.py b/pennylane/interfaces/execution.py index 77f8ece0782..2394e20b9a5 100644 --- a/pennylane/interfaces/execution.py +++ b/pennylane/interfaces/execution.py @@ -740,7 +740,10 @@ def _get_jax_execute_fn(interface: str, tapes: Sequence[QuantumTape]): interface = get_jax_interface_name(tapes) if interface == "jax-jit": - from .jax_jit import execute as _execute + if qml.active_return(): + from .jax_jit_tuple import execute_tuple as _execute + else: + from .jax_jit import execute as _execute else: if qml.active_return(): from .jax import execute_new as _execute diff --git a/pennylane/interfaces/jax_jit_tuple.py b/pennylane/interfaces/jax_jit_tuple.py new file mode 100644 index 00000000000..dbdb82bc306 --- /dev/null +++ b/pennylane/interfaces/jax_jit_tuple.py @@ -0,0 +1,299 @@ +# Copyright 2018-2022 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains functions for adding the JAX interface +to a PennyLane Device class. +""" + +# pylint: disable=too-many-arguments +import jax +import jax.numpy as jnp + +import pennylane as qml +from pennylane.interfaces import InterfaceUnsupportedError +from pennylane.interfaces.jax import _compute_jvps +from pennylane.interfaces.jax_jit import _validate_jax_version, _numeric_type_to_dtype + +dtype = jnp.float64 + + +def _create_shape_dtype_struct(tape, device): + """Auxiliary function for creating the shape and dtype object structure + given a tape.""" + + def process_single_shape(shape, tape_dtype): + return jax.ShapeDtypeStruct(tuple(shape), tape_dtype) + + num_measurements = len(tape.measurements) + shape = tape.shape(device) + if num_measurements == 1: + tape_dtype = _numeric_type_to_dtype(tape.numeric_type) + return process_single_shape(shape, tape_dtype) + + tape_dtype = tuple(_numeric_type_to_dtype(elem) for elem in tape.numeric_type) + return tuple(process_single_shape(s, d) for s, d in zip(shape, tape_dtype)) + + +def _tapes_shape_dtype_tuple(tapes, device): + """Auxiliary function for defining the jax.ShapeDtypeStruct objects given + the tapes and the device. + + The jax.pure_callback function expects jax.ShapeDtypeStruct objects to + describe the output of the function call. + """ + shape_dtypes = [] + + for t in tapes: + shape_and_dtype = _create_shape_dtype_struct(t, device) + shape_dtypes.append(shape_and_dtype) + return shape_dtypes + + +def _jac_shape_dtype_tuple(tapes, device): + """Auxiliary function for defining the jax.ShapeDtypeStruct objects when + computing the jacobian associated with the tapes and the device. + + The jax.pure_callback function expects jax.ShapeDtypeStruct objects to + describe the output of the function call. + """ + shape_dtypes = [] + + for t in tapes: + shape_and_dtype = _create_shape_dtype_struct(t, device) + + if len(t.trainable_params) == 1: + shape_dtypes.append(shape_and_dtype) + else: + num_measurements = len(t.measurements) + if num_measurements == 1: + s = [shape_and_dtype for _ in range(len(t.trainable_params))] + shape_dtypes.append(tuple(s)) + else: + s = [tuple(_s for _ in range(len(t.trainable_params))) for _s in shape_and_dtype] + shape_dtypes.append(tuple(s)) + + if len(tapes) == 1: + return shape_dtypes[0] + + return tuple(shape_dtypes) + + +def execute_tuple(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=1): + """Execute a batch of tapes with JAX parameters on a device. + + Args: + tapes (Sequence[.QuantumTape]): batch of tapes to execute + device (.Device): Device to use to execute the batch of tapes. + If the device does not provide a ``batch_execute`` method, + by default the tapes will be executed in serial. + execute_fn (callable): The execution function used to execute the tapes + during the forward pass. This function must return a tuple ``(results, jacobians)``. + If ``jacobians`` is an empty list, then ``gradient_fn`` is used to + compute the gradients during the backwards pass. + gradient_kwargs (dict): dictionary of keyword arguments to pass when + determining the gradients of tapes + gradient_fn (callable): the gradient function to use to compute quantum gradients + _n (int): a positive integer used to track nesting of derivatives, for example + if the nth-order derivative is requested. + max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies + the maximum order of derivatives to support. Increasing this value allows + for higher order derivatives to be extracted, at the cost of additional + (classical) computational overhead during the backwards pass. + + Returns: + list[list[float]]: A nested list of tape results. Each element in + the returned list corresponds in order to the provided tapes. + """ + # pylint: disable=unused-argument + if max_diff > 1: + raise InterfaceUnsupportedError( + "The JAX-JIT interface only supports first order derivatives." + ) + + if any( + m.return_type in (qml.measurements.Counts, qml.measurements.AllCounts) + for t in tapes + for m in t.measurements + ): + # Obtaining information about the shape of the Counts measurements is + # not implemeneted and is required for the callback logic + raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.") + + _validate_jax_version() + + for tape in tapes: + # set the trainable parameters + params = tape.get_parameters(trainable_only=False) + tape.trainable_params = qml.math.get_trainable_indices(params) + + parameters = tuple(list(t.get_parameters()) for t in tapes) + + if gradient_fn is None: + return _execute_fwd_tuple( + parameters, + tapes=tapes, + device=device, + execute_fn=execute_fn, + gradient_kwargs=gradient_kwargs, + _n=_n, + ) + + return _execute_bwd_tuple( + parameters, + tapes=tapes, + device=device, + execute_fn=execute_fn, + gradient_fn=gradient_fn, + gradient_kwargs=gradient_kwargs, + _n=_n, + ) + + +def _execute_bwd_tuple( + params, + tapes=None, + device=None, + execute_fn=None, + gradient_fn=None, + gradient_kwargs=None, + _n=1, +): # pylint: disable=dangerous-default-value,unused-argument + + # Copy a given tape with operations and set parameters + def _copy_tape(t, a): + tc = t.copy(copy_operations=True) + tc.set_parameters(a) + return tc + + @jax.custom_jvp + def execute_wrapper(params): + def wrapper(p): + """Compute the forward pass.""" + new_tapes = [_copy_tape(t, a) for t, a in zip(tapes, p)] + with qml.tape.Unwrap(*new_tapes): + res, _ = execute_fn(new_tapes, **gradient_kwargs) + return res + + shape_dtype_structs = _tapes_shape_dtype_tuple(tapes, device) + res = jax.pure_callback(wrapper, shape_dtype_structs, params) + return res + + @execute_wrapper.defjvp + def execute_wrapper_jvp(primals, tangents): + # pylint: disable=unused-variable + params = primals[0] + multi_measurements = [len(tape.measurements) > 1 for tape in tapes] + + # Execution: execute the function first + evaluation_results = execute_wrapper(params) + + # Backward: branch off based on the gradient function is a device method. + if isinstance(gradient_fn, qml.gradients.gradient_transform): + # Gradient function is a gradient transform + + res_from_callback = _grad_transform_jac_via_callback(params, device) + if len(tapes) == 1: + res_from_callback = [res_from_callback] + + jvps = _compute_jvps(res_from_callback, tangents[0], multi_measurements) + else: + # Gradient function is a device method + res_from_callback = _device_method_jac_via_callback(params, device) + if len(tapes) == 1: + res_from_callback = [res_from_callback] + + jvps = _compute_jvps(res_from_callback, tangents[0], multi_measurements) + + return evaluation_results, jvps + + def _grad_transform_jac_via_callback(params, device): + """Perform a callback to compute the jacobian of tapes using a gradient transform (e.g., parameter-shift or + finite differences grad transform). + + Note: we are not using the batch_jvp pipeline and rather split the steps of unwrapping tapes and the JVP + computation because: + + 1. Tape unwrapping has to happen in the callback: otherwise jitting is broken and Tracer objects + are converted to NumPy, something that raises an error; + + 2. Passing in the tangents as an argument to the wrapper function called by the jax.pure_callback raises an + error (as of jax and jaxlib 0.3.25): + ValueError: Pure callbacks do not support transpose. Please use jax.custom_vjp to use callbacks while + taking gradients. + + Solution: Use the callback to compute the jacobian and then separately compute the JVP using the + tangent. + """ + + def wrapper(params): + new_tapes = [_copy_tape(t, a) for t, a in zip(tapes, params)] + + with qml.tape.Unwrap(*new_tapes): + all_jacs = [] + for new_t in new_tapes: + jvp_tapes, res_processing_fn = gradient_fn( + new_t, shots=device.shots, **gradient_kwargs + ) + jacs = execute_fn(jvp_tapes)[0] + jacs = res_processing_fn(jacs) + all_jacs.append(jacs) + + if len(all_jacs) == 1: + return all_jacs[0] + + return all_jacs + + expected_shapes = _jac_shape_dtype_tuple(tapes, device) + res = jax.pure_callback(wrapper, expected_shapes, params) + return res + + def _device_method_jac_via_callback(params, device): + """Perform a callback to compute the jacobian of tapes using a device method (e.g., adjoint). + + Note: we are not using the batch_jvp pipeline and rather split the steps of unwrapping tapes and the JVP + computation because: + + 1. Tape unwrapping has to happen in the callback: otherwise jitting is broken and Tracer objects + are converted to NumPy, something that raises an error; + + 2. Passing in the tangents as an argument to the wrapper function called by the jax.pure_callback raises an + error (as of jax and jaxlib 0.3.25): + ValueError: Pure callbacks do not support transpose. Please use jax.custom_vjp to use callbacks while + taking gradients. + + Solution: Use the callback to compute the jacobian and then separately compute the JVP using the + tangent. + """ + + def wrapper(params): + new_tapes = [_copy_tape(t, a) for t, a in zip(tapes, params)] + with qml.tape.Unwrap(*new_tapes): + return gradient_fn(new_tapes, **gradient_kwargs) + + shape_dtype_structs = _jac_shape_dtype_tuple(tapes, device) + return jax.pure_callback(wrapper, shape_dtype_structs, params) + + return execute_wrapper(params) + + +# The execute function in forward mode +def _execute_fwd_tuple( + params, + tapes=None, + device=None, + execute_fn=None, + gradient_kwargs=None, + _n=1, +): # pylint: disable=dangerous-default-value,unused-argument + raise NotImplementedError("Forward mode execution for device gradients is not yet implemented.") diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 80d3d60f19f..4c8af9fc14c 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -366,7 +366,14 @@ def function(x): import jax from jax.interpreters.partial_eval import DynamicJaxprTracer - if isinstance(tensor, (jax.ad.JVPTracer, jax.interpreters.batching.BatchTracer)): + if isinstance( + tensor, + ( + jax.ad.JVPTracer, + jax.interpreters.batching.BatchTracer, + jax.interpreters.partial_eval.JaxprTracer, + ), + ): # Tracer objects will be used when computing gradients or applying transforms. # If the value of the tracer is known, it will contain a ConcreteArray. # Otherwise, it will be abstract. diff --git a/tests/returntypes/test_jax_jit_new.py b/tests/returntypes/test_jax_jit_new.py new file mode 100644 index 00000000000..d7359bac72f --- /dev/null +++ b/tests/returntypes/test_jax_jit_new.py @@ -0,0 +1,1026 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the JAX-JIT interface""" +import sys +import pytest + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") + +from jax.config import config + +config.update("jax_enable_x64", True) + +import numpy as np + +import pennylane as qml +from pennylane.gradients import param_shift +from pennylane.interfaces import execute, InterfaceUnsupportedError +from pennylane.interfaces.jax_jit import _execute_with_fwd + + +@pytest.mark.parametrize( + "version, package, should_raise", + [ + ("0.3.16", jax, True), + ("0.3.17", jax, False), + ("0.3.18", jax, False), + ("0.3.14", jax.lib, True), + ("0.3.15", jax.lib, False), + ("0.3.16", jax.lib, False), + ], +) +def test_raise_version_error(package, version, should_raise, monkeypatch): + """Test JAX version error""" + a = jnp.array([0.1, 0.2]) + + dev = qml.device("default.qubit", wires=1) + + with qml.tape.QuantumTape() as tape: + qml.expval(qml.PauliZ(0)) + + with monkeypatch.context() as m: + m.setattr(package, "__version__", version) + + if should_raise: + msg = "requires version 0.3.17 or higher for JAX and 0.3.15 or higher JAX lib" + with pytest.raises(InterfaceUnsupportedError, match=msg): + execute([tape], dev, gradient_fn=param_shift, interface="jax-jit") + else: + execute([tape], dev, gradient_fn=param_shift, interface="jax-jit") + + +class TestJaxExecuteUnitTests: + """Unit tests for jax execution""" + + def test_import_error(self, mocker): + """Test that an exception is caught on import error""" + + mock = mocker.patch.object(jax, "custom_jvp") + mock.side_effect = ImportError() + + dev = qml.device("default.qubit", wires=2, shots=None) + + with qml.tape.QuantumTape() as tape: + qml.expval(qml.PauliY(1)) + + with pytest.raises( + qml.QuantumFunctionError, + match="jax not found. Please install the latest version " + "of jax to enable the 'jax' interface", + ): + qml.execute([tape], dev, gradient_fn=qml.gradients.param_shift, interface="jax-jit") + + def test_jacobian_options(self, mocker, tol): + """Test setting jacobian options""" + spy = mocker.spy(qml.gradients, "param_shift") + + a = jnp.array([0.1, 0.2]) + + dev = qml.device("default.qubit", wires=1) + + def cost(a, device): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + device, + gradient_fn=param_shift, + gradient_kwargs={"shifts": [(np.pi / 4,)] * 2}, + interface="jax-jit", + )[0] + + res = jax.grad(cost)(a, device=dev) + + for args in spy.call_args_list: + assert args[1]["shifts"] == [(np.pi / 4,)] * 2 + + def test_incorrect_mode(self): + """Test that an error is raised if an gradient transform + is used with mode=forward""" + a = jnp.array([0.1, 0.2]) + + dev = qml.device("default.qubit", wires=1) + + def cost(a, device): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + device, + gradient_fn=param_shift, + mode="forward", + interface="jax-jit", + )[0] + + with pytest.raises( + ValueError, match="Gradient transforms cannot be used with mode='forward'" + ): + res = jax.grad(cost)(a, device=dev) + + def test_unknown_interface(self): + """Test that an error is raised if the interface is unknown""" + a = jnp.array([0.1, 0.2]) + + dev = qml.device("default.qubit", wires=1) + + def cost(a, device): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + device, + gradient_fn=param_shift, + interface="None", + )[0] + + with pytest.raises(ValueError, match="Unknown interface"): + cost(a, device=dev) + + # TODO + @pytest.mark.skip() + def test_forward_mode(self, mocker): + """Test that forward mode uses the `device.execute_and_gradients` pathway""" + dev = qml.device("default.qubit", wires=1) + spy = mocker.spy(dev, "execute_and_gradients") + + def cost(a): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn="device", + interface="jax-jit", + gradient_kwargs={ + "method": "adjoint_jacobian", + "use_device_state": True, + }, + )[0] + + a = jnp.array([0.1, 0.2]) + jax.jit(cost)(a) + + # adjoint method only performs a single device execution, but gets both result and gradient + assert dev.num_executions == 1 + spy.assert_called() + + def test_backward_mode(self, mocker): + """Test that backward mode uses the `device.batch_execute` and `device.gradients` pathway""" + dev = qml.device("default.qubit", wires=1) + spy_execute = mocker.spy(qml.devices.DefaultQubit, "batch_execute") + spy_gradients = mocker.spy(qml.devices.DefaultQubit, "gradients") + + def cost(a): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn="device", + mode="backward", + interface="jax-jit", + gradient_kwargs={"method": "adjoint_jacobian"}, + )[0] + + a = jnp.array([0.1, 0.2]) + jax.jit(cost)(a) + + assert dev.num_executions == 1 + spy_execute.assert_called() + spy_gradients.assert_not_called() + + jax.grad(cost)(a) + spy_gradients.assert_called() + + def test_max_diff_error(self): + """Test that an error is being raised if max_diff > 1 for the JAX + interface.""" + a = jnp.array([0.1, 0.2]) + + dev = qml.device("default.qubit", wires=1) + + with pytest.raises( + InterfaceUnsupportedError, + match="The JAX-JIT interface only supports first order derivatives.", + ): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + execute( + [tape], + dev, + interface="jax-jit", + gradient_fn=param_shift, + gradient_kwargs={"shift": np.pi / 4}, + max_diff=2, + ) + + +class TestCaching: + """Test for caching behaviour""" + + def test_cache_maxsize(self, mocker): + """Test the cachesize property of the cache""" + dev = qml.device("default.qubit", wires=1) + spy = mocker.spy(qml.interfaces, "cache_execute") + + def cost(a, cachesize): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn=param_shift, + cachesize=cachesize, + interface="jax-jit", + )[0] + + params = jnp.array([0.1, 0.2]) + jax.jit(jax.grad(cost), static_argnums=1)(params, cachesize=2) + cache = spy.call_args[0][1] + + assert cache.maxsize == 2 + assert cache.currsize == 2 + assert len(cache) == 2 + + def test_custom_cache(self, mocker): + """Test the use of a custom cache object""" + dev = qml.device("default.qubit", wires=1) + spy = mocker.spy(qml.interfaces, "cache_execute") + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn=param_shift, + cache=cache, + interface="jax-jit", + )[0] + + custom_cache = {} + params = jnp.array([0.1, 0.2]) + jax.grad(cost)(params, cache=custom_cache) + + cache = spy.call_args[0][1] + assert cache is custom_cache + + def test_custom_cache_multiple(self, mocker): + """Test the use of a custom cache object with multiple tapes""" + dev = qml.device("default.qubit", wires=1) + spy = mocker.spy(qml.interfaces, "cache_execute") + + a = jnp.array(0.1) + b = jnp.array(0.2) + + def cost(a, b, cache): + with qml.tape.QuantumTape() as tape1: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + + with qml.tape.QuantumTape() as tape2: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + + res = execute( + [tape1, tape2], + dev, + gradient_fn=param_shift, + cache=cache, + interface="jax-jit", + ) + return res[0] + + custom_cache = {} + jax.grad(cost)(a, b, cache=custom_cache) + + cache = spy.call_args[0][1] + assert cache is custom_cache + + def test_caching_param_shift(self, tol): + """Test that, when using parameter-shift transform, + caching produces the optimum number of evaluations.""" + dev = qml.device("default.qubit", wires=1) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn=param_shift, + cache=cache, + interface="jax-jit", + )[0] + + # Without caching, 5 evaluations are required to compute + # the Jacobian: 1 (forward pass) + 2 (backward pass) * (2 shifts * 2 params) + params = jnp.array([0.1, 0.2]) + jax.grad(cost)(params, cache=None) + assert dev.num_executions == 5 + + # With caching, 5 evaluations are required to compute + # the Jacobian: 1 (forward pass) + (2 shifts * 2 params) + dev._num_executions = 0 + jac_fn = jax.grad(cost) + grad1 = jac_fn(params, cache=True) + assert dev.num_executions == 5 + + # Check that calling the cost function again + # continues to evaluate the device (that is, the cache + # is emptied between calls) + grad2 = jac_fn(params, cache=True) + assert dev.num_executions == 10 + assert np.allclose(grad1, grad2, atol=tol, rtol=0) + + # Check that calling the cost function again + # with different parameters produces a different Jacobian + grad2 = jac_fn(2 * params, cache=True) + assert dev.num_executions == 15 + assert not np.allclose(grad1, grad2, atol=tol, rtol=0) + + def test_caching_adjoint_backward(self): + """Test that caching produces the optimum number of adjoint evaluations + when mode=backward""" + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(0)) + + return execute( + [tape], + dev, + gradient_fn="device", + cache=cache, + mode="backward", + interface="jax-jit", + gradient_kwargs={"method": "adjoint_jacobian"}, + )[0] + + # Without caching, 2 evaluations are required. + # 1 for the forward pass, and one per output dimension + # on the backward pass. + jax.grad(cost)(params, cache=None) + assert dev.num_executions == 2 + + # With caching, also 2 evaluations are required. One + # for the forward pass, and one for the backward pass. + dev._num_executions = 0 + jac_fn = jax.grad(cost) + grad1 = jac_fn(params, cache=True) + assert dev.num_executions == 2 + + +execute_kwargs = [ + {"gradient_fn": param_shift}, + # TODO: add forward implementation + # { + # "gradient_fn": "device", + # "mode": "forward", + # "gradient_kwargs": {"method": "adjoint_jacobian", "use_device_state": True}, + # }, + { + "gradient_fn": "device", + "mode": "backward", + "gradient_kwargs": {"method": "adjoint_jacobian"}, + }, +] + + +@pytest.mark.parametrize("execute_kwargs", execute_kwargs) +class TestJaxExecuteIntegration: + """Test the jax interface execute function + integrates well for both forward and backward execution""" + + def test_execution(self, execute_kwargs): + """Test execution""" + # TODO + if execute_kwargs.get("mode", None) == "forward": + pytest.skip("TODO") + dev = qml.device("default.qubit", wires=1) + + def cost(a, b): + with qml.tape.QuantumTape() as tape1: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + + with qml.tape.QuantumTape() as tape2: + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.expval(qml.PauliZ(0)) + + return execute([tape1, tape2], dev, interface="jax-jit", **execute_kwargs) + + a = jnp.array(0.1) + b = jnp.array(0.2) + res = cost(a, b) + + assert len(res) == 2 + assert res[0].shape == () + assert res[1].shape == () + + def test_scalar_jacobian(self, execute_kwargs, tol): + """Test scalar jacobian calculation""" + a = jnp.array(0.1) + dev = qml.device("default.qubit", wires=2) + + def cost(a): + with qml.tape.QuantumTape() as tape: + qml.RY(a, wires=0) + qml.expval(qml.PauliZ(0)) + return execute([tape], dev, interface="jax-jit", **execute_kwargs)[0] + + res = jax.jit(jax.grad(cost))(a) + assert res.shape == () + + # compare to standard tape jacobian + with qml.tape.QuantumTape() as tape: + qml.RY(a, wires=0) + qml.expval(qml.PauliZ(0)) + + tape.trainable_params = [0] + tapes, fn = param_shift(tape) + expected = fn(dev.batch_execute(tapes)) + + assert expected.shape == () + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_reusing_quantum_tape(self, execute_kwargs, tol): + """Test re-using a quantum tape by passing new parameters""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + dev = qml.device("default.qubit", wires=2) + + with qml.tape.QuantumTape() as tape: + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliZ(0)) + + assert tape.trainable_params == [0, 1] + + def cost(a, b): + + # An explicit call to _update() is required here to update the + # trainable parameters in between tape executions. + # This is different from how the autograd interface works. + # Unless the update is issued, the validation check related to the + # number of provided parameters fails in the tape: (len(params) != + # required_length) and the tape produces incorrect results. + tape._update() + tape.set_parameters([a, b]) + return execute([tape], dev, interface="jax-jit", **execute_kwargs)[0] + + jac_fn = jax.jit(jax.grad(cost)) + jac = jac_fn(a, b) + + a = jnp.array(0.54) + b = jnp.array(0.8) + + # check that the cost function continues to depend on the + # values of the parameters for subsequent calls + res2 = cost(2 * a, b) + expected = [np.cos(2 * a)] + assert np.allclose(res2, expected, atol=tol, rtol=0) + + jac_fn = jax.jit(jax.grad(lambda a, b: cost(2 * a, b))) + jac = jac_fn(a, b) + expected = -2 * np.sin(2 * a) + assert np.allclose(jac, expected, atol=tol, rtol=0) + + def test_grad_with_backward_mode(self, execute_kwargs): + """Test jax grad for adjoint diff method in backward mode""" + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + expected_results = jnp.array([-0.3875172, -0.18884787, -0.38355705]) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(0)) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + )[0] + return res + + cost = jax.jit(cost) + + results = jax.grad(cost)(params, cache=None) + for r, e in zip(results, expected_results): + assert jnp.allclose(r, e, atol=1e-7) + + def test_classical_processing_single_tape(self, execute_kwargs, tol): + """Test classical processing within the quantum tape for a single tape""" + a = jnp.array(0.1) + b = jnp.array(0.2) + c = jnp.array(0.3) + + def cost(a, b, c, device): + with qml.tape.QuantumTape() as tape: + qml.RY(a * c, wires=0) + qml.RZ(b, wires=0) + qml.RX(c + c**2 + jnp.sin(a), wires=0) + qml.expval(qml.PauliZ(0)) + + return execute([tape], device, interface="jax-jit", **execute_kwargs)[0] + + dev = qml.device("default.qubit", wires=2) + res = jax.jit(jax.grad(cost, argnums=(0, 1, 2)), static_argnums=3)(a, b, c, device=dev) + assert len(res) == 3 + + def test_classical_processing_multiple_tapes(self, execute_kwargs, tol): + """Test classical processing within the quantum tape for multiple + tapes""" + dev = qml.device("default.qubit", wires=2) + params = jax.numpy.array([0.3, 0.2]) + + def cost_fn(x): + with qml.tape.QuantumTape() as tape1: + qml.Hadamard(0) + qml.RY(x[0], wires=[0]) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliZ(0)) + + with qml.tape.QuantumTape() as tape2: + qml.Hadamard(0) + qml.CRX(2 * x[0] * x[1], wires=[0, 1]) + qml.RX(2 * x[1], wires=[1]) + qml.expval(qml.PauliZ(0)) + + result = execute( + tapes=[tape1, tape2], device=dev, interface="jax-jit", **execute_kwargs + ) + return result[0] + result[1] - 7 * result[1] + + res = jax.jit(jax.grad(cost_fn))(params) + assert res.shape == (2,) + + def test_multiple_tapes_output(self, execute_kwargs, tol): + """Test the output types for the execution of multiple quantum tapes""" + dev = qml.device("default.qubit", wires=2) + params = jax.numpy.array([0.3, 0.2]) + + def cost_fn(x): + with qml.tape.QuantumTape() as tape1: + qml.Hadamard(0) + qml.RY(x[0], wires=[0]) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliZ(0)) + + with qml.tape.QuantumTape() as tape2: + qml.Hadamard(0) + qml.CRX(2 * x[0] * x[1], wires=[0, 1]) + qml.RX(2 * x[1], wires=[1]) + qml.expval(qml.PauliZ(0)) + + return execute(tapes=[tape1, tape2], device=dev, interface="jax-jit", **execute_kwargs) + + res = jax.jit(cost_fn)(params) + assert isinstance(res, list) + assert all(isinstance(r, jnp.ndarray) for r in res) + assert all(r.shape == () for r in res) + + def test_matrix_parameter(self, execute_kwargs, tol): + """Test that the jax interface works correctly + with a matrix parameter""" + a = jnp.array(0.1) + U = jnp.array([[0, 1], [1, 0]]) + + def cost(a, U, device): + with qml.tape.QuantumTape() as tape: + qml.QubitUnitary(U, wires=0) + qml.RY(a, wires=0) + qml.expval(qml.PauliZ(0)) + + tape.trainable_params = [0] + return execute([tape], device, interface="jax-jit", **execute_kwargs)[0] + + dev = qml.device("default.qubit", wires=2) + res = jax.jit(cost, static_argnums=2)(a, U, device=dev) + assert np.allclose(res, -np.cos(a), atol=tol, rtol=0) + + jac_fn = jax.grad(cost, argnums=(0)) + res = jac_fn(a, U, device=dev) + assert np.allclose(res, np.sin(a), atol=tol, rtol=0) + + def test_differentiable_expand(self, execute_kwargs, tol): + """Test that operation and nested tapes expansion + is differentiable""" + + class U3(qml.U3): + def expand(self): + tape = qml.tape.QuantumTape() + theta, phi, lam = self.data + wires = self.wires + tape._ops += [ + qml.Rot(lam, theta, -lam, wires=wires), + qml.PhaseShift(phi + lam, wires=wires), + ] + return tape + + def cost_fn(a, p, device): + tape = qml.tape.QuantumTape() + + with tape: + qml.RX(a, wires=0) + U3(*p, wires=0) + qml.expval(qml.PauliX(0)) + + tape = tape.expand(stop_at=lambda obj: device.supports_operation(obj.name)) + return execute([tape], device, interface="jax-jit", **execute_kwargs)[0] + + a = jnp.array(0.1) + p = jnp.array([0.1, 0.2, 0.3]) + + dev = qml.device("default.qubit", wires=1) + res = jax.jit(cost_fn, static_argnums=2)(a, p, device=dev) + expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * ( + np.cos(p[2]) * np.sin(p[1]) + np.cos(p[0]) * np.cos(p[1]) * np.sin(p[2]) + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + jac_fn = jax.jit(jax.grad(cost_fn, argnums=(1)), static_argnums=2) + res = jac_fn(a, p, device=dev) + expected = jnp.array( + [ + np.cos(p[1]) * (np.cos(a) * np.cos(p[0]) - np.sin(a) * np.sin(p[0]) * np.sin(p[2])), + np.cos(p[1]) * np.cos(p[2]) * np.sin(a) + - np.sin(p[1]) + * (np.cos(a) * np.sin(p[0]) + np.cos(p[0]) * np.sin(a) * np.sin(p[2])), + np.sin(a) + * (np.cos(p[0]) * np.cos(p[1]) * np.cos(p[2]) - np.sin(p[1]) * np.sin(p[2])), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_independent_expval(self, execute_kwargs): + """Tests computing an expectation value that is independent of trainable + parameters.""" + # TODO + if execute_kwargs.get("mode", None) == "forward": + pytest.skip("TODO") + + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(1)) + + res = execute([tape], dev, cache=cache, interface="jax-jit", **execute_kwargs) + return res[0] + + res = jax.jit(jax.grad(cost), static_argnums=1)(params, cache=None) + assert res.shape == (3,) + + +@pytest.mark.parametrize("execute_kwargs", execute_kwargs) +class TestVectorValuedJIT: + """Test vector-valued returns for the JAX-JIT interface.""" + + @pytest.mark.parametrize( + "ret_type, shape, expected_type", + [ + ([qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))], (), tuple), + ([qml.probs(wires=[0, 1])], (4,), jnp.ndarray), + ], + ) + def test_shapes(self, execute_kwargs, ret_type, shape, expected_type): + """Test the shape of the result of vector-valued QNodes.""" + adjoint = execute_kwargs.get("gradient_kwargs", {}).get("method", "") == "adjoint_jacobian" + if adjoint: + pytest.skip("The adjoint diff method doesn't support probabilities.") + + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + + idx = 0 + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + for r in ret_type: + qml.apply(r) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + ) + return res[0] + + res = jax.jit(cost)(params, cache=None) + assert isinstance(res, expected_type) + + if expected_type is tuple: + for r in res: + assert r.shape == shape + else: + assert res.shape == shape + + def test_independent_expval(self, execute_kwargs): + """Tests computing an expectation value that is independent of trainable + parameters.""" + # TODO + if execute_kwargs.get("mode", None) == "forward": + pytest.skip("TODO") + + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(1)) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + ) + return res[0] + + res = jax.jit(jax.grad(cost), static_argnums=1)(params, cache=None) + assert res.shape == (3,) + + ret_and_output_dim = [ + ([qml.probs(wires=0)], (2,), jnp.ndarray), + ([qml.state()], (4,), jnp.ndarray), + ([qml.density_matrix(wires=0)], (2, 2), jnp.ndarray), + # Multi measurements + ([qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))], (), tuple), + ([qml.var(qml.PauliZ(0)), qml.var(qml.PauliZ(1))], (), tuple), + ([qml.probs(wires=0), qml.probs(wires=1)], (2,), tuple), + ] + + @pytest.mark.parametrize("ret, out_dim, expected_type", ret_and_output_dim) + def test_vector_valued_qnode(self, execute_kwargs, ret, out_dim, expected_type): + """Tests the shape of vector-valued QNode results.""" + + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + if execute_kwargs.get("mode", None) == "forward": + pytest.skip("TODO") + + grad_meth = ( + execute_kwargs["gradient_kwargs"]["method"] + if "gradient_kwargs" in execute_kwargs + else "" + ) + if "adjoint" in grad_meth and any( + r.return_type + in (qml.measurements.Probability, qml.measurements.State, qml.measurements.Variance) + for r in ret + ): + pytest.skip("Adjoint does not support probs") + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + + for r in ret: + qml.apply(r) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + )[0] + return res + + res = jax.jit(cost, static_argnums=1)(params, cache=None) + + assert isinstance(res, expected_type) + if expected_type is tuple: + for r in res: + assert r.shape == out_dim + else: + assert res.shape == out_dim + + def test_qnode_sample(self, execute_kwargs): + """Tests computing multiple expectation values in a tape.""" + dev = qml.device("default.qubit", wires=2, shots=10) + params = jnp.array([0.1, 0.2, 0.3]) + + grad_meth = ( + execute_kwargs["gradient_kwargs"]["method"] + if "gradient_kwargs" in execute_kwargs + else "" + ) + if "adjoint" in grad_meth or "backprop" in grad_meth: + pytest.skip("Adjoint does not support probs") + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.sample(qml.PauliZ(0)) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + )[0] + return res + + res = jax.jit(cost, static_argnums=1)(params, cache=None) + assert res.shape == (dev.shots,) + + def test_multiple_expvals_grad(self, execute_kwargs): + """Tests computing multiple expectation values in a tape.""" + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" + if fwd_mode: + pytest.skip("The forward mode is tested separately as it should raise an error.") + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(0)) + qml.expval(qml.PauliZ(1)) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + )[0] + return res[0] + res[1] + + res = jax.jit(jax.grad(cost), static_argnums=1)(params, cache=None) + assert res.shape == (3,) + + def test_multi_tape_jacobian_probs_expvals(self, execute_kwargs): + """Test the jacobian computation with multiple tapes with probability + and expectation value computations.""" + fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" + if fwd_mode: + pytest.skip("The forward mode is tested separately as it should raise an error.") + + adjoint = execute_kwargs.get("gradient_kwargs", {}).get("method", "") == "adjoint_jacobian" + if adjoint: + pytest.skip("The adjoint diff method doesn't support probabilities.") + + def cost(x, y, device, interface, ek): + with qml.tape.QuantumTape() as tape1: + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliZ(0)) + qml.expval(qml.PauliZ(1)) + + with qml.tape.QuantumTape() as tape2: + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + qml.probs(wires=[0]) + qml.probs(wires=[1]) + + return qml.execute([tape1, tape2], device, **ek, interface=interface)[0] + + dev = qml.device("default.qubit", wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + x_ = np.array(0.543) + y_ = np.array(-0.654) + + res = cost(x, y, dev, interface="jax-jit", ek=execute_kwargs) + + exp = cost(x_, y_, dev, interface="autograd", ek=execute_kwargs) + + for r, e in zip(res, exp): + assert jnp.allclose(r, e, atol=1e-7) + + # TODO: update when fwd mode is implemented + def test_multiple_expvals_raises_fwd_device_grad(self, execute_kwargs): + """Tests computing multiple expectation values in a tape.""" + execute_kwargs = { + "gradient_fn": "device", + "mode": "forward", + "gradient_kwargs": {"method": "adjoint_jacobian", "use_device_state": True}, + } + # fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" + # if not fwd_mode: + # pytest.skip("Forward mode is not turned on.") + + dev = qml.device("default.qubit", wires=2) + params = jnp.array([0.1, 0.2, 0.3]) + + def cost(a, cache): + with qml.tape.QuantumTape() as tape: + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + qml.RY(a[2], wires=0) + qml.expval(qml.PauliZ(0)) + qml.expval(qml.PauliZ(1)) + + res = qml.interfaces.execute( + [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs + ) + return res[0] + + with pytest.raises(NotImplementedError): + # with pytest.raises(InterfaceUnsupportedError): + jax.jacobian(cost)(params, cache=None) + + def test_assertion_error_fwd(self, execute_kwargs): + """Test that an assertion is raised if by chance there is a difference + in the number of tapes and the number of parameters sequences passed + to _execute_with_fwd.""" + a = 0.3 + b = 0.3 + + with qml.tape.QuantumTape() as tape: + qml.RY(a, wires=0) + qml.RY(b, wires=0) + qml.expval(qml.PauliZ(0)) + + device = qml.device("default.qubit", wires=2) + + # Create arguments for 2 tapes + params = [[0.2], [0.3]] + + # But pass only 1 tape + tapes = [tape] + + with pytest.raises(AssertionError): + _execute_with_fwd( + params, + tapes=tapes, + device=device, + execute_fn=lambda a: a, # Some dummy function + gradient_kwargs=None, + _n=1, + ) + + +def test_diff_method_None_jit(): + """Test that jitted execution works when `gradient_fn=None`.""" + + dev = qml.device("default.qubit.jax", wires=1, shots=10) + + @jax.jit + def wrapper(x): + with qml.tape.QuantumTape() as tape: + qml.RX(x, wires=0) + qml.expval(qml.PauliZ(0)) + + return qml.execute([tape], dev, gradient_fn=None) + + assert jnp.allclose(wrapper(jnp.array(0.0))[0], 1.0) diff --git a/tests/returntypes/test_jax_jit_qnode_new.py b/tests/returntypes/test_jax_jit_qnode_new.py new file mode 100644 index 00000000000..2e0352d90d0 --- /dev/null +++ b/tests/returntypes/test_jax_jit_qnode_new.py @@ -0,0 +1,2003 @@ +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for using the JAX-JIT interface with a QNode""" + +import pytest + +import pennylane as qml +from pennylane import numpy as np +from pennylane import qnode +from pennylane.tape import QuantumTape + +qubit_device_and_diff_method = [ + ["default.qubit", "backprop", "forward", "jax"], + # Jit + ["default.qubit", "finite-diff", "backward", "jax-jit"], + ["default.qubit", "parameter-shift", "backward", "jax-jit"], + # TODO: + # ["default.qubit", "adjoint", "forward", "jax-jit"], + ["default.qubit", "adjoint", "backward", "jax-jit"], +] + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") +jnp = jax.numpy + +from jax.config import config + +config.update("jax_enable_x64", True) + + +@pytest.mark.parametrize("dev_name,diff_method,mode,interface", qubit_device_and_diff_method) +class TestQNode: + """Test that using the QNode with JAX integrates with the PennyLane + stack""" + + def test_execution_with_interface(self, dev_name, diff_method, mode, interface): + """Test execution works with the interface""" + if diff_method == "backprop": + pytest.skip("Test does not support backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a, wires=0) + qml.RX(0.2, wires=0) + return qml.expval(qml.PauliZ(0)) + + a = np.array(0.1, requires_grad=True) + jax.jit(circuit)(a) + + assert circuit.interface == interface + + # the tape is able to deduce trainable parameters + assert circuit.qtape.trainable_params == [0] + + # gradients should work + grad = jax.jit(jax.grad(circuit))(a) + assert isinstance(grad, jnp.DeviceArray) + assert grad.shape == () + + def test_changing_trainability(self, dev_name, diff_method, mode, interface, mocker, tol): + """Test changing the trainability of parameters changes the + number of differentiation requests made""" + if diff_method != "parameter-shift": + pytest.skip("Test only supports parameter-shift") + + a = jnp.array(0.1) + b = jnp.array(0.2) + + dev = qml.device("default.qubit", wires=2) + + @qnode(dev, interface=interface, diff_method="parameter-shift") + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Hamiltonian([1, 1], [qml.PauliZ(0), qml.PauliY(1)])) + + grad_fn = jax.grad(circuit, argnums=[0, 1]) + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + res = grad_fn(a, b) + + # the tape has reported both arguments as trainable + assert circuit.qtape.trainable_params == [0, 1] + + expected = [-np.sin(a) + np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)] + assert np.allclose(res, expected, atol=tol, rtol=0) + + # The parameter-shift rule has been called for each argument + assert len(spy.spy_return[0]) == 4 + + # make the second QNode argument a constant + grad_fn = jax.grad(circuit, argnums=0) + res = grad_fn(a, b) + + # the tape has reported only the first argument as trainable + assert circuit.qtape.trainable_params == [0] + + expected = [-np.sin(a) + np.sin(a) * np.sin(b)] + assert np.allclose(res, expected, atol=tol, rtol=0) + + # The parameter-shift rule has been called only once + assert len(spy.spy_return[0]) == 2 + + # trainability also updates on evaluation + a = np.array(0.54, requires_grad=False) + b = np.array(0.8, requires_grad=True) + circuit(a, b) + assert circuit.qtape.trainable_params == [1] + + def test_classical_processing(self, dev_name, diff_method, mode, interface, tol): + """Test classical processing within the quantum tape""" + qml.disable_return() + a = jnp.array(0.1) + b = jnp.array(0.2) + c = jnp.array(0.3) + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(a, b, c): + qml.RY(a * c, wires=0) + qml.RZ(b, wires=0) + qml.RX(c + c**2 + jnp.sin(a), wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.grad(circuit, argnums=[0, 2])(a, b, c) + + if diff_method == "finite-diff": + assert circuit.qtape.trainable_params == [0, 2] + + assert len(res) == 2 + + def test_matrix_parameter(self, dev_name, diff_method, mode, interface, tol): + """Test that the jax interface works correctly + with a matrix parameter""" + U = jnp.array([[0, 1], [1, 0]]) + a = jnp.array(0.1) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(U, a): + qml.QubitUnitary(U, wires=0) + qml.RY(a, wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.grad(circuit, argnums=1)(U, a) + assert np.allclose(res, np.sin(a), atol=tol, rtol=0) + + if diff_method == "finite-diff": + assert circuit.qtape.trainable_params == [1] + + def test_differentiable_expand(self, dev_name, diff_method, mode, interface, tol): + """Test that operation and nested tape expansion + is differentiable""" + + class U3(qml.U3): + def expand(self): + theta, phi, lam = self.data + wires = self.wires + + with QuantumTape() as tape: + qml.Rot(lam, theta, -lam, wires=wires) + qml.PhaseShift(phi + lam, wires=wires) + + return tape + + dev = qml.device(dev_name, wires=1) + a = jnp.array(0.1) + p = jnp.array([0.1, 0.2, 0.3]) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(a, p): + qml.RX(a, wires=0) + U3(p[0], p[1], p[2], wires=0) + return qml.expval(qml.PauliX(0)) + + res = jax.jit(circuit)(a, p) + expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * ( + np.cos(p[2]) * np.sin(p[1]) + np.cos(p[0]) * np.cos(p[1]) * np.sin(p[2]) + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.jit(jax.grad(circuit, argnums=1))(a, p) + expected = np.array( + [ + np.cos(p[1]) * (np.cos(a) * np.cos(p[0]) - np.sin(a) * np.sin(p[0]) * np.sin(p[2])), + np.cos(p[1]) * np.cos(p[2]) * np.sin(a) + - np.sin(p[1]) + * (np.cos(a) * np.sin(p[0]) + np.cos(p[0]) * np.sin(a) * np.sin(p[2])), + np.sin(a) + * (np.cos(p[0]) * np.cos(p[1]) * np.cos(p[2]) - np.sin(p[1]) * np.sin(p[2])), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_jacobian_options(self, dev_name, diff_method, mode, interface, mocker, tol): + """Test setting jacobian options""" + if diff_method != "finite-diff": + pytest.skip("Test only applies to finite diff.") + + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = np.array([0.1, 0.2], requires_grad=True) + + dev = qml.device("default.qubit", wires=1) + + @qnode(dev, interface=interface, diff_method="finite-diff", h=1e-8, approx_order=2) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + if diff_method in {"finite-diff", "parameter-shift"} and interface == "jax-jit": + # No jax.jacobian support for call + pytest.xfail(reason="batching rules are implemented only for id_tap, not for call.") + + jax.jit(jax.jacobian(circuit))(a) + + for args in spy.call_args_list: + assert args[1]["approx_order"] == 2 + assert args[1]["h"] == 1e-8 + + +vv_qubit_device_and_diff_method = [ + ["default.qubit", "backprop", "forward", "jax"], + # Jit + ["default.qubit", "finite-diff", "backward", "jax-jit"], + ["default.qubit", "parameter-shift", "backward", "jax-jit"], + # TODO: + # ["default.qubit", "adjoint", "forward", "jax-jit"], + ["default.qubit", "adjoint", "backward", "jax-jit"], +] + + +@pytest.mark.parametrize("dev_name,diff_method,mode,interface", vv_qubit_device_and_diff_method) +class TestVectorValuedQNode: + """Test that using vector-valued QNodes with JAX integrate with the + PennyLane stack""" + + def test_diff_expval_expval(self, dev_name, diff_method, mode, interface, mocker, tol): + """Test jacobian calculation""" + if diff_method == "parameter-shift": + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + elif diff_method == "finite-diff": + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = np.array(0.1, requires_grad=True) + b = np.array(0.2, requires_grad=True) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliY(1)) + + res = jax.jit(circuit)(a, b) + + assert circuit.qtape.trainable_params == [0, 1] + assert isinstance(res, tuple) + assert len(res) == 2 + + expected = [np.cos(a), -np.cos(a) * np.sin(b)] + assert np.allclose(res[0], expected[0], atol=tol, rtol=0) + assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + + res = jax.jit(jax.jacobian(circuit, argnums=[0, 1]))(a, b) + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]) + assert isinstance(res, tuple) + assert len(res) == 2 + + assert isinstance(res[0], tuple) + assert isinstance(res[0][0], jax.numpy.ndarray) + assert res[0][0].shape == () + assert np.allclose(res[0][0], expected[0][0], atol=tol, rtol=0) + assert isinstance(res[0][1], jax.numpy.ndarray) + assert res[0][1].shape == () + assert np.allclose(res[0][1], expected[0][1], atol=tol, rtol=0) + + assert isinstance(res[1], tuple) + assert isinstance(res[1][0], jax.numpy.ndarray) + assert res[1][0].shape == () + assert np.allclose(res[1][0], expected[1][0], atol=tol, rtol=0) + assert isinstance(res[1][1], jax.numpy.ndarray) + assert res[1][1].shape == () + assert np.allclose(res[1][1], expected[1][1], atol=tol, rtol=0) + + if diff_method in ("parameter-shift", "finite-diff"): + spy.assert_called() + + def test_jacobian_no_evaluate(self, dev_name, diff_method, mode, interface, mocker, tol): + """Test jacobian calculation when no prior circuit evaluation has been performed""" + + if diff_method == "parameter-shift": + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + elif diff_method == "finite-diff": + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = jax.numpy.array(0.1) + b = jax.numpy.array(0.2) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliY(1)) + + if diff_method == "adjoint" and mode == "backward": + + # TODO: jit here too when the following issue is resolved: + # https://github.com/PennyLaneAI/pennylane/issues/3475 + jac_fn = jax.jacobian(circuit, argnums=[0, 1]) + else: + jac_fn = jax.jit(jax.jacobian(circuit, argnums=[0, 1])) + + res = jac_fn(a, b) + + assert isinstance(res, tuple) + assert len(res) == 2 + + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]) + + assert isinstance(res[0][0], jax.numpy.ndarray) + assert res[0][0].shape == () + assert np.allclose(res[0][0], expected[0][0], atol=tol, rtol=0) + + assert isinstance(res[0][1], jax.numpy.ndarray) + assert res[0][1].shape == () + assert np.allclose(res[0][1], expected[0][1], atol=tol, rtol=0) + + assert isinstance(res[1][0], jax.numpy.ndarray) + assert res[1][0].shape == () + assert np.allclose(res[1][0], expected[1][0], atol=tol, rtol=0) + + assert isinstance(res[1][1], jax.numpy.ndarray) + assert res[1][1].shape == () + assert np.allclose(res[1][1], expected[1][1], atol=tol, rtol=0) + + if diff_method in ("parameter-shift", "finite-diff"): + spy.assert_called() + + # call the Jacobian with new parameters + a = jax.numpy.array(0.6) + b = jax.numpy.array(0.832) + + res = jac_fn(a, b) + + assert isinstance(res, tuple) + assert len(res) == 2 + + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]) + + assert isinstance(res[0][0], jax.numpy.ndarray) + assert res[0][0].shape == () + assert np.allclose(res[0][0], expected[0][0], atol=tol, rtol=0) + + assert isinstance(res[0][1], jax.numpy.ndarray) + assert res[0][1].shape == () + assert np.allclose(res[0][1], expected[0][1], atol=tol, rtol=0) + + assert isinstance(res[1][0], jax.numpy.ndarray) + assert res[1][0].shape == () + assert np.allclose(res[1][0], expected[1][0], atol=tol, rtol=0) + + assert isinstance(res[1][1], jax.numpy.ndarray) + assert res[1][1].shape == () + assert np.allclose(res[1][1], expected[1][1], atol=tol, rtol=0) + + def test_diff_single_probs(self, dev_name, diff_method, mode, interface, tol): + """Tests correct output shape and evaluation for a tape + with a single prob output""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support probs") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.probs(wires=[1]) + + res = jax.jit(jax.jacobian(circuit, argnums=[0, 1]))(x, y) + + expected = np.array( + [ + [-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2], + [np.cos(y) * np.sin(x) / 2, np.cos(x) * np.sin(y) / 2], + ] + ) + + assert isinstance(res, tuple) + assert len(res) == 2 + + assert isinstance(res[0], jax.numpy.ndarray) + assert res[0].shape == (2,) + + assert isinstance(res[1], jax.numpy.ndarray) + assert res[1].shape == (2,) + + assert np.allclose(res[0], expected.T[0], atol=tol, rtol=0) + assert np.allclose(res[1], expected.T[1], atol=tol, rtol=0) + + def test_diff_multi_probs(self, dev_name, diff_method, mode, interface, tol): + """Tests correct output shape and evaluation for a tape + with multiple prob outputs""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support probs") + + dev = qml.device(dev_name, wires=3) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.probs(wires=[0]), qml.probs(wires=[1, 2]) + + res = circuit(x, y) + + assert isinstance(res, tuple) + assert len(res) == 2 + + expected = np.array( + [ + [np.cos(x / 2) ** 2, np.sin(x / 2) ** 2], + [(1 + np.cos(x) * np.cos(y)) / 2, 0, (1 - np.cos(x) * np.cos(y)) / 2, 0], + ] + ) + + assert isinstance(res[0], jax.numpy.ndarray) + assert res[0].shape == (2,) + assert np.allclose(res[0], expected[0], atol=tol, rtol=0) + + assert isinstance(res[1], jax.numpy.ndarray) + assert res[1].shape == (4,) + assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + + jac = jax.jit(jax.jacobian(circuit, argnums=[0, 1]))(x, y) + expected_0 = np.array( + [ + [-np.sin(x) / 2, np.sin(x) / 2], + [0, 0], + ] + ) + + expected_1 = np.array( + [ + [-np.cos(y) * np.sin(x) / 2, 0, np.sin(x) * np.cos(y) / 2, 0], + [-np.cos(x) * np.sin(y) / 2, 0, np.cos(x) * np.sin(y) / 2, 0], + ] + ) + + assert isinstance(jac, tuple) + assert isinstance(jac[0], tuple) + + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == (2,) + assert np.allclose(jac[0][0], expected_0[0], atol=tol, rtol=0) + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == (2,) + assert np.allclose(jac[0][1], expected_0[1], atol=tol, rtol=0) + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == (4,) + + assert np.allclose(jac[1][0], expected_1[0], atol=tol, rtol=0) + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == (4,) + assert np.allclose(jac[1][1], expected_1[1], atol=tol, rtol=0) + + def test_diff_expval_probs(self, dev_name, diff_method, mode, interface, tol): + """Tests correct output shape and evaluation for a tape + with prob and expval outputs""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support probs") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)), qml.probs(wires=[1]) + + res = jax.jit(circuit)(x, y) + expected = np.array( + [np.cos(x), [(1 + np.cos(x) * np.cos(y)) / 2, (1 - np.cos(x) * np.cos(y)) / 2]] + ) + assert isinstance(res, tuple) + assert len(res) == 2 + + assert isinstance(res[0], jax.numpy.ndarray) + assert res[0].shape == () + assert np.allclose(res[0], expected[0], atol=tol, rtol=0) + + assert isinstance(res[1], jax.numpy.ndarray) + assert res[1].shape == (2,) + assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + + jac = jax.jit(jax.jacobian(circuit, argnums=[0, 1]))(x, y) + expected = [ + [-np.sin(x), 0], + [ + [-np.sin(x) * np.cos(y) / 2, np.cos(y) * np.sin(x) / 2], + [-np.cos(x) * np.sin(y) / 2, np.cos(x) * np.sin(y) / 2], + ], + ] + + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert np.allclose(jac[0][0], expected[0][0], atol=tol, rtol=0) + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == () + assert np.allclose(jac[0][1], expected[0][1], atol=tol, rtol=0) + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == (2,) + assert np.allclose(jac[1][0], expected[1][0], atol=tol, rtol=0) + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == (2,) + assert np.allclose(jac[1][1], expected[1][1], atol=tol, rtol=0) + + def test_diff_expval_probs_sub_argnums(self, dev_name, diff_method, mode, interface, tol): + """Tests correct output shape and evaluation for a tape with prob and expval outputs with less + trainable parameters (argnums) than parameters.""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support probs") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)), qml.probs(wires=[1]) + + jac = jax.jit(jax.jacobian(circuit, argnums=[0]))(x, y) + + expected = [ + [-np.sin(x), 0], + [ + [-np.sin(x) * np.cos(y) / 2, np.cos(y) * np.sin(x) / 2], + [-np.cos(x) * np.sin(y) / 2, np.cos(x) * np.sin(y) / 2], + ], + ] + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 1 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert np.allclose(jac[0][0], expected[0][0], atol=tol, rtol=0) + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 1 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == (2,) + assert np.allclose(jac[1][0], expected[1][0], atol=tol, rtol=0) + + def test_diff_var_probs(self, dev_name, diff_method, mode, interface, tol): + """Tests correct output shape and evaluation for a tape + with prob and variance outputs""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support probs") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.var(qml.PauliZ(0)), qml.probs(wires=[1]) + + res = jax.jit(circuit)(x, y) + + expected = np.array( + [np.sin(x) ** 2, [(1 + np.cos(x) * np.cos(y)) / 2, (1 - np.cos(x) * np.cos(y)) / 2]] + ) + + assert isinstance(res[0], jax.numpy.ndarray) + assert res[0].shape == () + assert np.allclose(res[0], expected[0], atol=tol, rtol=0) + + assert isinstance(res[1], jax.numpy.ndarray) + assert res[1].shape == (2,) + assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + + jac = jax.jit(jax.jacobian(circuit, argnums=[0, 1]))(x, y) + expected = [ + [2 * np.cos(x) * np.sin(x), 0], + [ + [-np.sin(x) * np.cos(y) / 2, np.cos(y) * np.sin(x) / 2], + [-np.cos(x) * np.sin(y) / 2, np.cos(x) * np.sin(y) / 2], + ], + ] + + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert np.allclose(jac[0][0], expected[0][0], atol=tol, rtol=0) + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == () + assert np.allclose(jac[0][1], expected[0][1], atol=tol, rtol=0) + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == (2,) + assert np.allclose(jac[1][0], expected[1][0], atol=tol, rtol=0) + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == (2,) + assert np.allclose(jac[1][1], expected[1][1], atol=tol, rtol=0) + + +@pytest.mark.parametrize("interface", ["jax", "jax-jit"]) +class TestShotsIntegration: + """Test that the QNode correctly changes shot value, and + remains differentiable.""" + + def test_diff_method_None(self, interface): + """Test jax device works with diff_method=None.""" + dev = qml.device("default.qubit.jax", wires=1, shots=10) + + @jax.jit + @qml.qnode(dev, diff_method=None, interface=interface) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + assert jnp.allclose(circuit(jnp.array(0.0)), 1) + + def test_changing_shots(self, interface, mocker, tol): + """Test that changing shots works on execution""" + dev = qml.device("default.qubit", wires=2, shots=None) + a, b = jnp.array([0.543, -0.654]) + + @qnode(dev, diff_method=qml.gradients.param_shift, interface=interface) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + spy = mocker.spy(dev, "sample") + + # execute with device default shots (None) + res = circuit(a, b) + assert np.allclose(res, -np.cos(a) * np.sin(b), atol=tol, rtol=0) + spy.assert_not_called() + + # execute with shots=100 + res = circuit(a, b, shots=100) + spy.assert_called_once() + assert spy.spy_return.shape == (100,) + + # device state has been unaffected + assert dev.shots is None + res = circuit(a, b) + assert np.allclose(res, -np.cos(a) * np.sin(b), atol=tol, rtol=0) + spy.assert_called_once() # no additional calls + + def test_gradient_integration(self, interface, tol, mocker): + """Test that temporarily setting the shots works + for gradient computations""" + dev = qml.device("default.qubit", wires=2, shots=1) + a, b = jnp.array([0.543, -0.654]) + + spy = mocker.spy(dev, "batch_execute") + + @qnode(dev, diff_method=qml.gradients.param_shift, interface=interface) + def cost_fn(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + # TODO: jit when https://github.com/PennyLaneAI/pennylane/issues/3474 is resolved + res = jax.grad(cost_fn, argnums=[0, 1])(a, b, shots=30000) + assert dev.shots == 1 + + expected = [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)] + assert np.allclose(res, expected, atol=0.1, rtol=0) + + def test_update_diff_method(self, mocker, interface, tol): + """Test that temporarily setting the shots updates the diff method""" + dev = qml.device("default.qubit", wires=2, shots=100) + a, b = jnp.array([0.543, -0.654]) + + spy = mocker.spy(qml, "execute") + + # We're choosing interface="jax" such that backprop can be used in the + # test later + @qnode(dev, interface="jax") + def cost_fn(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + # since we are using finite shots, parameter-shift will + # be chosen + assert cost_fn.gradient_fn is qml.gradients.param_shift + + cost_fn(a, b) + assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift + + # if we set the shots to None, backprop can now be used + cost_fn(a, b, shots=None) + assert spy.call_args[1]["gradient_fn"] == "backprop" + + # original QNode settings are unaffected + assert cost_fn.gradient_fn is qml.gradients.param_shift + cost_fn(a, b) + assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift + + +@pytest.mark.parametrize("dev_name,diff_method,mode,interface", qubit_device_and_diff_method) +class TestQubitIntegration: + """Tests that ensure various qubit circuits integrate correctly""" + + def test_sampling(self, dev_name, diff_method, mode, interface): + """Test sampling works as expected""" + if mode == "forward": + pytest.skip("Sampling not possible with forward mode differentiation.") + + if diff_method == "adjoint": + pytest.skip("Adjoint warns with finite shots") + + dev = qml.device(dev_name, wires=2, shots=10) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(): + qml.Hadamard(wires=[0]) + qml.CNOT(wires=[0, 1]) + return qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliX(1)) + + res = jax.jit(circuit)() + + assert isinstance(res, tuple) + + assert isinstance(res[0], jnp.DeviceArray) + assert res[0].shape == (10,) + assert isinstance(res[1], jnp.DeviceArray) + assert res[1].shape == (10,) + + def test_counts(self, dev_name, diff_method, mode, interface): + """Test counts works as expected""" + if mode == "forward": + pytest.skip("Sampling not possible with forward mode differentiation.") + + if diff_method == "adjoint": + pytest.skip("Adjoint warns with finite shots") + + dev = qml.device(dev_name, wires=2, shots=10) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(): + qml.Hadamard(wires=[0]) + qml.CNOT(wires=[0, 1]) + return qml.counts(qml.PauliZ(0)), qml.counts(qml.PauliX(1)) + + if interface == "jax-jit": + with pytest.raises( + NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts." + ): + jax.jit(circuit)() + else: + res = jax.jit(circuit)() + + assert isinstance(res, tuple) + + assert isinstance(res[0], dict) + assert len(res[0]) == 2 + assert isinstance(res[1], dict) + assert len(res[1]) == 2 + + def test_chained_qnodes(self, dev_name, diff_method, mode, interface): + """Test that the gradient of chained QNodes works without error""" + dev = qml.device(dev_name, wires=2) + + class Template(qml.templates.StronglyEntanglingLayers): + def expand(self): + with qml.tape.QuantumTape() as tape: + qml.templates.StronglyEntanglingLayers(*self.parameters, self.wires) + return tape + + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) + def circuit1(weights): + Template(weights, wires=[0, 1]) + return qml.expval(qml.PauliZ(0)) + + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) + def circuit2(data, weights): + qml.templates.AngleEmbedding(jnp.stack([data, 0.7]), wires=[0, 1]) + Template(weights, wires=[0, 1]) + return qml.expval(qml.PauliX(0)) + + def cost(weights): + w1, w2 = weights + c1 = circuit1(w1) + c2 = circuit2(c1, w2) + return jnp.sum(c2) ** 2 + + w1 = qml.templates.StronglyEntanglingLayers.shape(n_wires=2, n_layers=3) + w2 = qml.templates.StronglyEntanglingLayers.shape(n_wires=2, n_layers=4) + + weights = [ + jnp.array(np.random.random(w1)), + jnp.array(np.random.random(w2)), + ] + + grad_fn = jax.jit(jax.grad(cost)) + res = grad_fn(weights) + + assert len(res) == 2 + + +hessian_qubit_device_and_diff_method = [ + ["default.qubit", "backprop", "forward", "jax"], + # TODO: + # Jit + # ["default.qubit", "finite-diff", "backward", "jax-jit"], + # ["default.qubit", "parameter-shift", "backward", "jax-jit"], + # ["default.qubit", "adjoint", "forward", "jax-jit"], + # ["default.qubit", "adjoint", "backward", "jax-jit"], +] + + +@pytest.mark.parametrize( + "dev_name,diff_method,mode,interface", hessian_qubit_device_and_diff_method +) +class TestQubitIntegrationHigherOrder: + """Tests that ensure various qubit circuits integrate correctly when computing higher-order derivatives""" + + def test_second_derivative(self, dev_name, diff_method, mode, interface, tol): + """Test second derivative calculation of a scalar-valued QNode""" + + if diff_method == "adjoint": + pytest.skip("Adjoint does not second derivative.") + + if interface == "jax-jit": + pytest.skip("JAX-JIT doesn't yet support Hessians.") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + g = jax.grad(circuit)(x) + g2 = jax.grad(lambda x: jnp.sum(jax.grad(circuit)(x)))(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + expected_g2 = [ + -np.cos(a) * np.cos(b) + np.sin(a) * np.sin(b), + np.sin(a) * np.sin(b) - np.cos(a) * np.cos(b), + ] + if diff_method == "finite-diff": + assert np.allclose(g2, expected_g2, atol=10e-2, rtol=0) + else: + assert np.allclose(g2, expected_g2, atol=tol, rtol=0) + + def test_hessian(self, dev_name, diff_method, mode, interface, tol): + """Test hessian calculation of a scalar-valued QNode""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support second derivative.") + + if interface == "jax-jit": + pytest.skip("JAX-JIT doesn't yet support Hessians.") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + grad_fn = jax.grad(circuit) + g = grad_fn(x) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + hess = jax.jacobian(grad_fn)(x) + + expected_hess = [ + [-np.cos(a) * np.cos(b), np.sin(a) * np.sin(b)], + [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)], + ] + if diff_method == "finite-diff": + assert np.allclose(hess, expected_hess, atol=10e-2, rtol=0) + else: + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued(self, dev_name, diff_method, mode, interface, tol): + """Test hessian calculation of a vector-valued QNode""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support second derivative.") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.probs(wires=0) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + + a, b = x + + expected_res = [0.5 + 0.5 * np.cos(a) * np.cos(b), 0.5 - 0.5 * np.cos(a) * np.cos(b)] + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + jac_fn = jax.jacobian(circuit) + g = jac_fn(x) + + expected_g = [ + [-0.5 * np.sin(a) * np.cos(b), -0.5 * np.cos(a) * np.sin(b)], + [0.5 * np.sin(a) * np.cos(b), 0.5 * np.cos(a) * np.sin(b)], + ] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + hess = jax.jacobian(jac_fn)(x) + + expected_hess = [ + [ + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.sin(a) * np.sin(b)], + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.cos(a) * np.cos(b)], + ], + [ + [0.5 * np.cos(a) * np.cos(b), -0.5 * np.sin(a) * np.sin(b)], + [-0.5 * np.sin(a) * np.sin(b), 0.5 * np.cos(a) * np.cos(b)], + ], + ] + if diff_method == "finite-diff": + assert np.allclose(hess, expected_hess, atol=10e-2, rtol=0) + else: + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued_postprocessing( + self, dev_name, diff_method, interface, mode, tol + ): + """Test hessian calculation of a vector valued QNode with post-processing""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support second derivative.") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(x): + qml.RX(x[0], wires=0) + qml.RY(x[1], wires=0) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(0)) + + def cost_fn(x): + return x @ jax.numpy.array(circuit(x)) + + x = jnp.array( + [0.76, -0.87], + ) + res = cost_fn(x) + + a, b = x + + expected_res = x @ jnp.array([np.cos(a) * np.cos(b), np.cos(a) * np.cos(b)]) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + grad_fn = jax.grad(cost_fn) + g = grad_fn(x) + + expected_g = [ + np.cos(b) * (np.cos(a) - (a + b) * np.sin(a)), + np.cos(a) * (np.cos(b) - (a + b) * np.sin(b)), + ] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + hess = jax.jacobian(grad_fn)(x) + + expected_hess = [ + [ + -(np.cos(b) * ((a + b) * np.cos(a) + 2 * np.sin(a))), + -(np.cos(b) * np.sin(a)) + (-np.cos(a) + (a + b) * np.sin(a)) * np.sin(b), + ], + [ + -(np.cos(b) * np.sin(a)) + (-np.cos(a) + (a + b) * np.sin(a)) * np.sin(b), + -(np.cos(a) * ((a + b) * np.cos(b) + 2 * np.sin(b))), + ], + ] + + if diff_method == "finite-diff": + assert np.allclose(hess, expected_hess, atol=10e-2, rtol=0) + else: + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued_separate_args( + self, dev_name, diff_method, mode, interface, mocker, tol + ): + """Test hessian calculation of a vector valued QNode that has separate input arguments""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support second derivative.") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode, max_diff=2) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.probs(wires=0) + + a = jnp.array(1.0) + b = jnp.array(2.0) + res = circuit(a, b) + + expected_res = [0.5 + 0.5 * np.cos(a) * np.cos(b), 0.5 - 0.5 * np.cos(a) * np.cos(b)] + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + jac_fn = jax.jacobian(circuit, argnums=[0, 1]) + g = jac_fn(a, b) + + expected_g = np.array( + [ + [-0.5 * np.sin(a) * np.cos(b), -0.5 * np.cos(a) * np.sin(b)], + [0.5 * np.sin(a) * np.cos(b), 0.5 * np.cos(a) * np.sin(b)], + ] + ) + assert np.allclose(g, expected_g.T, atol=tol, rtol=0) + + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + hess = jax.jacobian(jac_fn, argnums=[0, 1])(a, b) + + if diff_method == "backprop": + spy.assert_not_called() + elif diff_method == "parameter-shift": + spy.assert_called() + + expected_hess = np.array( + [ + [ + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.cos(a) * np.cos(b)], + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.sin(a) * np.sin(b)], + ], + [ + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.sin(a) * np.sin(b)], + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.cos(a) * np.cos(b)], + ], + ] + ) + if diff_method == "finite-diff": + assert np.allclose(hess, expected_hess, atol=10e-2, rtol=0) + else: + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_state(self, dev_name, diff_method, mode, interface, tol): + """Test that the state can be returned and differentiated""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support states") + + dev = qml.device(dev_name, wires=2) + + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.state() + + def cost_fn(x, y): + res = circuit(x, y) + assert res.dtype is np.dtype("complex128") + probs = jnp.abs(res) ** 2 + return probs[0] + probs[2] + + res = cost_fn(x, y) + + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + res = jax.grad(cost_fn, argnums=[0, 1])(x, y) + expected = np.array([-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2]) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_projector(self, dev_name, diff_method, mode, interface, tol): + """Test that the variance of a projector is correctly returned""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support projectors") + + dev = qml.device(dev_name, wires=2) + P = jnp.array([1]) + x, y = 0.765, -0.654 + + @qnode(dev, diff_method=diff_method, interface=interface, mode=mode) + def circuit(x, y): + qml.RX(x, wires=0) + qml.RY(y, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.var(qml.Projector(P, wires=0) @ qml.PauliX(1)) + + res = circuit(x, y) + expected = 0.25 * np.sin(x / 2) ** 2 * (3 + np.cos(2 * y) + 2 * np.cos(x) * np.sin(y) ** 2) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.grad(circuit, argnums=[0, 1])(x, y) + expected = np.array( + [ + 0.5 * np.sin(x) * (np.cos(x / 2) ** 2 + np.cos(2 * y) * np.sin(x / 2) ** 2), + -2 * np.cos(y) * np.sin(x / 2) ** 4 * np.sin(y), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + +# TODO: Add CV test when return types and custom diff are compatible +@pytest.mark.xfail(reason="CV variables with new return types.") +@pytest.mark.parametrize( + "diff_method,kwargs", + [["finite-diff", {}], ("parameter-shift", {}), ("parameter-shift", {"force_order2": True})], +) +@pytest.mark.parametrize("interface", ["jax-jit", "jax"]) +class TestCV: + """Tests for CV integration""" + + def test_first_order_observable(self, diff_method, kwargs, interface, tol): + """Test variance of a first order CV observable""" + dev = qml.device("default.gaussian", wires=1) + + r = 0.543 + phi = -0.654 + + @qnode(dev, interface=interface, diff_method=diff_method, **kwargs) + def circuit(r, phi): + qml.Squeezing(r, 0, wires=0) + qml.Rotation(phi, wires=0) + return qml.var(qml.X(0)) + + res = circuit(r, phi) + expected = np.exp(2 * r) * np.sin(phi) ** 2 + np.exp(-2 * r) * np.cos(phi) ** 2 + assert np.allclose(res, expected, atol=tol, rtol=0) + + # circuit jacobians + res = jax.grad(circuit, argnums=[0, 1])(r, phi) + expected = np.array( + [ + 2 * np.exp(2 * r) * np.sin(phi) ** 2 - 2 * np.exp(-2 * r) * np.cos(phi) ** 2, + 2 * np.sinh(2 * r) * np.sin(2 * phi), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_second_order_observable(self, diff_method, kwargs, interface, tol): + """Test variance of a second order CV expectation value""" + dev = qml.device("default.gaussian", wires=1) + + n = 0.12 + a = 0.765 + + @qnode(dev, interface=interface, diff_method=diff_method, **kwargs) + def circuit(n, a): + qml.ThermalState(n, wires=0) + qml.Displacement(a, 0, wires=0) + return qml.var(qml.NumberOperator(0)) + + res = circuit(n, a) + expected = n**2 + n + np.abs(a) ** 2 * (1 + 2 * n) + assert np.allclose(res, expected, atol=tol, rtol=0) + + # circuit jacobians + res = jax.grad(circuit, argnums=[0, 1])(n, a) + expected = np.array([2 * a**2 + 2 * n + 1, 2 * a * (2 * n + 1)]) + assert np.allclose(res, expected, atol=tol, rtol=0) + + +# TODO: add support for fwd mode to JAX-JIT +@pytest.mark.parametrize("interface", ["jax-python"]) +def test_adjoint_reuse_device_state(mocker, interface): + """Tests that the jax interface reuses the device state for adjoint differentiation""" + dev = qml.device("default.qubit", wires=1) + + @qnode(dev, interface=interface, diff_method="adjoint") + def circ(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + spy = mocker.spy(dev, "adjoint_jacobian") + + grad = jax.grad(circ)(1.0) + assert circ.device.num_executions == 1 + + spy.assert_called_with(mocker.ANY, use_device_state=True) + + +@pytest.mark.parametrize("dev_name,diff_method,mode,interface", qubit_device_and_diff_method) +class TestTapeExpansion: + """Test that tape expansion within the QNode integrates correctly + with the JAX interface""" + + @pytest.mark.parametrize("max_diff", [1, 2]) + def test_gradient_expansion_trainable_only( + self, dev_name, diff_method, mode, max_diff, interface, mocker + ): + """Test that a *supported* operation with no gradient recipe is only + expanded for parameter-shift and finite-differences when it is trainable.""" + if diff_method not in ("parameter-shift", "finite-diff"): + pytest.skip("Only supports gradient transforms") + + if max_diff == 2 and interface == "jax-jit": + pytest.skip("TODO: add Hessian support to JAX-JIT.") + + dev = qml.device(dev_name, wires=1) + + class PhaseShift(qml.PhaseShift): + grad_method = None + + def expand(self): + with qml.tape.QuantumTape() as tape: + qml.RY(3 * self.data[0], wires=self.wires) + return tape + + @qnode(dev, diff_method=diff_method, mode=mode, max_diff=max_diff, interface=interface) + def circuit(x, y): + qml.Hadamard(wires=0) + PhaseShift(x, wires=0) + PhaseShift(2 * y, wires=0) + return qml.expval(qml.PauliX(0)) + + spy = mocker.spy(circuit.device, "batch_execute") + x = jnp.array(0.5) + y = jnp.array(0.7) + circuit(x, y) + + spy = mocker.spy(circuit.gradient_fn, "transform_fn") + res = jax.grad(circuit, argnums=[0])(x, y) + + input_tape = spy.call_args[0][0] + assert len(input_tape.operations) == 3 + assert input_tape.operations[1].name == "RY" + assert input_tape.operations[1].data[0] == 3 * x + assert input_tape.operations[2].name == "PhaseShift" + assert input_tape.operations[2].grad_method is None + + @pytest.mark.parametrize("max_diff", [1, 2]) + def test_hamiltonian_expansion_analytic( + self, dev_name, diff_method, mode, max_diff, interface, mocker + ): + """Test that the Hamiltonian is not expanded if there + are non-commuting groups and the number of shots is None + and the first and second order gradients are correctly evaluated""" + if diff_method == "adjoint": + pytest.skip("The adjoint method does not yet support Hamiltonians") + + if max_diff == 2 and interface == "jax-jit": + pytest.skip("TODO: add Hessian support to JAX-JIT.") + + dev = qml.device(dev_name, wires=3, shots=None) + spy = mocker.spy(qml.transforms, "hamiltonian_expand") + obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] + + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode, max_diff=max_diff) + def circuit(data, weights, coeffs): + weights = weights.reshape(1, -1) + qml.templates.AngleEmbedding(data, wires=[0, 1]) + qml.templates.BasicEntanglerLayers(weights, wires=[0, 1]) + return qml.expval(qml.Hamiltonian(coeffs, obs)) + + d = jnp.array([0.1, 0.2]) + w = jnp.array([0.654, -0.734]) + c = jnp.array([-0.6543, 0.24, 0.54]) + + # test output + res = circuit(d, w, c) + expected = c[2] * np.cos(d[1] + w[1]) - c[1] * np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]) + assert np.allclose(res, expected) + spy.assert_not_called() + + # test gradients + grad = jax.grad(circuit, argnums=[1, 2])(d, w, c) + expected_w = [ + -c[1] * np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), + -c[1] * np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]) - c[2] * np.sin(d[1] + w[1]), + ] + expected_c = [0, -np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]), np.cos(d[1] + w[1])] + assert np.allclose(grad[0], expected_w) + assert np.allclose(grad[1], expected_c) + + # TODO: Add parameter shift when the bug with trainable params and hamiltonian_grad is solved. + # test second-order derivatives + if diff_method in "backprop" and max_diff == 2: + grad2_c = jax.jacobian(jax.grad(circuit, argnums=[2]), argnums=[2])(d, w, c) + assert np.allclose(grad2_c, 0) + + grad2_w_c = jax.jacobian(jax.grad(circuit, argnums=[1]), argnums=[2])(d, w, c) + expected = [0, -np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), 0], [ + 0, + -np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]), + -np.sin(d[1] + w[1]), + ] + assert np.allclose(grad2_w_c, expected) + + @pytest.mark.parametrize("max_diff", [1, 2]) + def test_hamiltonian_expansion_finite_shots( + self, dev_name, diff_method, mode, interface, max_diff, mocker + ): + """Test that the Hamiltonian is expanded if there + are non-commuting groups and the number of shots is finite + and the first and second order gradients are correctly evaluated""" + if diff_method in ("adjoint", "backprop", "finite-diff"): + pytest.skip("The adjoint and backprop methods do not yet support sampling") + + if max_diff == 2 and interface == "jax-jit": + pytest.skip("TODO: add Hessian support to JAX-JIT.") + + dev = qml.device(dev_name, wires=3, shots=50000) + spy = mocker.spy(qml.transforms, "hamiltonian_expand") + obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] + + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode, max_diff=max_diff) + def circuit(data, weights, coeffs): + weights = weights.reshape(1, -1) + qml.templates.AngleEmbedding(data, wires=[0, 1]) + qml.templates.BasicEntanglerLayers(weights, wires=[0, 1]) + H = qml.Hamiltonian(coeffs, obs) + H.compute_grouping() + return qml.expval(H) + + d = jnp.array([0.1, 0.2]) + w = jnp.array([0.654, -0.734]) + c = jnp.array([-0.6543, 0.24, 0.54]) + + # test output + res = circuit(d, w, c) + expected = c[2] * np.cos(d[1] + w[1]) - c[1] * np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]) + assert np.allclose(res, expected, atol=0.1) + spy.assert_called() + + # test gradients + grad = jax.grad(circuit, argnums=[1, 2])(d, w, c) + expected_w = [ + -c[1] * np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), + -c[1] * np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]) - c[2] * np.sin(d[1] + w[1]), + ] + expected_c = [0, -np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]), np.cos(d[1] + w[1])] + assert np.allclose(grad[0], expected_w, atol=0.1) + assert np.allclose(grad[1], expected_c, atol=0.1) + + # TODO: Fix hamiltonian grad for parameter shift and jax + # # test second-order derivatives + # if diff_method == "parameter-shift" and max_diff == 2: + + # grad2_c = jax.jacobian(jax.grad(circuit, argnum=2), argnum=2)(d, w, c) + # assert np.allclose(grad2_c, 0, atol=0.1) + + # grad2_w_c = jax.jacobian(jax.grad(circuit, argnum=1), argnum=2)(d, w, c) + # expected = [0, -np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), 0], [ + # 0, + # -np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]), + # -np.sin(d[1] + w[1]), + # ] + # assert np.allclose(grad2_w_c, expected, atol=0.1) + + +jit_qubit_device_and_diff_method = [ + ["default.qubit", "backprop", "forward"], + # Jit + ["default.qubit", "finite-diff", "backward"], + ["default.qubit", "parameter-shift", "backward"], + # TODO: + # ["default.qubit", "adjoint", "forward"], + ["default.qubit", "adjoint", "backward"], +] + +jacobian_fn = [jax.jacobian, jax.jacrev, jax.jacfwd] + + +@pytest.mark.parametrize("dev_name,diff_method,mode", jit_qubit_device_and_diff_method) +@pytest.mark.parametrize("jacobian", jacobian_fn) +class TestJIT: + """Test JAX JIT integration with the QNode and automatic resolution of the + correct JAX interface variant.""" + + def test_gradient(self, dev_name, diff_method, mode, jacobian, tol): + """Test derivative calculation of a scalar valued QNode""" + dev = qml.device(dev_name, wires=1) + + if diff_method == "adjoint": + pytest.xfail(reason="The adjoint method is not using host-callback currently") + + @qnode(dev, diff_method=diff_method, interface="jax-jit", mode=mode) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + g = jax.jit(jacobian(circuit))(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + @pytest.mark.filterwarnings( + "ignore:Requested adjoint differentiation to be computed with finite shots." + ) + @pytest.mark.parametrize("shots", [10, 1000]) + def test_hermitian(self, dev_name, diff_method, mode, shots, jacobian): + """Test that the jax device works with qml.Hermitian and jitting even + when shots>0. + + Note: before a fix, the cases of shots=10 and shots=1000 were failing due + to different reasons, hence the parametrization in the test. + """ + dev = qml.device(dev_name, wires=2, shots=shots) + + if diff_method == "backprop": + pytest.skip("Backpropagation is unsupported if shots > 0.") + + if diff_method == "adjoint" and mode == "forward": + pytest.skip("Computing the gradient for Hermitian is not supported with adjoint.") + + projector = np.array(qml.matrix(qml.PauliZ(0) @ qml.PauliZ(1))) + + @qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circ(projector): + return qml.expval(qml.Hermitian(projector, wires=range(2))) + + assert jnp.allclose(jax.jit(circ)(projector), 1) + + @pytest.mark.filterwarnings( + "ignore:Requested adjoint differentiation to be computed with finite shots." + ) + @pytest.mark.parametrize("shots", [10, 1000]) + def test_probs_obs_none(self, dev_name, diff_method, mode, shots, jacobian): + """Test that the jax device works with qml.probs, a MeasurementProcess + that has obs=None even when shots>0.""" + dev = qml.device(dev_name, wires=2, shots=shots) + + if diff_method == "backprop": + pytest.skip("Backpropagation is unsupported if shots > 0.") + + @qml.qnode(dev, interface="jax", diff_method="parameter-shift") + def circuit(): + return qml.probs(wires=0) + + assert jnp.allclose(circuit(), jnp.array([1.0, 0.0])) + + @pytest.mark.xfail( + reason="Non-trainable parameters are not being correctly unwrapped by the interface" + ) + def test_gradient_subset(self, dev_name, diff_method, mode, jacobian, tol): + """Test derivative calculation of a scalar valued QNode with respect + to a subset of arguments""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.RZ(c, wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.jit(jacobian(circuit, argnums=[0, 1]))(a, b, 0.0) + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + def test_gradient_scalar_cost_vector_valued_qnode( + self, dev_name, diff_method, mode, jacobian, tol + ): + """Test derivative calculation of a scalar valued cost function that + uses the output of a vector-valued QNode""" + dev = qml.device(dev_name, wires=2) + + if diff_method == "adjoint": + pytest.xfail(reason="The adjoint method is not using host-callback currently") + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.probs(wires=[1]) + + def cost(x, y, idx): + res = circuit(x, y) + return res[idx] + + x = jnp.array(1.0) + y = jnp.array(2.0) + expected_g = ( + np.array([-np.sin(x) * np.cos(y) / 2, np.cos(y) * np.sin(x) / 2]), + np.array([-np.cos(x) * np.sin(y) / 2, np.cos(x) * np.sin(y) / 2]), + ) + + idx = 0 + g0 = jax.jit(jacobian(cost, argnums=0))(x, y, idx) + g1 = jax.jit(jacobian(cost, argnums=1))(x, y, idx) + assert np.allclose(g0, expected_g[0][idx], atol=tol, rtol=0) + assert np.allclose(g1, expected_g[1][idx], atol=tol, rtol=0) + + idx = 1 + g0 = jax.jit(jacobian(cost, argnums=0))(x, y, idx) + g1 = jax.jit(jacobian(cost, argnums=1))(x, y, idx) + + assert np.allclose(g0, expected_g[0][idx], atol=tol, rtol=0) + assert np.allclose(g1, expected_g[1][idx], atol=tol, rtol=0) + + def test_matrix_parameter(self, dev_name, diff_method, mode, jacobian, tol): + """Test that the JAX-JIT interface works correctly with a matrix + parameter""" + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circ(p, U): + qml.QubitUnitary(U, wires=0) + qml.RY(p, wires=0) + return qml.expval(qml.PauliZ(0)) + + p = jnp.array(0.1) + U = jnp.array([[0, 1], [1, 0]]) + res = jax.jit(circ)(p, U) + assert np.allclose(res, -np.cos(p), atol=tol, rtol=0) + + jac_fn = jax.jit(jax.grad(circ, argnums=(0))) + res = jac_fn(p, U) + assert np.allclose(res, np.sin(p), atol=tol, rtol=0) + + +qubit_device_and_diff_method_and_mode = [ + ["default.qubit", "backprop", "forward"], + ["default.qubit", "finite-diff", "backward"], + ["default.qubit", "parameter-shift", "backward"], + # TODO: forward mode + # ["default.qubit", "adjoint", "forward"], + ["default.qubit", "adjoint", "backward"], +] + + +@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method_and_mode) +@pytest.mark.parametrize("shots", [None, 10000]) +@pytest.mark.parametrize("jacobian", jacobian_fn) +class TestReturn: + """Class to test the shape of the Grad/Jacobian/Hessian with different return types.""" + + def test_grad_single_measurement_param(self, dev_name, diff_method, mode, jacobian, shots): + """For one measurement and one param, the gradient is a float.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=1, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a, wires=0) + qml.RX(0.2, wires=0) + return qml.expval(qml.PauliZ(0)) + + a = jax.numpy.array(0.1) + + grad = jax.jit(jacobian(circuit))(a) + + assert isinstance(grad, jax.numpy.ndarray) + assert grad.shape == () + + def test_grad_single_measurement_multiple_param( + self, dev_name, diff_method, mode, jacobian, shots + ): + """For one measurement and multiple param, the gradient is a tuple of arrays.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=1, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.expval(qml.PauliZ(0)) + + a = jax.numpy.array(0.1) + b = jax.numpy.array(0.2) + + grad = jax.jit(jacobian(circuit, argnums=[0, 1]))(a, b) + + assert isinstance(grad, tuple) + assert len(grad) == 2 + assert grad[0].shape == () + assert grad[1].shape == () + + def test_grad_single_measurement_multiple_param_array( + self, dev_name, diff_method, mode, jacobian, shots + ): + """For one measurement and multiple param as a single array params, the gradient is an array.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=1, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + a = jax.numpy.array([0.1, 0.2]) + + grad = jax.jit(jacobian(circuit))(a) + + assert isinstance(grad, jax.numpy.ndarray) + assert grad.shape == (2,) + + def test_jacobian_single_measurement_param_probs( + self, dev_name, diff_method, mode, jacobian, shots + ): + """For a multi dimensional measurement (probs), check that a single array is returned with the correct + dimension""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a, wires=0) + qml.RX(0.2, wires=0) + return qml.probs(wires=[0, 1]) + + a = jax.numpy.array(0.1) + + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, jax.numpy.ndarray) + assert jac.shape == (4,) + + def test_jacobian_single_measurement_probs_multiple_param( + self, dev_name, diff_method, mode, jacobian, shots + ): + """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with + the correct dimension""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.probs(wires=[0, 1]) + + a = jax.numpy.array(0.1) + b = jax.numpy.array(0.2) + + jac = jax.jit(jacobian(circuit, argnums=[0, 1]))(a, b) + + assert isinstance(jac, tuple) + + assert isinstance(jac[0], jax.numpy.ndarray) + assert jac[0].shape == (4,) + + assert isinstance(jac[1], jax.numpy.ndarray) + assert jac[1].shape == (4,) + + def test_jacobian_single_measurement_probs_multiple_param_single_array( + self, dev_name, diff_method, mode, jacobian, shots + ): + """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with + the correct dimension""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.probs(wires=[0, 1]) + + a = jax.numpy.array([0.1, 0.2]) + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, jax.numpy.ndarray) + assert jac.shape == (4, 2) + + def test_jacobian_expval_expval_multiple_params( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with multiple params return a tuple of arrays.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + dev = qml.device(dev_name, wires=2, shots=shots) + + par_0 = jax.numpy.array(0.1) + par_1 = jax.numpy.array(0.2) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.expval(qml.PauliZ(0)) + + jac = jax.jit(jacobian(circuit, argnums=[0, 1]))(par_0, par_1) + + assert isinstance(jac, tuple) + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == () + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == () + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == () + + def test_jacobian_expval_expval_multiple_params_array( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with a multiple params array return a single array.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.expval(qml.PauliZ(0)) + + a = jax.numpy.array([0.1, 0.2]) + + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, tuple) + assert len(jac) == 2 # measurements + + assert isinstance(jac[0], jax.numpy.ndarray) + assert jac[0].shape == (2,) + + assert isinstance(jac[1], jax.numpy.ndarray) + assert jac[1].shape == (2,) + + def test_jacobian_var_var_multiple_params(self, dev_name, diff_method, mode, jacobian, shots): + """The jacobian of multiple measurements with multiple params return a tuple of arrays.""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of var.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + par_0 = jax.numpy.array(0.1) + par_1 = jax.numpy.array(0.2) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.var(qml.PauliZ(0)) + + jac = jax.jit(jacobian(circuit, argnums=[0, 1]))(par_0, par_1) + + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == () + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == () + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == () + + def test_jacobian_var_var_multiple_params_array( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with a multiple params array return a single array.""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of var.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.var(qml.PauliZ(0)) + + a = jax.numpy.array([0.1, 0.2]) + + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, tuple) + assert len(jac) == 2 # measurements + + assert isinstance(jac[0], jax.numpy.ndarray) + assert jac[0].shape == (2,) + + assert isinstance(jac[1], jax.numpy.ndarray) + assert jac[1].shape == (2,) + + def test_jacobian_multiple_measurement_single_param( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with a single params return an array.""" + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + dev = qml.device(dev_name, wires=2, shots=shots) + + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a, wires=0) + qml.RX(0.2, wires=0) + return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) + + a = jax.numpy.array(0.1) + + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], jax.numpy.ndarray) + assert jac[0].shape == () + + assert isinstance(jac[1], jax.numpy.ndarray) + assert jac[1].shape == (4,) + + def test_jacobian_multiple_measurement_multiple_param( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) + + a = np.array(0.1, requires_grad=True) + b = np.array(0.2, requires_grad=True) + + jac = jax.jit(jacobian(circuit, argnums=[0, 1]))(a, b) + + assert isinstance(jac, tuple) + assert len(jac) == 2 + + assert isinstance(jac[0], tuple) + assert len(jac[0]) == 2 + assert isinstance(jac[0][0], jax.numpy.ndarray) + assert jac[0][0].shape == () + assert isinstance(jac[0][1], jax.numpy.ndarray) + assert jac[0][1].shape == () + + assert isinstance(jac[1], tuple) + assert len(jac[1]) == 2 + assert isinstance(jac[1][0], jax.numpy.ndarray) + assert jac[1][0].shape == (4,) + assert isinstance(jac[1][1], jax.numpy.ndarray) + assert jac[1][1].shape == (4,) + + def test_jacobian_multiple_measurement_multiple_param_array( + self, dev_name, diff_method, mode, jacobian, shots + ): + """The jacobian of multiple measurements with a multiple params array return a single array.""" + if diff_method == "adjoint": + pytest.skip("Test does not supports adjoint because of probabilities.") + if shots is not None and diff_method in ("backprop", "adjoint"): + pytest.skip("Test does not support finite shots and adjoint/backprop") + + dev = qml.device(dev_name, wires=2, shots=shots) + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) + + a = jax.numpy.array([0.1, 0.2]) + + jac = jax.jit(jacobian(circuit))(a) + + assert isinstance(jac, tuple) + assert len(jac) == 2 # measurements + + assert isinstance(jac[0], jax.numpy.ndarray) + assert jac[0].shape == (2,) + + assert isinstance(jac[1], jax.numpy.ndarray) + assert jac[1].shape == (4, 2) diff --git a/tests/returntypes/test_jax_new.py b/tests/returntypes/test_jax_new.py index 33bbdffd276..315f14a8121 100644 --- a/tests/returntypes/test_jax_new.py +++ b/tests/returntypes/test_jax_new.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the jax interface""" +"""Unit tests for the JAX-Python interface""" import sys import pytest @@ -29,16 +29,12 @@ import pennylane as qml from pennylane.gradients import param_shift from pennylane.interfaces import execute, InterfaceUnsupportedError -from pennylane.interfaces.jax_jit import _execute_with_fwd -# TODO: add jax-jit when it supports new return types. -# "jax-jit" -@pytest.mark.parametrize("interface", ["jax-python"]) class TestJaxExecuteUnitTests: """Unit tests for jax execution""" - def test_import_error(self, mocker, interface): + def test_import_error(self, mocker): """Test that an exception is caught on import error""" mock = mocker.patch.object(jax, "custom_jvp") @@ -54,9 +50,9 @@ def test_import_error(self, mocker, interface): match="jax not found. Please install the latest version " "of jax to enable the 'jax' interface", ): - qml.execute([tape], dev, gradient_fn=qml.gradients.param_shift, interface=interface) + qml.execute([tape], dev, gradient_fn=qml.gradients.param_shift, interface="jax-python") - def test_jacobian_options(self, mocker, interface, tol): + def test_jacobian_options(self, mocker, tol): """Test setting jacobian options""" spy = mocker.spy(qml.gradients, "param_shift") @@ -75,7 +71,7 @@ def cost(a, device): device, gradient_fn=param_shift, gradient_kwargs={"shifts": [(np.pi / 4,)] * 2}, - interface=interface, + interface="jax-python", )[0] res = jax.grad(cost)(a, device=dev) @@ -83,7 +79,7 @@ def cost(a, device): for args in spy.call_args_list: assert args[1]["shifts"] == [(np.pi / 4,)] * 2 - def test_incorrect_mode(self, interface): + def test_incorrect_mode(self): """Test that an error is raised if an gradient transform is used with mode=forward""" a = jnp.array([0.1, 0.2]) @@ -101,7 +97,7 @@ def cost(a, device): device, gradient_fn=param_shift, mode="forward", - interface=interface, + interface="jax-python", )[0] with pytest.raises( @@ -109,7 +105,7 @@ def cost(a, device): ): res = jax.grad(cost)(a, device=dev) - def test_unknown_interface(self, interface): + def test_unknown_interface(self): """Test that an error is raised if the interface is unknown""" a = jnp.array([0.1, 0.2]) @@ -131,7 +127,7 @@ def cost(a, device): with pytest.raises(ValueError, match="Unknown interface"): cost(a, device=dev) - def test_forward_mode(self, interface, mocker): + def test_forward_mode(self, mocker): """Test that forward mode uses the `device.execute_and_gradients` pathway""" dev = qml.device("default.qubit", wires=1) spy = mocker.spy(dev, "execute_and_gradients") @@ -146,7 +142,7 @@ def cost(a): [tape], dev, gradient_fn="device", - interface=interface, + interface="jax-python", gradient_kwargs={ "method": "adjoint_jacobian", "use_device_state": True, @@ -160,7 +156,7 @@ def cost(a): assert dev.num_executions == 1 spy.assert_called() - def test_backward_mode(self, interface, mocker): + def test_backward_mode(self, mocker): """Test that backward mode uses the `device.batch_execute` and `device.gradients` pathway""" dev = qml.device("default.qubit", wires=1) spy_execute = mocker.spy(qml.devices.DefaultQubit, "batch_execute") @@ -177,7 +173,7 @@ def cost(a): dev, gradient_fn="device", mode="backward", - interface=interface, + interface="jax-python", gradient_kwargs={"method": "adjoint_jacobian"}, )[0] @@ -192,13 +188,10 @@ def cost(a): spy_gradients.assert_called() -# TODO: add jax-jit when it supports new return types. -# "jax-jit" -@pytest.mark.parametrize("interface", ["jax-python"]) class TestCaching: """Test for caching behaviour""" - def test_cache_maxsize(self, interface, mocker): + def test_cache_maxsize(self, mocker): """Test the cachesize property of the cache""" dev = qml.device("default.qubit", wires=1) spy = mocker.spy(qml.interfaces, "cache_execute") @@ -214,7 +207,7 @@ def cost(a, cachesize): dev, gradient_fn=param_shift, cachesize=cachesize, - interface=interface, + interface="jax-python", )[0] params = jnp.array([0.1, 0.2]) @@ -225,7 +218,7 @@ def cost(a, cachesize): assert cache.currsize == 2 assert len(cache) == 2 - def test_custom_cache(self, interface, mocker): + def test_custom_cache(self, mocker): """Test the use of a custom cache object""" dev = qml.device("default.qubit", wires=1) spy = mocker.spy(qml.interfaces, "cache_execute") @@ -241,7 +234,7 @@ def cost(a, cache): dev, gradient_fn=param_shift, cache=cache, - interface=interface, + interface="jax-python", )[0] custom_cache = {} @@ -251,7 +244,7 @@ def cost(a, cache): cache = spy.call_args[0][1] assert cache is custom_cache - def test_custom_cache_multiple(self, interface, mocker): + def test_custom_cache_multiple(self, mocker): """Test the use of a custom cache object with multiple tapes""" dev = qml.device("default.qubit", wires=1) spy = mocker.spy(qml.interfaces, "cache_execute") @@ -275,7 +268,7 @@ def cost(a, b, cache): dev, gradient_fn=param_shift, cache=cache, - interface=interface, + interface="jax-python", ) return res[0] @@ -285,7 +278,7 @@ def cost(a, b, cache): cache = spy.call_args[0][1] assert cache is custom_cache - def test_caching_param_shift(self, interface, tol): + def test_caching_param_shift(self, tol): """Test that, when using parameter-shift transform, caching produces the optimum number of evaluations.""" dev = qml.device("default.qubit", wires=1) @@ -301,7 +294,7 @@ def cost(a, cache): dev, gradient_fn=param_shift, cache=cache, - interface=interface, + interface="jax-python", )[0] # Without caching, 5 evaluations are required to compute @@ -330,7 +323,7 @@ def cost(a, cache): assert dev.num_executions == 15 assert not np.allclose(grad1, grad2, atol=tol, rtol=0) - def test_caching_adjoint_backward(self, interface): + def test_caching_adjoint_backward(self): """Test that caching produces the optimum number of adjoint evaluations when mode=backward""" dev = qml.device("default.qubit", wires=2) @@ -349,7 +342,7 @@ def cost(a, cache): gradient_fn="device", cache=cache, mode="backward", - interface=interface, + interface="jax-python", gradient_kwargs={"method": "adjoint_jacobian"}, )[0] @@ -382,15 +375,12 @@ def cost(a, cache): ] -# TODO: add jax-jit when it supports new return types. -# "jax-jit" @pytest.mark.parametrize("execute_kwargs", execute_kwargs) -@pytest.mark.parametrize("interface", ["jax-python"]) class TestJaxExecuteIntegration: """Test the jax interface execute function integrates well for both forward and backward execution""" - def test_execution(self, execute_kwargs, interface): + def test_execution(self, execute_kwargs): """Test execution""" dev = qml.device("default.qubit", wires=1) @@ -405,7 +395,7 @@ def cost(a, b): qml.RX(b, wires=0) qml.expval(qml.PauliZ(0)) - return execute([tape1, tape2], dev, interface=interface, **execute_kwargs) + return execute([tape1, tape2], dev, interface="jax-python", **execute_kwargs) a = jnp.array(0.1) b = jnp.array(0.2) @@ -415,7 +405,7 @@ def cost(a, b): assert res[0].shape == () assert res[1].shape == () - def test_scalar_jacobian(self, execute_kwargs, interface, tol): + def test_scalar_jacobian(self, execute_kwargs, tol): """Test scalar jacobian calculation""" a = jnp.array(0.1) dev = qml.device("default.qubit", wires=2) @@ -424,7 +414,7 @@ def cost(a): with qml.tape.QuantumTape() as tape: qml.RY(a, wires=0) qml.expval(qml.PauliZ(0)) - return execute([tape], dev, interface=interface, **execute_kwargs)[0] + return execute([tape], dev, interface="jax-python", **execute_kwargs)[0] res = jax.grad(cost)(a) assert res.shape == () @@ -441,7 +431,7 @@ def cost(a): assert expected.shape == () assert np.allclose(res, expected, atol=tol, rtol=0) - def test_reusing_quantum_tape(self, execute_kwargs, interface, tol): + def test_reusing_quantum_tape(self, execute_kwargs, tol): """Test re-using a quantum tape by passing new parameters""" a = jnp.array(0.1) b = jnp.array(0.2) @@ -466,7 +456,7 @@ def cost(a, b): # required_length) and the tape produces incorrect results. tape._update() tape.set_parameters([a, b]) - return execute([tape], dev, interface=interface, **execute_kwargs)[0] + return execute([tape], dev, interface="jax-python", **execute_kwargs)[0] jac_fn = jax.grad(cost) jac = jac_fn(a, b) @@ -485,7 +475,7 @@ def cost(a, b): expected = -2 * np.sin(2 * a) assert np.allclose(jac, expected, atol=tol, rtol=0) - def test_grad_with_backward_mode(self, execute_kwargs, interface): + def test_grad_with_backward_mode(self, execute_kwargs): """Test jax grad for adjoint diff method in backward mode""" dev = qml.device("default.qubit", wires=2) params = jnp.array([0.1, 0.2, 0.3]) @@ -499,18 +489,15 @@ def cost(a, cache): qml.expval(qml.PauliZ(0)) res = qml.interfaces.execute( - [tape], dev, cache=cache, interface=interface, **execute_kwargs + [tape], dev, cache=cache, interface="jax-python", **execute_kwargs )[0] return res - if interface == "jax-jit": - cost = jax.jit(cost) - results = jax.grad(cost)(params, cache=None) for r, e in zip(results, expected_results): assert jnp.allclose(r, e, atol=1e-7) - def test_classical_processing_single_tape(self, execute_kwargs, interface, tol): + def test_classical_processing_single_tape(self, execute_kwargs, tol): """Test classical processing within the quantum tape for a single tape""" a = jnp.array(0.1) b = jnp.array(0.2) @@ -523,13 +510,13 @@ def cost(a, b, c, device): qml.RX(c + c**2 + jnp.sin(a), wires=0) qml.expval(qml.PauliZ(0)) - return execute([tape], device, interface=interface, **execute_kwargs)[0] + return execute([tape], device, interface="jax-python", **execute_kwargs)[0] dev = qml.device("default.qubit", wires=2) res = jax.grad(cost, argnums=(0, 1, 2))(a, b, c, device=dev) assert len(res) == 3 - def test_classical_processing_multiple_tapes(self, execute_kwargs, interface, tol): + def test_classical_processing_multiple_tapes(self, execute_kwargs, tol): """Test classical processing within the quantum tape for multiple tapes""" dev = qml.device("default.qubit", wires=2) @@ -549,14 +536,14 @@ def cost_fn(x): qml.expval(qml.PauliZ(0)) result = execute( - tapes=[tape1, tape2], device=dev, interface=interface, **execute_kwargs + tapes=[tape1, tape2], device=dev, interface="jax-python", **execute_kwargs ) return result[0] + result[1] - 7 * result[1] res = jax.grad(cost_fn)(params) assert res.shape == (2,) - def test_multiple_tapes_output(self, execute_kwargs, interface, tol): + def test_multiple_tapes_output(self, execute_kwargs, tol): """Test the output types for the execution of multiple quantum tapes""" dev = qml.device("default.qubit", wires=2) params = jax.numpy.array([0.3, 0.2]) @@ -574,14 +561,16 @@ def cost_fn(x): qml.RX(2 * x[1], wires=[1]) qml.expval(qml.PauliZ(0)) - return execute(tapes=[tape1, tape2], device=dev, interface=interface, **execute_kwargs) + return execute( + tapes=[tape1, tape2], device=dev, interface="jax-python", **execute_kwargs + ) res = cost_fn(params) assert isinstance(res, list) assert all(isinstance(r, jnp.ndarray) for r in res) assert all(r.shape == () for r in res) - def test_matrix_parameter(self, execute_kwargs, interface, tol): + def test_matrix_parameter(self, execute_kwargs, tol): """Test that the jax interface works correctly with a matrix parameter""" a = jnp.array(0.1) @@ -594,7 +583,7 @@ def cost(a, U, device): qml.expval(qml.PauliZ(0)) tape.trainable_params = [0] - return execute([tape], device, interface=interface, **execute_kwargs)[0] + return execute([tape], device, interface="jax-python", **execute_kwargs)[0] dev = qml.device("default.qubit", wires=2) res = cost(a, U, device=dev) @@ -604,7 +593,7 @@ def cost(a, U, device): res = jac_fn(a, U, device=dev) assert np.allclose(res, np.sin(a), atol=tol, rtol=0) - def test_differentiable_expand(self, execute_kwargs, interface, tol): + def test_differentiable_expand(self, execute_kwargs, tol): """Test that operation and nested tapes expansion is differentiable""" @@ -628,7 +617,7 @@ def cost_fn(a, p, device): qml.expval(qml.PauliX(0)) tape = tape.expand(stop_at=lambda obj: device.supports_operation(obj.name)) - return execute([tape], device, interface=interface, **execute_kwargs)[0] + return execute([tape], device, interface="jax-python", **execute_kwargs)[0] a = jnp.array(0.1) p = jnp.array([0.1, 0.2, 0.3]) @@ -654,7 +643,7 @@ def cost_fn(a, p, device): ) assert np.allclose(res, expected, atol=tol, rtol=0) - def test_independent_expval(self, execute_kwargs, interface): + def test_independent_expval(self, execute_kwargs): """Tests computing an expectation value that is independent of trainable parameters.""" dev = qml.device("default.qubit", wires=2) @@ -667,7 +656,7 @@ def cost(a, cache): qml.RY(a[2], wires=0) qml.expval(qml.PauliZ(1)) - res = execute([tape], dev, cache=cache, interface=interface, **execute_kwargs) + res = execute([tape], dev, cache=cache, interface="jax-python", **execute_kwargs) return res[0] res = jax.grad(cost)(params, cache=None) @@ -884,278 +873,3 @@ def cost(x, y, device, interface, ek): assert res[1][1][0].shape == (2,) assert isinstance(res[1][1][1], jax.numpy.ndarray) assert res[1][1][1].shape == (2,) - - -# TODO: add jit tests -@pytest.mark.xfail(reason="Add interface for Jax-jit") -@pytest.mark.parametrize("execute_kwargs", execute_kwargs) -class TestVectorValuedJIT: - """Test vector-valued returns for the JAX jit Python interface.""" - - @pytest.mark.parametrize( - "ret_type, shape", - [ - ([qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))], (2,)), - ([qml.probs(wires=[0, 1])], (1, 4)), - ], - ) - def test_shapes(self, execute_kwargs, ret_type, shape): - """Test the shape of the result of vector-valued QNodes.""" - adjoint = execute_kwargs.get("gradient_kwargs", {}).get("method", "") == "adjoint_jacobian" - if adjoint: - pytest.skip("The adjoint diff method doesn't support probabilities.") - - dev = qml.device("default.qubit", wires=2) - params = jnp.array([0.1, 0.2, 0.3]) - - idx = 0 - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - for r in ret_type: - qml.apply(r) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - ) - return res[0] - - res = cost(params, cache=None) - assert res.shape == shape - - def test_independent_expval(self, execute_kwargs): - """Tests computing an expectation value that is independent trainable - parameters.""" - dev = qml.device("default.qubit", wires=2) - params = jnp.array([0.1, 0.2, 0.3]) - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - qml.expval(qml.PauliZ(1)) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - ) - return res[0][0] - - res = jax.grad(cost)(params, cache=None) - assert res.shape == (3,) - - ret_and_output_dim = [ - ([qml.probs(wires=0)], (1, 2)), - ([qml.state()], (1, 4)), - ([qml.density_matrix(wires=0)], (1, 2, 2)), - # Multi measurements - ([qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))], (2,)), - ([qml.var(qml.PauliZ(0)), qml.var(qml.PauliZ(1))], (2,)), - ([qml.probs(wires=0), qml.probs(wires=1)], (2, 2)), - ] - - @pytest.mark.parametrize("ret, out_dim", ret_and_output_dim) - def test_vector_valued_qnode(self, execute_kwargs, ret, out_dim): - """Tests the shape of vector-valued QNode results.""" - - dev = qml.device("default.qubit", wires=2) - params = jnp.array([0.1, 0.2, 0.3]) - - grad_meth = ( - execute_kwargs["gradient_kwargs"]["method"] - if "gradient_kwargs" in execute_kwargs - else "" - ) - if "adjoint" in grad_meth and any( - r.return_type - in (qml.measurements.Probability, qml.measurements.State, qml.measurements.Variance) - for r in ret - ): - pytest.skip("Adjoint does not support probs") - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - - for r in ret: - qml.apply(r) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - )[0] - return res - - res = cost(params, cache=None) - assert res.shape == out_dim - - def test_qnode_sample(self, execute_kwargs): - """Tests computing multiple expectation values in a tape.""" - dev = qml.device("default.qubit", wires=2, shots=10) - params = jnp.array([0.1, 0.2, 0.3]) - - grad_meth = ( - execute_kwargs["gradient_kwargs"]["method"] - if "gradient_kwargs" in execute_kwargs - else "" - ) - if "adjoint" in grad_meth or "backprop" in grad_meth: - pytest.skip("Adjoint does not support probs") - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - qml.sample(qml.PauliZ(0)) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - )[0] - return res - - res = cost(params, cache=None) - assert res.shape == (1, dev.shots) - - def test_multiple_expvals_grad(self, execute_kwargs): - """Tests computing multiple expectation values in a tape.""" - dev = qml.device("default.qubit", wires=2) - params = jnp.array([0.1, 0.2, 0.3]) - fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" - if fwd_mode: - pytest.skip("The forward mode is tested separately as it should raise an error.") - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliZ(1)) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - )[0] - return res[0] + res[1] - - res = jax.grad(cost)(params, cache=None) - assert res.shape == (3,) - - def test_multi_tape_jacobian_probs_expvals(self, execute_kwargs): - """Test the jacobian computation with multiple tapes with probability - and expectation value computations.""" - fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" - if fwd_mode: - pytest.skip("The forward mode is tested separately as it should raise an error.") - - adjoint = execute_kwargs.get("gradient_kwargs", {}).get("method", "") == "adjoint_jacobian" - if adjoint: - pytest.skip("The adjoint diff method doesn't support probabilities.") - - def cost(x, y, device, interface, ek): - with qml.tape.QuantumTape() as tape1: - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliZ(1)) - - with qml.tape.QuantumTape() as tape2: - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - qml.probs(wires=[0]) - qml.probs(wires=[1]) - - return qml.execute([tape1, tape2], device, **ek, interface=interface)[0] - - dev = qml.device("default.qubit", wires=2) - x = jnp.array(0.543) - y = jnp.array(-0.654) - - x_ = np.array(0.543) - y_ = np.array(-0.654) - - res = cost(x, y, dev, interface="jax-jit", ek=execute_kwargs) - - exp = cost(x_, y_, dev, interface="autograd", ek=execute_kwargs) - - for r, e in zip(res, exp): - assert jnp.allclose(r, e, atol=1e-7) - - def test_multiple_expvals_raises_fwd_device_grad(self, execute_kwargs): - """Tests computing multiple expectation values in a tape.""" - fwd_mode = execute_kwargs.get("mode", "not forward") == "forward" - if not fwd_mode: - pytest.skip("Forward mode is not turned on.") - - dev = qml.device("default.qubit", wires=2) - params = jnp.array([0.1, 0.2, 0.3]) - - def cost(a, cache): - with qml.tape.QuantumTape() as tape: - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - qml.RY(a[2], wires=0) - qml.expval(qml.PauliZ(0)) - qml.expval(qml.PauliZ(1)) - - res = qml.interfaces.execute( - [tape], dev, cache=cache, interface="jax-jit", **execute_kwargs - ) - return res[0] - - with pytest.raises(InterfaceUnsupportedError): - jax.jacobian(cost)(params, cache=None) - - def test_assertion_error_fwd(self, execute_kwargs): - """Test that an assertion is raised if by chance there is a difference - in the number of tapes and the number of parameters sequences passed - to _execute_with_fwd.""" - a = 0.3 - b = 0.3 - - with qml.tape.QuantumTape() as tape: - qml.RY(a, wires=0) - qml.RY(b, wires=0) - qml.expval(qml.PauliZ(0)) - - device = qml.device("default.qubit", wires=2) - - # Create arguments for 2 tapes - params = [[0.2], [0.3]] - - # But pass only 1 tape - tapes = [tape] - - with pytest.raises(AssertionError): - _execute_with_fwd( - params, - tapes=tapes, - device=device, - execute_fn=lambda a: a, # Some dummy function - gradient_kwargs=None, - _n=1, - ) - - -# TODO: add jit tests -@pytest.mark.xfail(reason="Add interface for Jax-jit") -def test_diff_method_None_jit(): - """Test that jitted execution works when `gradient_fn=None`.""" - - dev = qml.device("default.qubit.jax", wires=1, shots=10) - - @jax.jit - def wrapper(x): - with qml.tape.QuantumTape() as tape: - qml.RX(x, wires=0) - qml.expval(qml.PauliZ(0)) - - return qml.execute([tape], dev, gradient_fn=None) - - assert jnp.allclose(wrapper(jnp.array(0.0))[0], 1.0) diff --git a/tests/returntypes/test_jax_qnode_new.py b/tests/returntypes/test_jax_qnode_new.py index 91769dc91bc..050bbd6c192 100644 --- a/tests/returntypes/test_jax_qnode_new.py +++ b/tests/returntypes/test_jax_qnode_new.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for using the jax interface and its jittable variant with -a QNode""" +"""Integration tests for using the JAX-Python interface with a QNode""" import pytest @@ -30,14 +29,6 @@ ["default.qubit", "adjoint", "backward", "jax-python"], ] -# TODO: add jit to the tests -""" -# Jit -["default.qubit", "finite-diff", "backward", "jax-jit"], -["default.qubit", "parameter-shift", "backward", "jax-jit"], -["default.qubit", "adjoint", "forward", "jax-jit"], -["default.qubit", "adjoint", "backward", "jax-jit"], -""" pytestmark = pytest.mark.jax jax = pytest.importorskip("jax") @@ -60,7 +51,7 @@ def test_execution_with_interface(self, dev_name, diff_method, mode, interface): dev = qml.device(dev_name, wires=1) - @qnode(dev, interface=interface, diff_method=diff_method) + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -233,10 +224,6 @@ def circuit(a): qml.RX(a[1], wires=0) return qml.expval(qml.PauliZ(0)) - if diff_method in {"finite-diff", "parameter-shift"} and interface == "jax-jit": - # No jax.jacobian support for call - pytest.xfail(reason="batching rules are implemented only for id_tap, not for call.") - jax.jacobian(circuit)(a) for args in spy.call_args_list: @@ -253,15 +240,6 @@ def circuit(a): ["default.qubit", "adjoint", "backward", "jax-python"], ] -# TODO: add jit to the tests -""" -# Jit -["default.qubit", "finite-diff", "backward", "jax-jit"], -["default.qubit", "parameter-shift", "backward", "jax-jit"], -["default.qubit", "adjoint", "forward", "jax-jit"], -["default.qubit", "adjoint", "backward", "jax-jit"], -""" - @pytest.mark.parametrize("dev_name,diff_method,mode,interface", vv_qubit_device_and_diff_method) class TestVectorValuedQNode: @@ -671,9 +649,7 @@ def circuit(x, y): assert np.allclose(jac[1][1], expected[1][1], atol=tol, rtol=0) -# TODO: add jit tests -# "jax-jit" -@pytest.mark.parametrize("interface", ["jax-python"]) +@pytest.mark.parametrize("interface", ["jax", "jax-python"]) class TestShotsIntegration: """Test that the QNode correctly changes shot value, and remains differentiable.""" @@ -682,7 +658,6 @@ def test_diff_method_None(self, interface): """Test jax device works with diff_method=None.""" dev = qml.device("default.qubit.jax", wires=1, shots=10) - @jax.jit @qml.qnode(dev, diff_method=None, interface=interface) def circuit(x): qml.RX(x, wires=0) @@ -838,12 +813,12 @@ def expand(self): qml.templates.StronglyEntanglingLayers(*self.parameters, self.wires) return tape - @qnode(dev, interface=interface, diff_method=diff_method) + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) def circuit1(weights): Template(weights, wires=[0, 1]) return qml.expval(qml.PauliZ(0)) - @qnode(dev, interface=interface, diff_method=diff_method) + @qnode(dev, interface=interface, diff_method=diff_method, mode=mode) def circuit2(data, weights): qml.templates.AngleEmbedding(jnp.stack([data, 0.7]), wires=[0, 1]) Template(weights, wires=[0, 1]) @@ -868,6 +843,23 @@ def cost(weights): assert len(res) == 2 + +hessian_qubit_device_and_diff_method = [ + ["default.qubit", "backprop", "forward", "jax"], + # Python + ["default.qubit", "finite-diff", "backward", "jax-python"], + ["default.qubit", "parameter-shift", "backward", "jax-python"], + ["default.qubit", "adjoint", "forward", "jax-python"], + ["default.qubit", "adjoint", "backward", "jax-python"], +] + + +@pytest.mark.parametrize( + "dev_name,diff_method,mode,interface", hessian_qubit_device_and_diff_method +) +class TestQubitIntegrationHigherOrder: + """Tests that ensure various qubit circuits integrate correctly when computing higher-order derivatives""" + def test_second_derivative(self, dev_name, diff_method, mode, interface, tol): """Test second derivative calculation of a scalar-valued QNode""" @@ -1169,7 +1161,7 @@ def circuit(x, y): "diff_method,kwargs", [["finite-diff", {}], ("parameter-shift", {}), ("parameter-shift", {"force_order2": True})], ) -@pytest.mark.parametrize("interface", ["jax-jit", "jax-python"]) +@pytest.mark.parametrize("interface", ["jax", "jax-python"]) class TestCV: """Tests for CV integration""" @@ -1223,7 +1215,6 @@ def circuit(n, a): assert np.allclose(res, expected, atol=tol, rtol=0) -# TODO: Add "jax-jit" @pytest.mark.parametrize("interface", ["jax-python"]) def test_adjoint_reuse_device_state(mocker, interface): """Tests that the jax interface reuses the device state for adjoint differentiation""" @@ -1400,844 +1391,3 @@ def circuit(data, weights, coeffs): # -np.sin(d[1] + w[1]), # ] # assert np.allclose(grad2_w_c, expected, atol=0.1) - - -jit_qubit_device_and_diff_method = [ - ["default.qubit", "backprop", "forward"], - # Jit - ["default.qubit", "finite-diff", "backward"], - ["default.qubit", "parameter-shift", "backward"], - ["default.qubit", "adjoint", "forward"], - ["default.qubit", "adjoint", "backward"], -] - - -# TODO: Add interface for Jax-jit and create a new test file. -@pytest.mark.xfail(reason="Add interface for Jax-jit") -@pytest.mark.parametrize("dev_name,diff_method,mode", jit_qubit_device_and_diff_method) -class TestJIT: - """Test JAX JIT integration with the QNode and automatic resolution of the - correct JAX interface variant.""" - - def test_gradient(self, dev_name, diff_method, mode, tol): - """Test derivative calculation of a scalar valued QNode""" - dev = qml.device(dev_name, wires=1) - - if diff_method == "adjoint": - pytest.xfail(reason="The adjoint method is not using host-callback currently") - - @jax.jit - @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) - def circuit(x): - qml.RY(x[0], wires=0) - qml.RX(x[1], wires=0) - return qml.expval(qml.PauliZ(0)) - - x = jnp.array([1.0, 2.0]) - res = circuit(x) - g = jax.grad(circuit)(x) - - a, b = x - - expected_res = np.cos(a) * np.cos(b) - assert np.allclose(res, expected_res, atol=tol, rtol=0) - - expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] - assert np.allclose(g, expected_g, atol=tol, rtol=0) - - @pytest.mark.filterwarnings( - "ignore:Requested adjoint differentiation to be computed with finite shots." - ) - @pytest.mark.parametrize("shots", [10, 1000]) - def test_hermitian(self, dev_name, diff_method, mode, shots): - """Test that the jax device works with qml.Hermitian and jitting even - when shots>0. - - Note: before a fix, the cases of shots=10 and shots=1000 were failing due - to different reasons, hence the parametrization in the test. - """ - dev = qml.device(dev_name, wires=2, shots=shots) - - if diff_method == "backprop": - pytest.skip("Backpropagation is unsupported if shots > 0.") - - if diff_method == "adjoint" and mode == "forward": - pytest.skip("Computing the gradient for Hermitian is not supported with adjoint.") - - projector = np.array(qml.matrix(qml.PauliZ(0) @ qml.PauliZ(1))) - - @jax.jit - @qml.qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circ(projector): - return qml.expval(qml.Hermitian(projector, wires=range(2))) - - assert jnp.allclose(circ(projector), 1) - - @pytest.mark.filterwarnings( - "ignore:Requested adjoint differentiation to be computed with finite shots." - ) - @pytest.mark.parametrize("shots", [10, 1000]) - def test_probs_obs_none(self, dev_name, diff_method, mode, shots): - """Test that the jax device works with qml.probs, a MeasurementProcess - that has obs=None even when shots>0.""" - dev = qml.device(dev_name, wires=2, shots=shots) - - if diff_method == "backprop": - pytest.skip("Backpropagation is unsupported if shots > 0.") - - @qml.qnode(dev, interface="jax", diff_method="parameter-shift") - def circuit(): - return qml.probs(wires=0) - - assert jnp.allclose(circuit(), jnp.array([1.0, 0.0])) - - @pytest.mark.xfail( - reason="Non-trainable parameters are not being correctly unwrapped by the interface" - ) - def test_gradient_subset(self, dev_name, diff_method, mode, tol): - """Test derivative calculation of a scalar valued QNode with respect - to a subset of arguments""" - a = jnp.array(0.1) - b = jnp.array(0.2) - - dev = qml.device(dev_name, wires=1) - - @jax.jit - @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) - def circuit(a, b): - qml.RY(a, wires=0) - qml.RX(b, wires=0) - qml.RZ(c, wires=0) - return qml.expval(qml.PauliZ(0)) - - res = jax.grad(circuit, argnums=[0, 1])(a, b, 0.0) - - expected_res = np.cos(a) * np.cos(b) - assert np.allclose(res, expected_res, atol=tol, rtol=0) - - expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] - assert np.allclose(g, expected_g, atol=tol, rtol=0) - - def test_gradient_scalar_cost_vector_valued_qnode(self, dev_name, diff_method, mode, tol): - """Test derivative calculation of a scalar valued cost function that - uses the output of a vector-valued QNode""" - dev = qml.device(dev_name, wires=2) - - if diff_method == "adjoint": - pytest.xfail(reason="The adjoint method is not using host-callback currently") - - @jax.jit - @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.probs(wires=[1]) - - def cost(x, y, idx): - res = circuit(x, y) - return res[idx] - - x = jnp.array(1.0) - y = jnp.array(2.0) - expected_g = ( - np.array([-np.sin(x) * np.cos(y) / 2, np.cos(y) * np.sin(x) / 2]), - np.array([-np.cos(x) * np.sin(y) / 2, np.cos(x) * np.sin(y) / 2]), - ) - - idx = 0 - g0 = jax.grad(cost, argnums=0)(x, y, idx) - g1 = jax.grad(cost, argnums=1)(x, y, idx) - assert np.allclose(g0, expected_g[0][idx], atol=tol, rtol=0) - assert np.allclose(g1, expected_g[1][idx], atol=tol, rtol=0) - - idx = 1 - g0 = jax.grad(cost, argnums=0)(x, y, idx) - g1 = jax.grad(cost, argnums=1)(x, y, idx) - - assert np.allclose(g0, expected_g[0][idx], atol=tol, rtol=0) - assert np.allclose(g1, expected_g[1][idx], atol=tol, rtol=0) - - -qubit_device_and_diff_method_and_mode = [ - ["default.qubit", "backprop", "forward"], - ["default.qubit", "finite-diff", "backward"], - ["default.qubit", "parameter-shift", "backward"], - ["default.qubit", "adjoint", "forward"], - ["default.qubit", "adjoint", "backward"], -] - -jacobian_fn = [jax.jacobian, jax.jacrev, jax.jacfwd] - - -@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method_and_mode) -@pytest.mark.parametrize("shots", [None, 10000]) -class TestReturn: - """Class to test the shape of the Grad/Jacobian/Hessian with different return types.""" - - def test_grad_single_measurement_param(self, dev_name, diff_method, mode, shots): - """For one measurement and one param, the gradient is a float.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=1, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a, wires=0) - qml.RX(0.2, wires=0) - return qml.expval(qml.PauliZ(0)) - - a = jax.numpy.array(0.1) - - grad = jax.grad(circuit)(a) - - assert isinstance(grad, jax.numpy.ndarray) - assert grad.shape == () - - def test_grad_single_measurement_multiple_param(self, dev_name, diff_method, mode, shots): - """For one measurement and multiple param, the gradient is a tuple of arrays.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=1, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a, b): - qml.RY(a, wires=0) - qml.RX(b, wires=0) - return qml.expval(qml.PauliZ(0)) - - a = jax.numpy.array(0.1) - b = jax.numpy.array(0.2) - - grad = jax.grad(circuit, argnums=[0, 1])(a, b) - - assert isinstance(grad, tuple) - assert len(grad) == 2 - assert grad[0].shape == () - assert grad[1].shape == () - - def test_grad_single_measurement_multiple_param_array(self, dev_name, diff_method, mode, shots): - """For one measurement and multiple param as a single array params, the gradient is an array.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=1, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - return qml.expval(qml.PauliZ(0)) - - a = jax.numpy.array([0.1, 0.2]) - - grad = jax.grad(circuit)(a) - - assert isinstance(grad, jax.numpy.ndarray) - assert grad.shape == (2,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_single_measurement_param_probs( - self, dev_name, diff_method, mode, jacobian, shots - ): - """For a multi dimensional measurement (probs), check that a single array is returned with the correct - dimension""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a, wires=0) - qml.RX(0.2, wires=0) - return qml.probs(wires=[0, 1]) - - a = jax.numpy.array(0.1) - - jac = jacobian(circuit)(a) - - assert isinstance(jac, jax.numpy.ndarray) - assert jac.shape == (4,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_single_measurement_probs_multiple_param( - self, dev_name, diff_method, mode, jacobian, shots - ): - """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with - the correct dimension""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a, b): - qml.RY(a, wires=0) - qml.RX(b, wires=0) - return qml.probs(wires=[0, 1]) - - a = jax.numpy.array(0.1) - b = jax.numpy.array(0.2) - - jac = jacobian(circuit, argnums=[0, 1])(a, b) - - assert isinstance(jac, tuple) - - assert isinstance(jac[0], jax.numpy.ndarray) - assert jac[0].shape == (4,) - - assert isinstance(jac[1], jax.numpy.ndarray) - assert jac[1].shape == (4,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_single_measurement_probs_multiple_param_single_array( - self, dev_name, diff_method, mode, jacobian, shots - ): - """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with - the correct dimension""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - return qml.probs(wires=[0, 1]) - - a = jax.numpy.array([0.1, 0.2]) - jac = jacobian(circuit)(a) - - assert isinstance(jac, jax.numpy.ndarray) - assert jac.shape == (4, 2) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_expval_expval_multiple_params( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The hessian of multiple measurements with multiple params return a tuple of arrays.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - dev = qml.device(dev_name, wires=2, shots=shots) - - par_0 = jax.numpy.array(0.1) - par_1 = jax.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.expval(qml.PauliZ(0)) - - jac = jacobian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(jac, tuple) - - assert isinstance(jac[0], tuple) - assert len(jac[0]) == 2 - assert isinstance(jac[0][0], jax.numpy.ndarray) - assert jac[0][0].shape == () - assert isinstance(jac[0][1], jax.numpy.ndarray) - assert jac[0][1].shape == () - - assert isinstance(jac[1], tuple) - assert len(jac[1]) == 2 - assert isinstance(jac[1][0], jax.numpy.ndarray) - assert jac[1][0].shape == () - assert isinstance(jac[1][1], jax.numpy.ndarray) - assert jac[1][1].shape == () - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_expval_expval_multiple_params_array( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The jacobian of multiple measurements with a multiple params array return a single array.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.expval(qml.PauliZ(0)) - - a = jax.numpy.array([0.1, 0.2]) - - jac = jacobian(circuit)(a) - - assert isinstance(jac, tuple) - assert len(jac) == 2 # measurements - - assert isinstance(jac[0], jax.numpy.ndarray) - assert jac[0].shape == (2,) - - assert isinstance(jac[1], jax.numpy.ndarray) - assert jac[1].shape == (2,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_var_var_multiple_params(self, dev_name, diff_method, mode, jacobian, shots): - """The hessian of multiple measurements with multiple params return a tuple of arrays.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of var.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - par_0 = jax.numpy.array(0.1) - par_1 = jax.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.var(qml.PauliZ(0)) - - jac = jacobian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(jac, tuple) - assert len(jac) == 2 - - assert isinstance(jac[0], tuple) - assert len(jac[0]) == 2 - assert isinstance(jac[0][0], jax.numpy.ndarray) - assert jac[0][0].shape == () - assert isinstance(jac[0][1], jax.numpy.ndarray) - assert jac[0][1].shape == () - - assert isinstance(jac[1], tuple) - assert len(jac[1]) == 2 - assert isinstance(jac[1][0], jax.numpy.ndarray) - assert jac[1][0].shape == () - assert isinstance(jac[1][1], jax.numpy.ndarray) - assert jac[1][1].shape == () - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_var_var_multiple_params_array( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The jacobian of multiple measurements with a multiple params array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of var.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.var(qml.PauliZ(0)) - - a = jax.numpy.array([0.1, 0.2]) - - jac = jacobian(circuit)(a) - - assert isinstance(jac, tuple) - assert len(jac) == 2 # measurements - - assert isinstance(jac[0], jax.numpy.ndarray) - assert jac[0].shape == (2,) - - assert isinstance(jac[1], jax.numpy.ndarray) - assert jac[1].shape == (2,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_multiple_measurement_single_param( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The jacobian of multiple measurements with a single params return an array.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - dev = qml.device(dev_name, wires=2, shots=shots) - - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a, wires=0) - qml.RX(0.2, wires=0) - return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) - - a = jax.numpy.array(0.1) - - jac = jacobian(circuit)(a) - - assert isinstance(jac, tuple) - assert len(jac) == 2 - - assert isinstance(jac[0], jax.numpy.ndarray) - assert jac[0].shape == () - - assert isinstance(jac[1], jax.numpy.ndarray) - assert jac[1].shape == (4,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_multiple_measurement_multiple_param( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a, b): - qml.RY(a, wires=0) - qml.RX(b, wires=0) - return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) - - a = np.array(0.1, requires_grad=True) - b = np.array(0.2, requires_grad=True) - - jac = jacobian(circuit, argnums=[0, 1])(a, b) - - assert isinstance(jac, tuple) - assert len(jac) == 2 - - assert isinstance(jac[0], tuple) - assert len(jac[0]) == 2 - assert isinstance(jac[0][0], jax.numpy.ndarray) - assert jac[0][0].shape == () - assert isinstance(jac[0][1], jax.numpy.ndarray) - assert jac[0][1].shape == () - - assert isinstance(jac[1], tuple) - assert len(jac[1]) == 2 - assert isinstance(jac[1][0], jax.numpy.ndarray) - assert jac[1][0].shape == (4,) - assert isinstance(jac[1][1], jax.numpy.ndarray) - assert jac[1][1].shape == (4,) - - @pytest.mark.parametrize("jacobian", jacobian_fn) - def test_jacobian_multiple_measurement_multiple_param_array( - self, dev_name, diff_method, mode, jacobian, shots - ): - """The jacobian of multiple measurements with a multiple params array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because of probabilities.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - @qnode(dev, interface="jax", diff_method=diff_method, mode=mode) - def circuit(a): - qml.RY(a[0], wires=0) - qml.RX(a[1], wires=0) - return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1]) - - a = jax.numpy.array([0.1, 0.2]) - - jac = jacobian(circuit)(a) - - assert isinstance(jac, tuple) - assert len(jac) == 2 # measurements - - assert isinstance(jac[0], jax.numpy.ndarray) - assert jac[0].shape == (2,) - - assert isinstance(jac[1], jax.numpy.ndarray) - assert jac[1].shape == (4, 2) - - def test_hessian_expval_multiple_params(self, dev_name, diff_method, mode, shots): - """The hessian of single a measurement with multiple params return a tuple of arrays.""" - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - dev = qml.device(dev_name, wires=2, shots=shots) - - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - - par_0 = jax.numpy.array(0.1) - par_1 = jax.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)) - - hess = jax.hessian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], tuple) - assert len(hess[0]) == 2 - assert isinstance(hess[0][0], jax.numpy.ndarray) - assert hess[0][0].shape == () - assert hess[0][1].shape == () - - assert isinstance(hess[1], tuple) - assert len(hess[1]) == 2 - assert isinstance(hess[1][0], jax.numpy.ndarray) - assert hess[1][0].shape == () - assert hess[1][1].shape == () - - def test_hessian_expval_multiple_param_array(self, dev_name, diff_method, mode, shots): - """The hessian of single measurement with a multiple params array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - params = jax.numpy.array([0.1, 0.2]) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x): - qml.RX(x[0], wires=[0]) - qml.RY(x[1], wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)) - - hess = jax.hessian(circuit)(params) - - assert isinstance(hess, jax.numpy.ndarray) - assert hess.shape == (2, 2) - - def test_hessian_var_multiple_params(self, dev_name, diff_method, mode, shots): - """The hessian of single a measurement with multiple params return a tuple of arrays.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - dev = qml.device(dev_name, wires=2, shots=shots) - - par_0 = jax.numpy.array(0.1) - par_1 = jax.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)) - - hess = jax.hessian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], tuple) - assert len(hess[0]) == 2 - assert isinstance(hess[0][0], jax.numpy.ndarray) - assert hess[0][0].shape == () - assert hess[0][1].shape == () - - assert isinstance(hess[1], tuple) - assert len(hess[1]) == 2 - assert isinstance(hess[1][0], jax.numpy.ndarray) - assert hess[1][0].shape == () - assert hess[1][1].shape == () - - def test_hessian_var_multiple_param_array(self, dev_name, diff_method, mode, shots): - """The hessian of single measurement with a multiple params array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - params = jax.numpy.array([0.1, 0.2]) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x): - qml.RX(x[0], wires=[0]) - qml.RY(x[1], wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)) - - hess = jax.hessian(circuit)(params) - - assert isinstance(hess, jax.numpy.ndarray) - assert hess.shape == (2, 2) - - def test_hessian_probs_expval_multiple_params(self, dev_name, diff_method, mode, shots): - """The hessian of multiple measurements with multiple params return a tuple of arrays.""" - dev = qml.device(dev_name, wires=2, shots=shots) - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - par_0 = jax.numpy.array(0.1) - par_1 = jax.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.probs(wires=[0, 1]) - - hess = jax.hessian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], tuple) - assert len(hess[0]) == 2 - assert isinstance(hess[0][0], tuple) - assert len(hess[0][0]) == 2 - assert isinstance(hess[0][0][0], jax.numpy.ndarray) - assert hess[0][0][0].shape == () - assert isinstance(hess[0][0][1], jax.numpy.ndarray) - assert hess[0][0][1].shape == () - assert isinstance(hess[0][1], tuple) - assert len(hess[0][1]) == 2 - assert isinstance(hess[0][1][0], jax.numpy.ndarray) - assert hess[0][1][0].shape == () - assert isinstance(hess[0][1][1], jax.numpy.ndarray) - assert hess[0][1][1].shape == () - - assert isinstance(hess[1], tuple) - assert len(hess[1]) == 2 - assert isinstance(hess[1][0], tuple) - assert len(hess[1][0]) == 2 - assert isinstance(hess[1][0][0], jax.numpy.ndarray) - assert hess[1][0][0].shape == (4,) - assert isinstance(hess[1][0][1], jax.numpy.ndarray) - assert hess[1][0][1].shape == (4,) - assert isinstance(hess[1][1], tuple) - assert len(hess[1][1]) == 2 - assert isinstance(hess[1][1][0], jax.numpy.ndarray) - assert hess[1][1][0].shape == (4,) - assert isinstance(hess[1][1][1], jax.numpy.ndarray) - assert hess[1][1][1].shape == (4,) - - def test_hessian_expval_probs_multiple_param_array(self, dev_name, diff_method, mode, shots): - """The hessian of multiple measurements with a multiple param array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - params = jax.numpy.array([0.1, 0.2]) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x): - qml.RX(x[0], wires=[0]) - qml.RY(x[1], wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0) @ qml.PauliX(1)), qml.probs(wires=[0, 1]) - - hess = jax.hessian(circuit)(params) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], jax.numpy.ndarray) - assert hess[0].shape == (2, 2) - - assert isinstance(hess[1], jax.numpy.ndarray) - assert hess[1].shape == (4, 2, 2) - - def test_hessian_probs_var_multiple_params(self, dev_name, diff_method, mode, shots): - """The hessian of multiple measurements with multiple params return a tuple of arrays.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - par_0 = qml.numpy.array(0.1) - par_1 = qml.numpy.array(0.2) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x, y): - qml.RX(x, wires=[0]) - qml.RY(y, wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.probs(wires=[0, 1]) - - hess = jax.hessian(circuit, argnums=[0, 1])(par_0, par_1) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], tuple) - assert len(hess[0]) == 2 - assert isinstance(hess[0][0], tuple) - assert len(hess[0][0]) == 2 - assert isinstance(hess[0][0][0], jax.numpy.ndarray) - assert hess[0][0][0].shape == () - assert isinstance(hess[0][0][1], jax.numpy.ndarray) - assert hess[0][0][1].shape == () - assert isinstance(hess[0][1], tuple) - assert len(hess[0][1]) == 2 - assert isinstance(hess[0][1][0], jax.numpy.ndarray) - assert hess[0][1][0].shape == () - assert isinstance(hess[0][1][1], jax.numpy.ndarray) - assert hess[0][1][1].shape == () - - assert isinstance(hess[1], tuple) - assert len(hess[1]) == 2 - assert isinstance(hess[1][0], tuple) - assert len(hess[1][0]) == 2 - assert isinstance(hess[1][0][0], jax.numpy.ndarray) - assert hess[1][0][0].shape == (4,) - assert isinstance(hess[1][0][1], jax.numpy.ndarray) - assert hess[1][0][1].shape == (4,) - assert isinstance(hess[1][1], tuple) - assert len(hess[1][1]) == 2 - assert isinstance(hess[1][1][0], jax.numpy.ndarray) - assert hess[1][1][0].shape == (4,) - assert isinstance(hess[1][1][1], jax.numpy.ndarray) - assert hess[1][1][1].shape == (4,) - - def test_hessian_var_probs_multiple_param_array(self, dev_name, diff_method, mode, shots): - """The hessian of multiple measurements with a multiple param array return a single array.""" - if diff_method == "adjoint": - pytest.skip("Test does not supports adjoint because second order diff.") - if shots is not None and diff_method in ("backprop", "adjoint"): - pytest.skip("Test does not support finite shots and adjoint/backprop") - - dev = qml.device(dev_name, wires=2, shots=shots) - - params = jax.numpy.array([0.1, 0.2]) - - @qnode(dev, interface="jax", diff_method=diff_method, max_diff=2, mode=mode) - def circuit(x): - qml.RX(x[0], wires=[0]) - qml.RY(x[1], wires=[1]) - qml.CNOT(wires=[0, 1]) - return qml.var(qml.PauliZ(0) @ qml.PauliX(1)), qml.probs(wires=[0, 1]) - - hess = jax.hessian(circuit)(params) - - assert isinstance(hess, tuple) - assert len(hess) == 2 - - assert isinstance(hess[0], jax.numpy.ndarray) - assert hess[0].shape == (2, 2) - - assert isinstance(hess[1], jax.numpy.ndarray) - assert hess[1].shape == (4, 2, 2)