From 43b1fe05241816fab9efbc1e61c5c83aac809be4 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Fri, 24 Sep 2021 17:13:59 +0800 Subject: [PATCH 1/9] Add JAX integration tests --- pennylane/transforms/batch_transform.py | 7 +- tests/gradients/test_finite_difference.py | 8 +- tests/gradients/test_parameter_shift.py | 8 +- tests/interfaces/test_batch_jax_qnode.py | 1105 +++++++++++++++++++++ tests/transforms/test_batch_transform.py | 5 +- 5 files changed, 1120 insertions(+), 13 deletions(-) create mode 100644 tests/interfaces/test_batch_jax_qnode.py diff --git a/pennylane/transforms/batch_transform.py b/pennylane/transforms/batch_transform.py index 731ab540486..adf1a1cbfe8 100644 --- a/pennylane/transforms/batch_transform.py +++ b/pennylane/transforms/batch_transform.py @@ -221,6 +221,7 @@ def default_qnode_wrapper(self, qnode, targs, tkwargs): the QNode, and return the output of the applying the tape transform to the QNode's constructed tape. """ + max_diff = tkwargs.pop("max_diff", 2) def _wrapper(*args, **kwargs): qnode.construct(args, kwargs) @@ -242,7 +243,11 @@ def _wrapper(*args, **kwargs): gradient_fn = qml.gradients.finite_diff res = qml.execute( - tapes, device=qnode.device, gradient_fn=gradient_fn, interface=interface, max_diff=2 + tapes, + device=qnode.device, + gradient_fn=gradient_fn, + interface=interface, + max_diff=max_diff, ) return processing_fn(res) diff --git a/tests/gradients/test_finite_difference.py b/tests/gradients/test_finite_difference.py index 76ee2d861f3..cb4678d127b 100644 --- a/tests/gradients/test_finite_difference.py +++ b/tests/gradients/test_finite_difference.py @@ -568,7 +568,7 @@ def test_torch(self, approx_order, strategy, tol): torch = pytest.importorskip("torch") from pennylane.interfaces.torch import TorchInterface - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit.torch", wires=2) params = torch.tensor([0.543, -0.654], dtype=torch.float64, requires_grad=True) with TorchInterface.apply(qml.tape.QubitParamShiftTape()) as tape: @@ -578,7 +578,7 @@ def test_torch(self, approx_order, strategy, tol): qml.expval(qml.PauliZ(0) @ qml.PauliX(1)) tapes, fn = finite_diff(tape, n=1, approx_order=approx_order, strategy=strategy) - jac = fn([t.execute(dev) for t in tapes]) + jac = fn(dev.batch_execute(tapes)) cost = torch.sum(jac) cost.backward() hess = params.grad @@ -606,7 +606,7 @@ def test_jax(self, approx_order, strategy, tol): config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit.jax", wires=2) params = jnp.array([0.543, -0.654]) def cost_fn(x): @@ -618,7 +618,7 @@ def cost_fn(x): tape.trainable_params = {0, 1} tapes, fn = finite_diff(tape, n=1, approx_order=approx_order, strategy=strategy) - jac = fn([t.execute(dev) for t in tapes]) + jac = fn(dev.batch_execute(tapes)) return jac res = jax.jacobian(cost_fn)(params) diff --git a/tests/gradients/test_parameter_shift.py b/tests/gradients/test_parameter_shift.py index d59d9e17420..25a17dddb74 100644 --- a/tests/gradients/test_parameter_shift.py +++ b/tests/gradients/test_parameter_shift.py @@ -858,7 +858,7 @@ def test_torch(self, tol): torch = pytest.importorskip("torch") from pennylane.interfaces.torch import TorchInterface - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit.torch", wires=2) params = torch.tensor([0.543, -0.654], dtype=torch.float64, requires_grad=True) with TorchInterface.apply(qml.tape.QubitParamShiftTape()) as tape: @@ -868,7 +868,7 @@ def test_torch(self, tol): qml.var(qml.PauliZ(0) @ qml.PauliX(1)) tapes, fn = qml.gradients.param_shift(tape) - jac = fn([t.execute(dev) for t in tapes]) + jac = fn(dev.batch_execute(tapes)) cost = jac[0, 0] cost.backward() hess = params.grad @@ -891,7 +891,7 @@ def test_jax(self, tol): config.update("jax_enable_x64", True) - dev = qml.device("default.qubit", wires=2) + dev = qml.device("default.qubit.jax", wires=2) params = jnp.array([0.543, -0.654]) def cost_fn(x): @@ -903,7 +903,7 @@ def cost_fn(x): tape.trainable_params = {0, 1} tapes, fn = qml.gradients.param_shift(tape) - jac = fn([t.execute(dev) for t in tapes]) + jac = fn(dev.batch_execute(tapes)) return jac res = jax.jacobian(cost_fn)(params) diff --git a/tests/interfaces/test_batch_jax_qnode.py b/tests/interfaces/test_batch_jax_qnode.py new file mode 100644 index 00000000000..2dbdaa5287c --- /dev/null +++ b/tests/interfaces/test_batch_jax_qnode.py @@ -0,0 +1,1105 @@ +# Copyright 2018-2020 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for using the jax interface with a QNode""" +import pytest +from pennylane import numpy as np + +import pennylane as qml +from pennylane.beta import qnode, QNode +from pennylane.tape import JacobianTape + +qubit_device_and_diff_method = [ + ["default.qubit", "finite-diff", "backward"], + ["default.qubit", "parameter-shift", "backward"], + ["default.qubit", "backprop", "forward"], + ["default.qubit", "adjoint", "forward"], + ["default.qubit", "adjoint", "backward"], +] + +jax = pytest.importorskip("jax") +jnp = jax.numpy + + +from jax.config import config + +config.update("jax_enable_x64", True) + + +@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) +class TestQNode: + """Test that using the QNode with Autograd integrates with the PennyLane stack""" + + def test_execution_with_interface(self, dev_name, diff_method, mode): + """Test execution works with the interface""" + if diff_method == "backprop": + pytest.skip("Test does not support backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, interface="jax", diff_method=diff_method) + def circuit(a): + qml.RY(a, wires=0) + qml.RX(0.2, wires=0) + return qml.expval(qml.PauliZ(0)) + + a = np.array(0.1, requires_grad=True) + circuit(a) + + assert circuit.interface == "jax" + + # the tape is able to deduce trainable parameters + assert circuit.qtape.trainable_params == {0} + + # gradients should work + grad = jax.grad(circuit)(a) + assert isinstance(grad, jnp.DeviceArray) + assert grad.shape == tuple() + + def test_jacobian(self, dev_name, diff_method, mode, mocker, tol): + """Test jacobian calculation""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + if diff_method == "parameter-shift": + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + elif diff_method == "finite-diff": + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = np.array(0.1, requires_grad=True) + b = np.array(0.2, requires_grad=True) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliY(1))] + + res = circuit(a, b) + + assert circuit.qtape.trainable_params == {0, 1} + assert res.shape == (2,) + + expected = [np.cos(a), -np.cos(a) * np.sin(b)] + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.jacobian(circuit, argnums=[0, 1])(a, b) + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]).T + assert np.allclose(res, expected, atol=tol, rtol=0) + + if diff_method in ("parameter-shift", "finite-diff"): + spy.assert_called() + + def test_jacobian_no_evaluate(self, dev_name, diff_method, mode, mocker, tol): + """Test jacobian calculation when no prior circuit evaluation has been performed""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + if diff_method == "parameter-shift": + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + elif diff_method == "finite-diff": + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = np.array(0.1, requires_grad=True) + b = np.array(0.2, requires_grad=True) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliY(1))] + + jac_fn = jax.jacobian(circuit, argnums=[0, 1]) + res = jac_fn(a, b) + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]).T + assert np.allclose(res, expected, atol=tol, rtol=0) + + if diff_method in ("parameter-shift", "finite-diff"): + spy.assert_called() + + # call the Jacobian with new parameters + a = np.array(0.6, requires_grad=True) + b = np.array(0.832, requires_grad=True) + + res = jac_fn(a, b) + expected = np.array([[-np.sin(a), 0], [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)]]).T + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_jacobian_options(self, dev_name, diff_method, mode, mocker, tol): + """Test setting jacobian options""" + if diff_method == "backprop": + pytest.skip("Test does not support backprop") + + spy = mocker.spy(qml.gradients.finite_diff, "transform_fn") + + a = np.array([0.1, 0.2], requires_grad=True) + + dev = qml.device("default.qubit", wires=1) + + @qnode(dev, interface="jax", h=1e-8, order=2) + def circuit(a): + qml.RY(a[0], wires=0) + qml.RX(a[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + jax.jacobian(circuit)(a) + + for args in spy.call_args_list: + assert args[1]["order"] == 2 + assert args[1]["h"] == 1e-8 + + def test_changing_trainability(self, dev_name, diff_method, mode, mocker, tol): + """Test changing the trainability of parameters changes the + number of differentiation requests made""" + if diff_method != "parameter-shift": + pytest.skip("Test only supports parameter-shift") + + a = jnp.array(0.1) + b = jnp.array(0.2) + + dev = qml.device("default.qubit", wires=2) + + @qnode(dev, interface="jax", diff_method="parameter-shift") + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.Hamiltonian([1, 1], [qml.PauliZ(0), qml.PauliY(1)])) + + grad_fn = jax.grad(circuit, argnums=[0, 1]) + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + res = grad_fn(a, b) + + # the tape has reported both arguments as trainable + assert circuit.qtape.trainable_params == {0, 1} + + expected = [-np.sin(a) + np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)] + assert np.allclose(res, expected, atol=tol, rtol=0) + + # The parameter-shift rule has been called for each argument + assert len(spy.spy_return[0]) == 4 + + # make the second QNode argument a constant + grad_fn = jax.grad(circuit, argnums=0) + res = grad_fn(a, b) + + # the tape has reported only the first argument as trainable + assert circuit.qtape.trainable_params == {0} + + expected = [-np.sin(a) + np.sin(a) * np.sin(b)] + assert np.allclose(res, expected, atol=tol, rtol=0) + + # The parameter-shift rule has been called only once + assert len(spy.spy_return[0]) == 2 + + # trainability also updates on evaluation + a = np.array(0.54, requires_grad=False) + b = np.array(0.8, requires_grad=True) + circuit(a, b) + assert circuit.qtape.trainable_params == {1} + + def test_classical_processing(self, dev_name, diff_method, mode, tol): + """Test classical processing within the quantum tape""" + a = jnp.array(0.1) + b = jnp.array(0.2) + c = jnp.array(0.3) + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, b, c): + qml.RY(a * c, wires=0) + qml.RZ(b, wires=0) + qml.RX(c + c ** 2 + jnp.sin(a), wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.grad(circuit, argnums=[0, 2])(a, b, c) + + if diff_method == "finite-diff": + assert circuit.qtape.trainable_params == {0, 2} + + assert len(res) == 2 + + def test_matrix_parameter(self, dev_name, diff_method, mode, tol): + """Test that the jax interface works correctly + with a matrix parameter""" + U = jnp.array([[0, 1], [1, 0]]) + a = jnp.array(0.1) + + dev = qml.device(dev_name, wires=2) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(U, a): + qml.QubitUnitary(U, wires=0) + qml.RY(a, wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.grad(circuit, argnums=1)(U, a) + assert np.allclose(res, np.sin(a), atol=tol, rtol=0) + + if diff_method == "finite-diff": + assert circuit.qtape.trainable_params == {1} + + def test_differentiable_expand(self, dev_name, diff_method, mode, tol): + """Test that operation and nested tape expansion + is differentiable""" + + class U3(qml.U3): + def expand(self): + theta, phi, lam = self.data + wires = self.wires + + with JacobianTape() as tape: + qml.Rot(lam, theta, -lam, wires=wires) + qml.PhaseShift(phi + lam, wires=wires) + + return tape + + dev = qml.device(dev_name, wires=1) + a = jnp.array(0.1) + p = jnp.array([0.1, 0.2, 0.3]) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, p): + qml.RX(a, wires=0) + U3(p[0], p[1], p[2], wires=0) + return qml.expval(qml.PauliX(0)) + + res = circuit(a, p) + expected = np.cos(a) * np.cos(p[1]) * np.sin(p[0]) + np.sin(a) * ( + np.cos(p[2]) * np.sin(p[1]) + np.cos(p[0]) * np.cos(p[1]) * np.sin(p[2]) + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.grad(circuit, argnums=1)(a, p) + expected = np.array( + [ + np.cos(p[1]) * (np.cos(a) * np.cos(p[0]) - np.sin(a) * np.sin(p[0]) * np.sin(p[2])), + np.cos(p[1]) * np.cos(p[2]) * np.sin(a) + - np.sin(p[1]) + * (np.cos(a) * np.sin(p[0]) + np.cos(p[0]) * np.sin(a) * np.sin(p[2])), + np.sin(a) + * (np.cos(p[0]) * np.cos(p[1]) * np.cos(p[2]) - np.sin(p[1]) * np.sin(p[2])), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + +class TestShotsIntegration: + """Test that the QNode correctly changes shot value, and + remains differentiable.""" + + def test_changing_shots(self, mocker, tol): + """Test that changing shots works on execution""" + dev = qml.device("default.qubit", wires=2, shots=None) + a, b = jnp.array([0.543, -0.654]) + + @qnode(dev, diff_method=qml.gradients.param_shift, interface="jax") + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + spy = mocker.spy(dev, "sample") + + # execute with device default shots (None) + res = circuit(a, b) + assert np.allclose(res, -np.cos(a) * np.sin(b), atol=tol, rtol=0) + spy.assert_not_called() + + # execute with shots=100 + res = circuit(a, b, shots=100) + spy.assert_called() + assert spy.spy_return.shape == (100,) + + # device state has been unaffected + assert dev.shots is None + spy = mocker.spy(dev, "sample") + res = circuit(a, b) + assert np.allclose(res, -np.cos(a) * np.sin(b), atol=tol, rtol=0) + spy.assert_not_called() + + def test_gradient_integration(self, tol): + """Test that temporarily setting the shots works + for gradient computations""" + dev = qml.device("default.qubit", wires=2, shots=100) + a, b = jnp.array([0.543, -0.654]) + + @qnode(dev, diff_method=qml.gradients.param_shift, interface="jax") + def cost_fn(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + res = jax.grad(cost_fn, argnums=[0, 1])(a, b, shots=30000) + assert dev.shots == 100 + + expected = [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)] + assert np.allclose(res, expected, atol=0.1, rtol=0) + + def test_update_diff_method(self, mocker, tol): + """Test that temporarily setting the shots updates the diff method""" + dev = qml.device("default.qubit", wires=2, shots=100) + a, b = jnp.array([0.543, -0.654]) + + spy = mocker.spy(qml, "execute") + + @qnode(dev, interface="jax") + def cost_fn(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(1)) + + # since we are using finite shots, parameter-shift will + # be chosen + assert cost_fn.gradient_fn is qml.gradients.param_shift + + cost_fn(a, b) + assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift + + # if we set the shots to None, backprop can now be used + cost_fn(a, b, shots=None) + assert spy.call_args[1]["gradient_fn"] == "backprop" + + # original QNode settings are unaffected + assert cost_fn.gradient_fn is qml.gradients.param_shift + cost_fn(a, b) + assert spy.call_args[1]["gradient_fn"] is qml.gradients.param_shift + + +@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) +class TestQubitIntegration: + """Tests that ensure various qubit circuits integrate correctly""" + + def test_probability_differentiation(self, dev_name, diff_method, mode, tol): + """Tests correct output shape and evaluation for a tape + with a single prob output""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.probs(wires=[1]) + + res = jax.jacobian(circuit, argnums=[0, 1])(x, y) + + expected = np.array( + [ + [-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2], + [np.cos(y) * np.sin(x) / 2, np.cos(x) * np.sin(y) / 2], + ] + ) + assert np.allclose(res, expected.T, atol=tol, rtol=0) + + def test_multiple_probability_differentiation(self, dev_name, diff_method, mode, tol): + """Tests correct output shape and evaluation for a tape + with multiple prob outputs""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.probs(wires=[0]), qml.probs(wires=[1]) + + res = circuit(x, y) + + expected = np.array( + [ + [np.cos(x / 2) ** 2, np.sin(x / 2) ** 2], + [(1 + np.cos(x) * np.cos(y)) / 2, (1 - np.cos(x) * np.cos(y)) / 2], + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.jacobian(circuit, argnums=[0, 1])(x, y) + expected = np.array( + [ + [[-np.sin(x) / 2, 0], [-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2]], + [ + [np.sin(x) / 2, 0], + [np.cos(y) * np.sin(x) / 2, np.cos(x) * np.sin(y) / 2], + ], + ] + ) + + assert np.allclose(res, expected.T, atol=tol, rtol=0) + + @pytest.mark.xfail(reason="Line 230 in QubitDevice: results = self._asarray(results) fails") + def test_ragged_differentiation(self, dev_name, diff_method, mode, tol): + """Tests correct output shape and evaluation for a tape + with prob and expval outputs""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return [qml.expval(qml.PauliZ(0)), qml.probs(wires=[1])] + + res = circuit(x, y) + + expected = np.array( + [np.cos(x), (1 + np.cos(x) * np.cos(y)) / 2, (1 - np.cos(x) * np.cos(y)) / 2] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.jacobian(circuit, argnums=[0, 1])(x, y) + expected = np.array( + [ + [-np.sin(x), 0], + [-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2], + [np.cos(y) * np.sin(x) / 2, np.cos(x) * np.sin(y) / 2], + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + @pytest.mark.xfail(reason="Line 230 in QubitDevice: results = self._asarray(results) fails") + def test_ragged_differentiation_variance(self, dev_name, diff_method, mode, tol): + """Tests correct output shape and evaluation for a tape + with prob and variance outputs""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + dev = qml.device(dev_name, wires=2) + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return [qml.var(qml.PauliZ(0)), qml.probs(wires=[1])] + + res = circuit(x, y) + + expected = np.array( + [np.sin(x) ** 2, (1 + np.cos(x) * np.cos(y)) / 2, (1 - np.cos(x) * np.cos(y)) / 2] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.jacobian(circuit, argnums=[0, 1])(x, y) + expected = np.array( + [ + [2 * np.cos(x) * np.sin(x), 0], + [-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2], + [np.cos(y) * np.sin(x) / 2, np.cos(x) * np.sin(y) / 2], + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_sampling(self, dev_name, diff_method, mode): + """Test sampling works as expected""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + if mode == "forward": + pytest.skip("Sampling not possible with forward mode differentiation.") + + dev = qml.device(dev_name, wires=2, shots=10) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(): + qml.Hadamard(wires=[0]) + qml.CNOT(wires=[0, 1]) + return [qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliX(1))] + + res = circuit() + + assert res.shape == (2, 10) + assert isinstance(res, jnp.DeviceArray) + + def test_chained_qnodes(self, dev_name, diff_method, mode): + """Test that the gradient of chained QNodes works without error""" + dev = qml.device(dev_name, wires=2) + + class Template(qml.templates.StronglyEntanglingLayers): + def expand(self): + with qml.tape.QuantumTape() as tape: + qml.templates.StronglyEntanglingLayers(*self.parameters, self.wires) + return tape + + @qnode(dev, interface="jax", diff_method=diff_method) + def circuit1(weights): + Template(weights, wires=[0, 1]) + return qml.expval(qml.PauliZ(0)) + + @qnode(dev, interface="jax", diff_method=diff_method) + def circuit2(data, weights): + qml.templates.AngleEmbedding(jnp.stack([data, 0.7]), wires=[0, 1]) + Template(weights, wires=[0, 1]) + return qml.expval(qml.PauliX(0)) + + def cost(weights): + w1, w2 = weights + c1 = circuit1(w1) + c2 = circuit2(c1, w2) + return jnp.sum(c2) ** 2 + + w1 = qml.templates.StronglyEntanglingLayers.shape(n_wires=2, n_layers=3) + w2 = qml.templates.StronglyEntanglingLayers.shape(n_wires=2, n_layers=4) + + weights = [ + jnp.array(np.random.random(w1)), + jnp.array(np.random.random(w2)), + ] + + grad_fn = jax.grad(cost) + res = grad_fn(weights) + + assert len(res) == 2 + + def test_second_derivative(self, dev_name, diff_method, mode, tol): + """Test second derivative calculation of a scalar valued QNode""" + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + g = jax.grad(circuit)(x) + g2 = jax.grad(lambda x: jnp.sum(jax.grad(circuit)(x)))(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + expected_g2 = [ + -np.cos(a) * np.cos(b) + np.sin(a) * np.sin(b), + np.sin(a) * np.sin(b) - np.cos(a) * np.cos(b), + ] + assert np.allclose(g2, expected_g2, atol=tol, rtol=0) + + def test_hessian(self, dev_name, diff_method, mode, tol): + """Test hessian calculation of a scalar valued QNode""" + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + grad_fn = jax.grad(circuit) + g = grad_fn(x) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + hess = jax.jacobian(grad_fn)(x) + + expected_hess = [ + [-np.cos(a) * np.cos(b), np.sin(a) * np.sin(b)], + [np.sin(a) * np.sin(b), -np.cos(a) * np.cos(b)], + ] + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued(self, dev_name, diff_method, mode, tol): + """Test hessian calculation of a vector valued QNode""" + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.probs(wires=0) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + + a, b = x + + expected_res = [0.5 + 0.5 * np.cos(a) * np.cos(b), 0.5 - 0.5 * np.cos(a) * np.cos(b)] + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + jac_fn = jax.jacobian(circuit) + g = jac_fn(x) + + expected_g = [ + [-0.5 * np.sin(a) * np.cos(b), -0.5 * np.cos(a) * np.sin(b)], + [0.5 * np.sin(a) * np.cos(b), 0.5 * np.cos(a) * np.sin(b)], + ] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + hess = jax.jacobian(jac_fn)(x) + + expected_hess = [ + [ + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.sin(a) * np.sin(b)], + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.cos(a) * np.cos(b)], + ], + [ + [0.5 * np.cos(a) * np.cos(b), -0.5 * np.sin(a) * np.sin(b)], + [-0.5 * np.sin(a) * np.sin(b), 0.5 * np.cos(a) * np.cos(b)], + ], + ] + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued_postprocessing(self, dev_name, diff_method, mode, tol): + """Test hessian calculation of a vector valued QNode with post-processing""" + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(x): + qml.RX(x[0], wires=0) + qml.RY(x[1], wires=0) + return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(0))] + + def cost_fn(x): + return x @ circuit(x) + + x = jnp.array( + [0.76, -0.87], + ) + res = cost_fn(x) + + a, b = x + + expected_res = x @ jnp.array([np.cos(a) * np.cos(b), np.cos(a) * np.cos(b)]) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + grad_fn = jax.grad(cost_fn) + g = grad_fn(x) + + expected_g = [ + np.cos(b) * (np.cos(a) - (a + b) * np.sin(a)), + np.cos(a) * (np.cos(b) - (a + b) * np.sin(b)), + ] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + hess = jax.jacobian(grad_fn)(x) + + expected_hess = [ + [ + -(np.cos(b) * ((a + b) * np.cos(a) + 2 * np.sin(a))), + -(np.cos(b) * np.sin(a)) + (-np.cos(a) + (a + b) * np.sin(a)) * np.sin(b), + ], + [ + -(np.cos(b) * np.sin(a)) + (-np.cos(a) + (a + b) * np.sin(a)) * np.sin(b), + -(np.cos(a) * ((a + b) * np.cos(b) + 2 * np.sin(b))), + ], + ] + + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_hessian_vector_valued_separate_args(self, dev_name, diff_method, mode, mocker, tol): + """Test hessian calculation of a vector valued QNode that has separate input arguments""" + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + dev = qml.device(dev_name, wires=1) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode, max_diff=2) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + return qml.probs(wires=0) + + a = jnp.array(1.0) + b = jnp.array(2.0) + res = circuit(a, b) + + expected_res = [0.5 + 0.5 * np.cos(a) * np.cos(b), 0.5 - 0.5 * np.cos(a) * np.cos(b)] + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + jac_fn = jax.jacobian(circuit, argnums=[0, 1]) + g = jac_fn(a, b) + + expected_g = np.array( + [ + [-0.5 * np.sin(a) * np.cos(b), -0.5 * np.cos(a) * np.sin(b)], + [0.5 * np.sin(a) * np.cos(b), 0.5 * np.cos(a) * np.sin(b)], + ] + ) + assert np.allclose(g, expected_g.T, atol=tol, rtol=0) + + spy = mocker.spy(qml.gradients.param_shift, "transform_fn") + hess = jax.jacobian(jac_fn, argnums=[0, 1])(a, b) + + if diff_method == "backprop": + spy.assert_not_called() + elif diff_method == "parameter-shift": + spy.assert_called() + + expected_hess = np.array( + [ + [ + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.cos(a) * np.cos(b)], + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.sin(a) * np.sin(b)], + ], + [ + [0.5 * np.sin(a) * np.sin(b), -0.5 * np.sin(a) * np.sin(b)], + [-0.5 * np.cos(a) * np.cos(b), 0.5 * np.cos(a) * np.cos(b)], + ], + ] + ) + assert np.allclose(hess, expected_hess, atol=tol, rtol=0) + + def test_state(self, dev_name, diff_method, mode, tol): + """Test that the state can be returned and differentiated""" + if diff_method != "backprop": + pytest.skip("JAX interface does not support vector-valued QNodes") + + if diff_method == "adjoint": + pytest.skip("Adjoint does not support states") + + dev = qml.device(dev_name, wires=2) + + x = jnp.array(0.543) + y = jnp.array(-0.654) + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=[0]) + qml.RY(y, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.state() + + def cost_fn(x, y): + res = circuit(x, y) + assert res.dtype is np.dtype("complex128") + probs = jnp.abs(res) ** 2 + return probs[0] + probs[2] + + res = cost_fn(x, y) + + if diff_method not in {"backprop"}: + pytest.skip("Test only supports backprop") + + res = jax.grad(cost_fn, argnums=[0, 1])(x, y) + expected = np.array([-np.sin(x) * np.cos(y) / 2, -np.cos(x) * np.sin(y) / 2]) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_projector(self, dev_name, diff_method, mode, tol): + """Test that the variance of a projector is correctly returned""" + if diff_method == "adjoint": + pytest.skip("Adjoint does not support projectors") + + dev = qml.device(dev_name, wires=2) + P = jnp.array([1]) + x, y = 0.765, -0.654 + + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x, y): + qml.RX(x, wires=0) + qml.RY(y, wires=1) + qml.CNOT(wires=[0, 1]) + return qml.var(qml.Projector(P, wires=0) @ qml.PauliX(1)) + + res = circuit(x, y) + expected = 0.25 * np.sin(x / 2) ** 2 * (3 + np.cos(2 * y) + 2 * np.cos(x) * np.sin(y) ** 2) + assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.grad(circuit, argnums=[0, 1])(x, y) + expected = np.array( + [ + 0.5 * np.sin(x) * (np.cos(x / 2) ** 2 + np.cos(2 * y) * np.sin(x / 2) ** 2), + -2 * np.cos(y) * np.sin(x / 2) ** 4 * np.sin(y), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + +@pytest.mark.parametrize( + "diff_method,kwargs", + [["finite-diff", {}], ("parameter-shift", {}), ("parameter-shift", {"force_order2": True})], +) +class TestCV: + """Tests for CV integration""" + + def test_first_order_observable(self, diff_method, kwargs, tol): + """Test variance of a first order CV observable""" + dev = qml.device("default.gaussian", wires=1) + + r = 0.543 + phi = -0.654 + + @qnode(dev, interface="jax", diff_method=diff_method, **kwargs) + def circuit(r, phi): + qml.Squeezing(r, 0, wires=0) + qml.Rotation(phi, wires=0) + return qml.var(qml.X(0)) + + res = circuit(r, phi) + expected = np.exp(2 * r) * np.sin(phi) ** 2 + np.exp(-2 * r) * np.cos(phi) ** 2 + assert np.allclose(res, expected, atol=tol, rtol=0) + + # circuit jacobians + res = jax.grad(circuit, argnums=[0, 1])(r, phi) + expected = np.array( + [ + 2 * np.exp(2 * r) * np.sin(phi) ** 2 - 2 * np.exp(-2 * r) * np.cos(phi) ** 2, + 2 * np.sinh(2 * r) * np.sin(2 * phi), + ] + ) + assert np.allclose(res, expected, atol=tol, rtol=0) + + def test_second_order_observable(self, diff_method, kwargs, tol): + """Test variance of a second order CV expectation value""" + dev = qml.device("default.gaussian", wires=1) + + n = 0.12 + a = 0.765 + + @qnode(dev, interface="jax", diff_method=diff_method, **kwargs) + def circuit(n, a): + qml.ThermalState(n, wires=0) + qml.Displacement(a, 0, wires=0) + return qml.var(qml.NumberOperator(0)) + + res = circuit(n, a) + expected = n ** 2 + n + np.abs(a) ** 2 * (1 + 2 * n) + assert np.allclose(res, expected, atol=tol, rtol=0) + + # circuit jacobians + res = jax.grad(circuit, argnums=[0, 1])(n, a) + expected = np.array([2 * a ** 2 + 2 * n + 1, 2 * a * (2 * n + 1)]) + assert np.allclose(res, expected, atol=tol, rtol=0) + + +def test_adjoint_reuse_device_state(mocker): + """Tests that the jax interface reuses the device state for adjoint differentiation""" + dev = qml.device("default.qubit", wires=1) + + @qnode(dev, interface="jax", diff_method="adjoint") + def circ(x): + qml.RX(x, wires=0) + return qml.expval(qml.PauliZ(0)) + + spy = mocker.spy(dev, "adjoint_jacobian") + + grad = jax.grad(circ)(1.0) + assert circ.device.num_executions == 1 + + spy.assert_called_with(mocker.ANY, use_device_state=True) + + +@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) +class TestTapeExpansion: + """Test that tape expansion within the QNode integrates correctly + with the Autograd interface""" + + @pytest.mark.parametrize("max_diff", [1, 2]) + def test_gradient_expansion_trainable_only(self, dev_name, diff_method, mode, max_diff, mocker): + """Test that a *supported* operation with no gradient recipe is only + expanded for parameter-shift and finite-differences when it is trainable.""" + if diff_method not in ("parameter-shift", "finite-diff"): + pytest.skip("Only supports gradient transforms") + + if max_diff > 1: + pytest.skip("JAX only supports first derivatives") + + dev = qml.device(dev_name, wires=1) + + class PhaseShift(qml.PhaseShift): + grad_method = None + + def expand(self): + with qml.tape.QuantumTape() as tape: + qml.RY(3 * self.data[0], wires=self.wires) + return tape + + @qnode(dev, diff_method=diff_method, mode=mode, max_diff=max_diff, interface="jax") + def circuit(x, y): + qml.Hadamard(wires=0) + PhaseShift(x, wires=0) + PhaseShift(2 * y, wires=0) + return qml.expval(qml.PauliX(0)) + + spy = mocker.spy(circuit.device, "batch_execute") + x = jnp.array(0.5) + y = jnp.array(0.7) + circuit(x, y) + + spy = mocker.spy(circuit.gradient_fn, "transform_fn") + res = jax.grad(circuit, argnums=[0])(x, y) + + input_tape = spy.call_args[0][0] + assert len(input_tape.operations) == 3 + assert input_tape.operations[1].name == "RY" + assert input_tape.operations[1].data[0] == 3 * x + assert input_tape.operations[2].name == "PhaseShift" + assert input_tape.operations[2].grad_method is None + + @pytest.mark.parametrize("max_diff", [1, 2]) + def test_hamiltonian_expansion_analytic(self, dev_name, diff_method, mode, max_diff, mocker): + """Test that the Hamiltonian is not expanded if there + are non-commuting groups and the number of shots is None + and the first and second order gradients are correctly evaluated""" + if diff_method == "adjoint": + pytest.skip("The adjoint method does not yet support Hamiltonians") + + if max_diff > 1: + pytest.skip("JAX only supports first derivatives") + + dev = qml.device(dev_name, wires=3, shots=None) + spy = mocker.spy(qml.transforms, "hamiltonian_expand") + obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] + + @qnode(dev, interface="jax", diff_method=diff_method, mode=mode, max_diff=max_diff) + def circuit(data, weights, coeffs): + weights = weights.reshape(1, -1) + qml.templates.AngleEmbedding(data, wires=[0, 1]) + qml.templates.BasicEntanglerLayers(weights, wires=[0, 1]) + return qml.expval(qml.Hamiltonian(coeffs, obs)) + + d = jnp.array([0.1, 0.2]) + w = jnp.array([0.654, -0.734]) + c = jnp.array([-0.6543, 0.24, 0.54]) + + # test output + res = circuit(d, w, c) + expected = c[2] * np.cos(d[1] + w[1]) - c[1] * np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]) + assert np.allclose(res, expected) + spy.assert_not_called() + + # test gradients + grad = jax.grad(circuit, argnums=[1, 2])(d, w, c) + expected_w = [ + -c[1] * np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), + -c[1] * np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]) - c[2] * np.sin(d[1] + w[1]), + ] + expected_c = [0, -np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]), np.cos(d[1] + w[1])] + assert np.allclose(grad[0], expected_w) + assert np.allclose(grad[1], expected_c) + + # test second-order derivatives + if diff_method in ("parameter-shift", "backprop") and max_diff == 2: + + grad2_c = jax.jacobian(jax.grad(circuit, argnum=2), argnum=2)(d, w, c) + assert np.allclose(grad2_c, 0) + + grad2_w_c = jax.jacobian(jax.grad(circuit, argnum=1), argnum=2)(d, w, c) + expected = [0, -np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), 0], [ + 0, + -np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]), + -np.sin(d[1] + w[1]), + ] + assert np.allclose(grad2_w_c, expected) + + # @pytest.mark.xfail(reason="Will fail since expval(H) expands to a vector valued return for finite-shots") + # @pytest.mark.parametrize("max_diff", [1, 2]) + # def test_hamiltonian_expansion_finite_shots( + # self, dev_name, diff_method, mode, max_diff, mocker + # ): + # """Test that the Hamiltonian is expanded if there + # are non-commuting groups and the number of shots is finite + # and the first and second order gradients are correctly evaluated""" + # if diff_method in ("adjoint", "backprop", "finite-diff"): + # pytest.skip("The adjoint and backprop methods do not yet support sampling") + + # if max_diff > 1: + # pytest.skip("JAX only supports first derivatives") + + # dev = qml.device(dev_name, wires=3, shots=50000) + # spy = mocker.spy(qml.transforms, "hamiltonian_expand") + # obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] + + # @qnode(dev, interface="jax", diff_method=diff_method, mode=mode, max_diff=max_diff) + # def circuit(data, weights, coeffs): + # weights = weights.reshape(1, -1) + # qml.templates.AngleEmbedding(data, wires=[0, 1]) + # qml.templates.BasicEntanglerLayers(weights, wires=[0, 1]) + # H = qml.Hamiltonian(coeffs, obs) + # H.compute_grouping() + # return qml.expval(H) + + # d = jnp.array([0.1, 0.2]) + # w = jnp.array([0.654, -0.734]) + # c = jnp.array([-0.6543, 0.24, 0.54]) + + # # test output + # res = circuit(d, w, c) + # expected = c[2] * np.cos(d[1] + w[1]) - c[1] * np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]) + # assert np.allclose(res, expected, atol=0.1) + # spy.assert_called() + + # # test gradients + # grad = jax.grad(circuit, argnums=[1, 2])(d, w, c) + # expected_w = [ + # -c[1] * np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), + # -c[1] * np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]) - c[2] * np.sin(d[1] + w[1]), + # ] + # expected_c = [0, -np.sin(d[0] + w[0]) * np.sin(d[1] + w[1]), np.cos(d[1] + w[1])] + # assert np.allclose(grad[0], expected_w, atol=0.1) + # assert np.allclose(grad[1], expected_c, atol=0.1) + + # # test second-order derivatives + # if diff_method == "parameter-shift" and max_diff == 2: + + # grad2_c = jax.jacobian(jax.grad(circuit, argnum=2), argnum=2)(d, w, c) + # assert np.allclose(grad2_c, 0, atol=0.1) + + # grad2_w_c = jax.jacobian(jax.grad(circuit, argnum=1), argnum=2)(d, w, c) + # expected = [0, -np.cos(d[0] + w[0]) * np.sin(d[1] + w[1]), 0], [ + # 0, + # -np.cos(d[1] + w[1]) * np.sin(d[0] + w[0]), + # -np.sin(d[1] + w[1]), + # ] + # assert np.allclose(grad2_w_c, expected, atol=0.1) diff --git a/tests/transforms/test_batch_transform.py b/tests/transforms/test_batch_transform.py index 01c14d5062a..48102f6377c 100644 --- a/tests/transforms/test_batch_transform.py +++ b/tests/transforms/test_batch_transform.py @@ -376,15 +376,12 @@ def test_differentiable_torch(self, diff_method): def test_differentiable_jax(self, diff_method): """Test that a batch transform is differentiable when using jax""" - if diff_method in ("parameter-shift", "finite-diff"): - pytest.skip("Does not support parameter-shift mode") - jax = pytest.importorskip("jax") dev = qml.device("default.qubit", wires=2) qnode = qml.QNode(self.circuit, dev, interface="jax", diff_method=diff_method) def cost(x, weights): - return self.my_transform(qnode, weights)(x) + return self.my_transform(qnode, weights, max_diff=1)(x) weights = jax.numpy.array([0.1, 0.2]) x = jax.numpy.array(0.543) From b6e78c64be3a5371b61d784c6e4e38a94d9f3c07 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Sat, 25 Sep 2021 15:20:02 +0800 Subject: [PATCH 2/9] add test --- tests/gradients/test_gradient_transform.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/gradients/test_gradient_transform.py b/tests/gradients/test_gradient_transform.py index 88670944e07..ede11909aac 100644 --- a/tests/gradients/test_gradient_transform.py +++ b/tests/gradients/test_gradient_transform.py @@ -405,3 +405,27 @@ def circuit(x): res.backward() expected = -2 * np.cos(2 * x_) assert np.allclose(x.grad.detach(), expected, atol=tol, rtol=0) + + def test_jax(self, tol): + """Test that a gradient transform remains differentiable + with JAX""" + jax = pytest.importorskip("jax") + jnp = jax.numpy + dev = qml.device("default.qubit", wires=2) + + @qml.gradients.param_shift + @qml.qnode(dev, interface="jax") + def circuit(x): + qml.RY(x ** 2, wires=[1]) + qml.CNOT(wires=[0, 1]) + return qml.var(qml.PauliX(1)) + + x = jnp.array(-0.654) + + # res = circuit(x) + # expected = -4 * x * np.cos(x ** 2) * np.sin(x ** 2) + # assert np.allclose(res, expected, atol=tol, rtol=0) + + res = jax.grad(circuit)(x) + expected = -2 * (4 * x ** 2 * np.cos(2 * x ** 2) + np.sin(2 * x ** 2)) + assert np.allclose(res, expected, atol=tol, rtol=0) From 653a546a80fb5277be42231c06fcf469bda73447 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 28 Sep 2021 19:32:10 +0800 Subject: [PATCH 3/9] test --- pennylane/math/utils.py | 2 +- tests/collections/test_collections.py | 2 ++ tests/gradients/test_gradient_transform.py | 6 +++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 585c930d3c8..093a62d8d8e 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -279,6 +279,6 @@ def requires_grad(tensor, interface=None): if interface == "jax": import jax - return isinstance(tensor, jax.interpreters.ad.JVPTracer) + return isinstance(tensor, (jax.numpy.DeviceArray, jax.interpreters.ad.JVPTracer)) raise ValueError(f"Argument {tensor} is an unknown object") diff --git a/tests/collections/test_collections.py b/tests/collections/test_collections.py index 6842ad24151..173482e26b0 100644 --- a/tests/collections/test_collections.py +++ b/tests/collections/test_collections.py @@ -400,6 +400,8 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su coeffs = coeffs.numpy() expected = np.dot(qcval, coeffs) + print(res) + print(expected) assert np.all(res == expected) def test_unknown_interface(self, monkeypatch): diff --git a/tests/gradients/test_gradient_transform.py b/tests/gradients/test_gradient_transform.py index ede11909aac..8e32ad37e8f 100644 --- a/tests/gradients/test_gradient_transform.py +++ b/tests/gradients/test_gradient_transform.py @@ -422,9 +422,9 @@ def circuit(x): x = jnp.array(-0.654) - # res = circuit(x) - # expected = -4 * x * np.cos(x ** 2) * np.sin(x ** 2) - # assert np.allclose(res, expected, atol=tol, rtol=0) + res = circuit(x) + expected = -4 * x * np.cos(x ** 2) * np.sin(x ** 2) + assert np.allclose(res, expected, atol=tol, rtol=0) res = jax.grad(circuit)(x) expected = -2 * (4 * x ** 2 * np.cos(2 * x ** 2) + np.sin(2 * x ** 2)) From 92cb81d005d47c4986b9388b2389bb92cf20b474 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 28 Sep 2021 20:29:33 +0800 Subject: [PATCH 4/9] update --- pennylane/math/multi_dispatch.py | 14 +++++++++++++- pennylane/math/utils.py | 2 +- tests/transforms/test_metric_tensor.py | 1 + 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index 444dc8a5590..a04330f99c2 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -274,11 +274,23 @@ def get_trainable_indices(values): Trainable: {0} tensor(0.0899685, requires_grad=True) """ + trainable = requires_grad interface = _multi_dispatch(values) trainable_params = set() + if interface == "jax": + import jax + + if not any(isinstance(v, jax.interpreters.ad.JVPTracer) for v in values): + # No JAX tracing is occuring; treat all `DeviceArray` objects as trainable. + def trainable(p, **kwargs): + return isinstance(p, jax.numpy.DeviceArray) + + else: + trainable = requires_grad + for idx, p in enumerate(values): - if requires_grad(p, interface=interface): + if trainable(p, interface=interface): trainable_params.add(idx) return trainable_params diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 093a62d8d8e..585c930d3c8 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -279,6 +279,6 @@ def requires_grad(tensor, interface=None): if interface == "jax": import jax - return isinstance(tensor, (jax.numpy.DeviceArray, jax.interpreters.ad.JVPTracer)) + return isinstance(tensor, jax.interpreters.ad.JVPTracer) raise ValueError(f"Argument {tensor} is an unknown object") diff --git a/tests/transforms/test_metric_tensor.py b/tests/transforms/test_metric_tensor.py index 76ff4370d28..c9cc04e41b0 100644 --- a/tests/transforms/test_metric_tensor.py +++ b/tests/transforms/test_metric_tensor.py @@ -612,6 +612,7 @@ def cost(weights): weights = jnp.array([0.432, 0.12, -0.432]) a, b, c = weights + cost(weights) grad = jax.grad(cost)(weights) expected = np.array( [np.cos(a) * np.cos(b) ** 2 * np.sin(a) / 2, np.cos(a) ** 2 * np.sin(2 * b) / 4, 0] From eedeabb55fd187c30d3177cc9dffb7247a3d3488 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 28 Sep 2021 20:47:22 +0800 Subject: [PATCH 5/9] update --- pennylane/math/multi_dispatch.py | 8 +++++--- tests/collections/test_collections.py | 4 +--- tests/tape/test_unwrap.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index a04330f99c2..bca3ccf077b 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -283,10 +283,12 @@ def get_trainable_indices(values): if not any(isinstance(v, jax.interpreters.ad.JVPTracer) for v in values): # No JAX tracing is occuring; treat all `DeviceArray` objects as trainable. - def trainable(p, **kwargs): - return isinstance(p, jax.numpy.DeviceArray) - + trainable = lambda p, **kwargs: isinstance(p, jax.numpy.DeviceArray) else: + # JAX tracing is occuring; use the default behaviour (only traced arrays + # are treated as trainable). This is required to ensure that `jax.grad(func, argnums=...) + # works correctly, as the argnums argnument determines which parameters are + # traced arrays. trainable = requires_grad for idx, p in enumerate(values): diff --git a/tests/collections/test_collections.py b/tests/collections/test_collections.py index 173482e26b0..40b7e6abe87 100644 --- a/tests/collections/test_collections.py +++ b/tests/collections/test_collections.py @@ -400,9 +400,7 @@ def test_dot_product_qnodes_tensor(self, qnodes, interface, tf_support, torch_su coeffs = coeffs.numpy() expected = np.dot(qcval, coeffs) - print(res) - print(expected) - assert np.all(res == expected) + assert np.allclose(res, expected) def test_unknown_interface(self, monkeypatch): """Test exception raised if the interface is unknown""" diff --git a/tests/tape/test_unwrap.py b/tests/tape/test_unwrap.py index f6cf1f33f63..0424c3e027f 100644 --- a/tests/tape/test_unwrap.py +++ b/tests/tape/test_unwrap.py @@ -167,9 +167,9 @@ def test_unwrap_jax(): assert all(isinstance(i, float) for i in params) assert np.allclose(params, [0.1, 0.2, 0.5, 0.3]) - # During the forward pass, JAX has no concept of trainable - # arrays. - assert tape.trainable_params == set() + # During the forward pass, Device arrays are treated as + # trainable, but no other types are. + assert tape.trainable_params == {0, 1, 3} # outside the context, the original parameters have been restored. assert tape.get_parameters(trainable_only=False) == p From 04dfcf5e0293b8f1894086d7a08f49fcf7e644d4 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Wed, 29 Sep 2021 01:07:24 +0800 Subject: [PATCH 6/9] Add jit tests --- pennylane/math/multi_dispatch.py | 2 +- pennylane/math/utils.py | 2 +- tests/interfaces/test_batch_jax_qnode.py | 58 ++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/pennylane/math/multi_dispatch.py b/pennylane/math/multi_dispatch.py index bca3ccf077b..765440d213e 100644 --- a/pennylane/math/multi_dispatch.py +++ b/pennylane/math/multi_dispatch.py @@ -281,7 +281,7 @@ def get_trainable_indices(values): if interface == "jax": import jax - if not any(isinstance(v, jax.interpreters.ad.JVPTracer) for v in values): + if not any(isinstance(v, jax.core.Tracer) for v in values): # No JAX tracing is occuring; treat all `DeviceArray` objects as trainable. trainable = lambda p, **kwargs: isinstance(p, jax.numpy.DeviceArray) else: diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 585c930d3c8..fb53819cf18 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -279,6 +279,6 @@ def requires_grad(tensor, interface=None): if interface == "jax": import jax - return isinstance(tensor, jax.interpreters.ad.JVPTracer) + return isinstance(tensor, jax.core.Tracer) raise ValueError(f"Argument {tensor} is an unknown object") diff --git a/tests/interfaces/test_batch_jax_qnode.py b/tests/interfaces/test_batch_jax_qnode.py index 2dbdaa5287c..3f355290abe 100644 --- a/tests/interfaces/test_batch_jax_qnode.py +++ b/tests/interfaces/test_batch_jax_qnode.py @@ -1103,3 +1103,61 @@ def circuit(data, weights, coeffs): # -np.sin(d[1] + w[1]), # ] # assert np.allclose(grad2_w_c, expected, atol=0.1) + + +@pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) +class TestJIT: + """Test JAX JIT integration with the QNode""" + + def test_gradient(self, dev_name, diff_method, mode, tol): + """Test derivative calculation of a scalar valued QNode""" + dev = qml.device(dev_name, wires=1) + + if diff_method == "adjoint": + pytest.xfail(reason="The adjoint method is not using host-callback currently") + + @jax.jit + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(x): + qml.RY(x[0], wires=0) + qml.RX(x[1], wires=0) + return qml.expval(qml.PauliZ(0)) + + x = jnp.array([1.0, 2.0]) + res = circuit(x) + g = jax.grad(circuit)(x) + + a, b = x + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) + + @pytest.mark.xfail( + reason="Non-trainable parameters are not being correctly unwrapped by the interface" + ) + def test_gradient_subset(self, dev_name, diff_method, mode, tol): + """Test derivative calculation of a scalar valued QNode with respect + to a subset of arguments""" + a = jnp.array(0.1) + b = jnp.array(0.2) + + dev = qml.device(dev_name, wires=1) + + @jax.jit + @qnode(dev, diff_method=diff_method, interface="jax", mode=mode) + def circuit(a, b): + qml.RY(a, wires=0) + qml.RX(b, wires=0) + qml.RZ(c, wires=0) + return qml.expval(qml.PauliZ(0)) + + res = jax.grad(circuit, argnums=[0, 1])(a, b, 0.0) + + expected_res = np.cos(a) * np.cos(b) + assert np.allclose(res, expected_res, atol=tol, rtol=0) + + expected_g = [-np.sin(a) * np.cos(b), -np.cos(a) * np.sin(b)] + assert np.allclose(g, expected_g, atol=tol, rtol=0) From 31eae7c82af50f0f163d25fa23cb7ce9ac255c38 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Wed, 29 Sep 2021 01:12:05 +0800 Subject: [PATCH 7/9] more --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f9767d2aba9..989fb2d225f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -8,6 +8,7 @@ extended to the JAX interface for scalar functions, via the beta `pennylane.interfaces.batch` module. [(#1634)](https://github.com/PennyLaneAI/pennylane/pull/1634) + [(#1683)](https://github.com/PennyLaneAI/pennylane/pull/1683) For example using the `execute` function from the `pennylane.interfaces.batch` module: From b8e511d2d3d18ec880f7706faa915e94d063c177 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Tue, 19 Oct 2021 22:50:36 +0800 Subject: [PATCH 8/9] fix --- pennylane/transforms/batch_transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pennylane/transforms/batch_transform.py b/pennylane/transforms/batch_transform.py index 861cf5d6490..e888b17f32f 100644 --- a/pennylane/transforms/batch_transform.py +++ b/pennylane/transforms/batch_transform.py @@ -272,7 +272,7 @@ def default_qnode_wrapper(self, qnode, targs, tkwargs): the QNode, and return the output of the applying the tape transform to the QNode's constructed tape. """ - max_diff = tkwargs.pop("max_diff", 2) + transform_max_diff = tkwargs.pop("max_diff", None) if "shots" in inspect.signature(qnode.func).parameters: raise ValueError( @@ -289,6 +289,7 @@ def _wrapper(*args, **kwargs): interface = qnode.interface execute_kwargs = getattr(qnode, "execute_kwargs", {}) max_diff = execute_kwargs.pop("max_diff", 2) + max_diff = transform_max_diff or max_diff gradient_fn = getattr(qnode, "gradient_fn", qnode.diff_method) gradient_kwargs = getattr(qnode, "gradient_kwargs", {}) From 4d508171a2b9d50c48ee46c3204e41b70044cc49 Mon Sep 17 00:00:00 2001 From: Josh Izaac Date: Wed, 20 Oct 2021 00:08:20 +0800 Subject: [PATCH 9/9] Apply suggestions from code review Co-authored-by: antalszava --- doc/releases/changelog-dev.md | 2 +- tests/interfaces/test_batch_jax_qnode.py | 6 +++--- tests/transforms/test_metric_tensor.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index c6e05b4556c..557ad2375f4 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -135,7 +135,7 @@ extended to the JAX interface for scalar functions, via the beta `pennylane.interfaces.batch` module. [(#1634)](https://github.com/PennyLaneAI/pennylane/pull/1634) - [(#1683)](https://github.com/PennyLaneAI/pennylane/pull/1683) + [(#1685)](https://github.com/PennyLaneAI/pennylane/pull/1685) For example using the `execute` function from the `pennylane.interfaces.batch` module: diff --git a/tests/interfaces/test_batch_jax_qnode.py b/tests/interfaces/test_batch_jax_qnode.py index 3f355290abe..d6a5ac2295c 100644 --- a/tests/interfaces/test_batch_jax_qnode.py +++ b/tests/interfaces/test_batch_jax_qnode.py @@ -1,4 +1,4 @@ -# Copyright 2018-2020 Xanadu Quantum Technologies Inc. +# Copyright 2018-2021 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ @pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) class TestQNode: - """Test that using the QNode with Autograd integrates with the PennyLane stack""" + """Test that using the QNode with JAX integrates with the PennyLane stack""" def test_execution_with_interface(self, dev_name, diff_method, mode): """Test execution works with the interface""" @@ -943,7 +943,7 @@ def circ(x): @pytest.mark.parametrize("dev_name,diff_method,mode", qubit_device_and_diff_method) class TestTapeExpansion: """Test that tape expansion within the QNode integrates correctly - with the Autograd interface""" + with the JAX interface""" @pytest.mark.parametrize("max_diff", [1, 2]) def test_gradient_expansion_trainable_only(self, dev_name, diff_method, mode, max_diff, mocker): diff --git a/tests/transforms/test_metric_tensor.py b/tests/transforms/test_metric_tensor.py index 4f5ab52ac33..169a6cea1d7 100644 --- a/tests/transforms/test_metric_tensor.py +++ b/tests/transforms/test_metric_tensor.py @@ -672,7 +672,6 @@ def cost(weights): weights = jnp.array([0.432, 0.12, -0.432]) a, b, c = weights - cost(weights) grad = jax.grad(cost)(weights) expected = np.array( [np.cos(a) * np.cos(b) ** 2 * np.sin(a) / 2, np.cos(a) ** 2 * np.sin(2 * b) / 4, 0]