diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index e3d117d9564..9491b3e3564 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -78,6 +78,9 @@
configurations.
[(#2169)](https://github.com/PennyLaneAI/pennylane/pull/2169)
+ The postprocessing function for the `cut_circuit` transform has been added.
+ [(#2192)](https://github.com/PennyLaneAI/pennylane/pull/2192)
+
Improvements
* The `gradients` module has been streamlined and special-purpose functions
diff --git a/pennylane/grouping/pauli.py b/pennylane/grouping/pauli.py
index 4786d2ce2f9..a7d5ee7fdc5 100644
--- a/pennylane/grouping/pauli.py
+++ b/pennylane/grouping/pauli.py
@@ -309,8 +309,11 @@ def partition_pauli_group(n_qubits: int) -> List[List[str]]:
if not isinstance(n_qubits, int):
raise TypeError("Must specify an integer number of qubits.")
- if n_qubits <= 0:
- raise ValueError("Number of qubits must be at least 1.")
+ if n_qubits < 0:
+ raise ValueError("Number of qubits must be at least 0.")
+
+ if n_qubits == 0:
+ return [[""]]
strings = set() # tracks all the strings that have already been grouped
groups = []
diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py
index 4be4a266e9a..57855710abe 100644
--- a/pennylane/transforms/__init__.py
+++ b/pennylane/transforms/__init__.py
@@ -97,6 +97,7 @@
~transforms.graph_to_tape
~transforms.expand_fragment_tapes
~transforms.contract_tensors
+ ~transforms.qcut_processing_fn
Transforms that act on tapes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -177,4 +178,5 @@
graph_to_tape,
expand_fragment_tapes,
contract_tensors,
+ qcut_processing_fn,
)
diff --git a/pennylane/transforms/qcut.py b/pennylane/transforms/qcut.py
index c9bb1eb9baa..9a72ba77f51 100644
--- a/pennylane/transforms/qcut.py
+++ b/pennylane/transforms/qcut.py
@@ -21,8 +21,9 @@
from itertools import product
from typing import List, Sequence, Tuple
-import pennylane as qml
from networkx import MultiDiGraph, weakly_connected_components
+
+import pennylane as qml
from pennylane import apply, expval
from pennylane.grouping import string_to_pauli_word
from pennylane.measure import MeasurementProcess
@@ -593,12 +594,12 @@ def contract_tensors(
each tensor
use_opt_einsum (bool): Determines whether to use the
`opt_einsum `__ package. This package is useful
- for tensor contractions of large networks but must be installed separately using, e.g.,
- ``pip install opt_einsum``. Both settings for ``use_opt_einsum`` result in a
+ for faster tensor contractions of large networks but must be installed separately using,
+ e.g., ``pip install opt_einsum``. Both settings for ``use_opt_einsum`` result in a
differentiable contraction.
Returns:
- float or array-like: the result of contracting the tensor network
+ float or tensor_like: the result of contracting the tensor network
**Example**
@@ -675,3 +676,166 @@ def contract_tensors(
kwargs = {} if use_opt_einsum else {"like": tensors[0]}
return contract(eqn, *tensors, **kwargs)
+
+
+CHANGE_OF_BASIS = qml.math.array(
+ [[1.0, 1.0, 0.0, 0.0], [-1.0, -1.0, 2.0, 0.0], [-1.0, -1.0, 0.0, 2.0], [1.0, -1.0, 0.0, 0.0]]
+)
+
+
+def _process_tensor(results, n_prep: int, n_meas: int):
+ """Convert a flat slice of an individual circuit fragment's execution results into a tensor.
+
+ This function performs the following steps:
+
+ 1. Reshapes ``results`` into the intermediate shape ``(4,) * n_prep + (4**n_meas,)``
+ 2. Shuffles the final axis to follow the standard product over measurement settings. E.g., for
+ ``n_meas = 2`` the standard product is: II, IX, IY, IZ, XI, ..., ZY, ZZ while the input order
+ will be the result of ``qml.grouping.partition_pauli_group(2)``, i.e., II, IZ, ZI, ZZ, ...,
+ YY.
+ 3. Reshapes into the final target shape ``(4,) * (n_prep + n_meas)``
+ 4. Performs a change of basis for the preparation indices (the first ``n_prep`` indices) from
+ the |0>, |1>, |+>, |+i> basis to the I, X, Y, Z basis using ``CHANGE_OF_BASIS``.
+
+ Args:
+ results (tensor_like): the input execution results
+ n_prep (int): the number of preparation nodes in the corresponding circuit fragment
+ n_meas (int): the number of measurement nodes in the corresponding circuit fragment
+
+ Returns:
+ tensor_like: the corresponding fragment tensor
+ """
+ n = n_prep + n_meas
+ dim_meas = 4**n_meas
+
+ # Step 1
+ intermediate_shape = (4,) * n_prep + (dim_meas,)
+ intermediate_tensor = qml.math.reshape(results, intermediate_shape)
+
+ # Step 2
+ grouped = qml.grouping.partition_pauli_group(n_meas)
+ grouped_flat = [term for group in grouped for term in group]
+ order = qml.math.argsort(grouped_flat)
+
+ if qml.math.get_interface(intermediate_tensor) == "tensorflow":
+ # TensorFlow does not support slicing
+ intermediate_tensor = qml.math.gather(intermediate_tensor, order, axis=-1)
+ else:
+ sl = [slice(None)] * n_prep + [order]
+ intermediate_tensor = intermediate_tensor[tuple(sl)]
+
+ # Step 3
+ final_shape = (4,) * n
+ final_tensor = qml.math.reshape(intermediate_tensor, final_shape)
+
+ # Step 4
+ change_of_basis = qml.math.convert_like(CHANGE_OF_BASIS, intermediate_tensor)
+
+ for i in range(n_prep):
+ axes = [[1], [i]]
+ final_tensor = qml.math.tensordot(change_of_basis, final_tensor, axes=axes)
+
+ axes = list(reversed(range(n_prep))) + list(range(n_prep, n))
+
+ # Use transpose to reorder indices. We must do this because tensordot returns a tensor whose
+ # indices are ordered according to the uncontracted indices of the first tensor, followed
+ # by the uncontracted indices of the second tensor. For example, calculating C_kj T_ij returns
+ # a tensor T'_ki rather than T'_ik.
+ final_tensor = qml.math.transpose(final_tensor, axes=axes)
+
+ final_tensor *= qml.math.power(2, -(n_meas + n_prep) / 2)
+ return final_tensor
+
+
+def _to_tensors(
+ results,
+ prepare_nodes: Sequence[Sequence[PrepareNode]],
+ measure_nodes: Sequence[Sequence[MeasureNode]],
+) -> List:
+ """Process a flat list of execution results from all circuit fragments into the corresponding
+ tensors.
+
+ This function slices ``results`` according to the expected size of fragment tensors derived from
+ the ``prepare_nodes`` and ``measure_nodes`` and then passes onto ``_process_tensor`` for further
+ transformation.
+
+ Args:
+ results (tensor_like): A collection of execution results, provided as a flat tensor,
+ corresponding to the expansion of circuit fragments in the communication graph over
+ measurement and preparation node configurations. These results are processed into
+ tensors by this function.
+ prepare_nodes (Sequence[Sequence[PrepareNode]]): a sequence whose length is equal to the
+ number of circuit fragments, with each element used here to determine the number of
+ preparation nodes in a given fragment
+ measure_nodes (Sequence[Sequence[MeasureNode]]): a sequence whose length is equal to the
+ number of circuit fragments, with each element used here to determine the number of
+ measurement nodes in a given fragment
+
+ Returns:
+ List[tensor_like]: the tensors for each circuit fragment in the communication graph
+ """
+ ctr = 0
+ tensors = []
+
+ for p, m in zip(prepare_nodes, measure_nodes):
+ n_prep = len(p)
+ n_meas = len(m)
+ n = n_prep + n_meas
+
+ dim = 4**n
+ results_slice = results[ctr : dim + ctr]
+
+ tensors.append(_process_tensor(results_slice, n_prep, n_meas))
+
+ ctr += dim
+
+ if len(results) != ctr:
+ raise ValueError(f"The results argument should be a flat list of length {ctr}")
+
+ return tensors
+
+
+def qcut_processing_fn(
+ results: Sequence[Sequence],
+ communication_graph: MultiDiGraph,
+ prepare_nodes: Sequence[Sequence[PrepareNode]],
+ measure_nodes: Sequence[Sequence[MeasureNode]],
+ use_opt_einsum: bool = False,
+):
+ """Processing function for the :func:`cut_circuit` transform.
+
+ .. note::
+
+ This function is designed for use as part of the circuit cutting workflow. Check out the
+ :doc:`transforms ` page for more details.
+
+ Args:
+ results (Sequence[Sequence]): A collection of execution results corresponding to the
+ expansion of circuit fragments in the ``communication_graph`` over measurement and
+ preparation node configurations. These results are processed into tensors and then
+ contracted.
+ communication_graph (MultiDiGraph): the communication graph determining connectivity between
+ circuit fragments
+ prepare_nodes (Sequence[Sequence[PrepareNode]]): a sequence of size
+ ``len(communication_graph.nodes)`` that determines the order of preparation indices in
+ each tensor
+ measure_nodes (Sequence[Sequence[MeasureNode]]): a sequence of size
+ ``len(communication_graph.nodes)`` that determines the order of measurement indices in
+ each tensor
+ use_opt_einsum (bool): Determines whether to use the
+ `opt_einsum `__ package. This package is useful
+ for faster tensor contractions of large networks but must be installed separately using,
+ e.g., ``pip install opt_einsum``. Both settings for ``use_opt_einsum`` result in a
+ differentiable contraction.
+
+ Returns:
+ float or tensor_like: the output of the original uncut circuit arising from contracting
+ the tensor network of circuit fragments
+ """
+ flat_results = qml.math.concatenate(results)
+
+ tensors = _to_tensors(flat_results, prepare_nodes, measure_nodes)
+ result = contract_tensors(
+ tensors, communication_graph, prepare_nodes, measure_nodes, use_opt_einsum
+ )
+ return result
diff --git a/tests/grouping/test_pauli_group.py b/tests/grouping/test_pauli_group.py
index 90f6cfae546..3c02ac63bdc 100644
--- a/tests/grouping/test_pauli_group.py
+++ b/tests/grouping/test_pauli_group.py
@@ -263,5 +263,9 @@ def test_invalid_input(self):
with pytest.raises(TypeError, match="Must specify an integer number"):
partition_pauli_group("3")
- with pytest.raises(ValueError, match="Number of qubits must be at least 1"):
+ with pytest.raises(ValueError, match="Number of qubits must be at least 0"):
partition_pauli_group(-1)
+
+ def test_zero(self):
+ """Test if [[""]] is returned with zero qubits"""
+ assert partition_pauli_group(0) == [[""]]
diff --git a/tests/transforms/test_qcut.py b/tests/transforms/test_qcut.py
index ed22ef942f3..badd90ef56d 100644
--- a/tests/transforms/test_qcut.py
+++ b/tests/transforms/test_qcut.py
@@ -15,17 +15,34 @@
Unit tests for the `pennylane.qcut` package.
"""
import copy
+import itertools
import string
import sys
from itertools import product
-import pennylane as qml
import pytest
from networkx import MultiDiGraph
+from scipy.stats import unitary_group
+
+import pennylane as qml
from pennylane import numpy as np
from pennylane.transforms import qcut
from pennylane.wires import Wires
+I, X, Y, Z = (
+ np.eye(2),
+ qml.PauliX.compute_matrix(),
+ qml.PauliY.compute_matrix(),
+ qml.PauliZ.compute_matrix(),
+)
+
+states_pure = [
+ np.array([1, 0]),
+ np.array([0, 1]),
+ np.array([1, 1]) / np.sqrt(2),
+ np.array([1, 1j]) / np.sqrt(2),
+]
+
with qml.tape.QuantumTape() as tape:
qml.RX(0.432, wires=0)
qml.RY(0.543, wires="a")
@@ -51,6 +68,17 @@
qml.RZ(0.876, wires=3)
qml.expval(qml.PauliZ(wires=[0]))
+
+def kron(*args):
+ """Multi-argument kronecker product"""
+ if len(args) == 1:
+ return args[0]
+ if len(args) == 2:
+ return np.kron(args[0], args[1])
+ else:
+ return np.kron(args[0], kron(*args[1:]))
+
+
# tape containing mid-circuit measurements
with qml.tape.QuantumTape() as mcm_tape:
qml.Hadamard(wires=0)
@@ -1421,3 +1449,286 @@ def test_advanced(self, mocker, use_opt_einsum):
assert eqn == expected_eqn
assert np.allclose(res, np.einsum(eqn, *t))
+
+
+class TestQCutProcessingFn:
+ """Tests for the qcut_processing_fn and contained functions"""
+
+ def test_to_tensors(self, monkeypatch):
+ """Test that _to_tensors correctly reshapes the flat list of results into the original
+ tensors according to the supplied prepare_nodes and measure_nodes. Uses a mock function
+ for _process_tensor since we do not need to process the tensors."""
+ prepare_nodes = [[None] * 3, [None] * 2, [None] * 1, [None] * 4]
+ measure_nodes = [[None] * 2, [None] * 2, [None] * 3, [None] * 3]
+ tensors = [
+ np.arange(4**5).reshape((4,) * 5),
+ np.arange(4**4).reshape((4,) * 4),
+ np.arange(4**4).reshape((4,) * 4),
+ np.arange(4**7).reshape((4,) * 7),
+ ]
+ results = np.concatenate([t.flatten() for t in tensors])
+
+ def mock_process_tensor(r, np, nm):
+ return qml.math.reshape(r, (4,) * (np + nm))
+
+ with monkeypatch.context() as m:
+ m.setattr(qcut, "_process_tensor", mock_process_tensor)
+ tensors_out = qcut._to_tensors(results, prepare_nodes, measure_nodes)
+
+ for t1, t2 in zip(tensors, tensors_out):
+ assert np.allclose(t1, t2)
+
+ def test_to_tensors_raises(self):
+ """Tests if a ValueError is raised when a results vector is passed to _to_tensors with a
+ size that is incompatible with the prepare_nodes and measure_nodes arguments"""
+ prepare_nodes = [[None] * 3]
+ measure_nodes = [[None] * 2]
+ tensors = [np.arange(4**5).reshape((4,) * 5), np.arange(4)]
+ results = np.concatenate([t.flatten() for t in tensors])
+
+ with pytest.raises(ValueError, match="should be a flat list of length 1024"):
+ qcut._to_tensors(results, prepare_nodes, measure_nodes)
+
+ @pytest.mark.parametrize("interface", ["autograd.numpy", "tensorflow", "torch", "jax.numpy"])
+ @pytest.mark.parametrize("n", [1, 2])
+ def test_process_tensor(self, n, interface):
+ """Test if the tensor returned by _process_tensor is equal to the expected value"""
+ lib = pytest.importorskip(interface)
+
+ U = unitary_group.rvs(2**n, random_state=1967)
+
+ # First, create target process tensor
+ basis = np.array([I, X, Y, Z]) / np.sqrt(2)
+ prod_inp = itertools.product(range(4), repeat=n)
+ prod_out = itertools.product(range(4), repeat=n)
+
+ results = []
+
+ # Calculates U_{ijkl} = Tr((b[k] x b[l]) U (b[i] x b[j]) U*)
+ # See Sec. II. A. of https://arxiv.org/abs/1909.07534, below Eq. (2).
+ for inp, out in itertools.product(prod_inp, prod_out):
+ input = kron(*[basis[i] for i in inp])
+ output = kron(*[basis[i] for i in out])
+ results.append(np.trace(output @ U @ input @ U.conj().T))
+
+ target_tensor = np.array(results).reshape((4,) * (2 * n))
+
+ # Now, create the input results vector found from executing over the product of |0>, |1>,
+ # |+>, |+i> inputs and using the grouped Pauli terms for measurements
+ dev = qml.device("default.qubit", wires=n)
+
+ @qml.qnode(dev)
+ def f(state, measurement):
+ qml.QubitStateVector(state, wires=range(n))
+ qml.QubitUnitary(U, wires=range(n))
+ return [qml.expval(qml.grouping.string_to_pauli_word(m)) for m in measurement]
+
+ prod_inp = itertools.product(range(4), repeat=n)
+ prod_out = qml.grouping.partition_pauli_group(n)
+
+ results = []
+
+ for inp, out in itertools.product(prod_inp, prod_out):
+ input = kron(*[states_pure[i] for i in inp])
+ results.append(f(input, out))
+
+ results = qml.math.cast_like(np.concatenate(results), lib.ones(1))
+
+ # Now apply _process_tensor
+ tensor = qcut._process_tensor(results, n, n)
+ assert np.allclose(tensor, target_tensor)
+
+ @pytest.mark.parametrize("use_opt_einsum", [True, False])
+ def test_qcut_processing_fn(self, use_opt_einsum):
+ """Test if qcut_processing_fn returns the expected answer when applied to a simple circuit
+ that is cut up into three fragments:
+ 0: ──RX(0.5)─|─RY(0.6)─|─RX(0.8)──┤ ⟨Z⟩
+ """
+ if use_opt_einsum:
+ pytest.importorskip("opt_einsum")
+
+ ### Find the expected result
+ dev = qml.device("default.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def f(x, y, z):
+ qml.RX(x, wires=0)
+ ### CUT HERE
+ qml.RY(y, wires=0)
+ ### CUT HERE
+ qml.RX(z, wires=0)
+ return qml.expval(qml.PauliZ(0))
+
+ x, y, z = 0.5, 0.6, 0.8
+ expected_result = f(x, y, z)
+
+ ### Find the result using qcut_processing_fn
+
+ meas_basis = [I, Z, X, Y]
+ states = [np.outer(s, s.conj()) for s in states_pure]
+ zero_proj = states[0]
+
+ u1 = qml.RX.compute_matrix(x)
+ u2 = qml.RY.compute_matrix(y)
+ u3 = qml.RX.compute_matrix(z)
+ t1 = np.array([np.trace(b @ u1 @ zero_proj @ u1.conj().T) for b in meas_basis])
+ t2 = np.array([[np.trace(b @ u2 @ s @ u2.conj().T) for b in meas_basis] for s in states])
+ t3 = np.array([np.trace(Z @ u3 @ s @ u3.conj().T) for s in states])
+
+ res = [t1, t2.flatten(), t3]
+ p = [[], [qcut.PrepareNode(wires=0)], [qcut.PrepareNode(wires=0)]]
+ m = [[qcut.MeasureNode(wires=0)], [qcut.MeasureNode(wires=0)], []]
+
+ edges = [
+ (0, 1, 0, {"pair": (m[0][0], p[1][0])}),
+ (1, 2, 0, {"pair": (m[1][0], p[2][0])}),
+ ]
+ g = MultiDiGraph(edges)
+
+ result = qcut.qcut_processing_fn(res, g, p, m, use_opt_einsum=use_opt_einsum)
+ assert np.allclose(result, expected_result)
+
+ @pytest.mark.parametrize("use_opt_einsum", [True, False])
+ def test_qcut_processing_fn_autograd(self, use_opt_einsum):
+ """Test if qcut_processing_fn handles the gradient as expected in the autograd interface
+ using a simple example"""
+ if use_opt_einsum:
+ pytest.importorskip("opt_einsum")
+
+ x = np.array(0.9, requires_grad=True)
+
+ def f(x):
+ t1 = x * np.arange(4)
+ t2 = x**2 * np.arange(16).reshape((4, 4))
+ t3 = np.sin(x * np.pi / 2) * np.arange(4)
+
+ res = [t1, t2.flatten(), t3]
+ p = [[], [qcut.PrepareNode(wires=0)], [qcut.PrepareNode(wires=0)]]
+ m = [[qcut.MeasureNode(wires=0)], [qcut.MeasureNode(wires=0)], []]
+
+ edges = [
+ (0, 1, 0, {"pair": (m[0][0], p[1][0])}),
+ (1, 2, 0, {"pair": (m[1][0], p[2][0])}),
+ ]
+ g = MultiDiGraph(edges)
+
+ return qcut.qcut_processing_fn(res, g, p, m, use_opt_einsum=use_opt_einsum)
+
+ grad = qml.grad(f)(x)
+ expected_grad = (
+ 3 * x**2 * np.sin(x * np.pi / 2) + x**3 * np.cos(x * np.pi / 2) * np.pi / 2
+ ) * f(1)
+
+ assert np.allclose(grad, expected_grad)
+
+ @pytest.mark.parametrize("use_opt_einsum", [True, False])
+ def test_qcut_processing_fn_tf(self, use_opt_einsum):
+ """Test if qcut_processing_fn handles the gradient as expected in the TF interface
+ using a simple example"""
+ if use_opt_einsum:
+ pytest.importorskip("opt_einsum")
+ tf = pytest.importorskip("tensorflow")
+
+ x = tf.Variable(0.9, dtype=tf.float64)
+
+ def f(x):
+ x = tf.cast(x, dtype=tf.float64)
+ t1 = x * tf.range(4, dtype=tf.float64)
+ t2 = x**2 * tf.range(16, dtype=tf.float64)
+ t3 = tf.sin(x * np.pi / 2) * tf.range(4, dtype=tf.float64)
+
+ res = [t1, t2, t3]
+ p = [[], [qcut.PrepareNode(wires=0)], [qcut.PrepareNode(wires=0)]]
+ m = [[qcut.MeasureNode(wires=0)], [qcut.MeasureNode(wires=0)], []]
+
+ edges = [
+ (0, 1, 0, {"pair": (m[0][0], p[1][0])}),
+ (1, 2, 0, {"pair": (m[1][0], p[2][0])}),
+ ]
+ g = MultiDiGraph(edges)
+
+ return qcut.qcut_processing_fn(res, g, p, m, use_opt_einsum=use_opt_einsum)
+
+ with tf.GradientTape() as tape:
+ res = f(x)
+
+ grad = tape.gradient(res, x)
+ expected_grad = (
+ 3 * x**2 * np.sin(x * np.pi / 2) + x**3 * np.cos(x * np.pi / 2) * np.pi / 2
+ ) * f(1)
+
+ assert np.allclose(grad, expected_grad)
+
+ @pytest.mark.parametrize("use_opt_einsum", [True, False])
+ def test_qcut_processing_fn_torch(self, use_opt_einsum):
+ """Test if qcut_processing_fn handles the gradient as expected in the torch interface
+ using a simple example"""
+ if use_opt_einsum:
+ pytest.importorskip("opt_einsum")
+ torch = pytest.importorskip("torch")
+
+ x = torch.tensor(0.9, requires_grad=True, dtype=torch.float64)
+
+ def f(x):
+ t1 = x * torch.arange(4)
+ t2 = x**2 * torch.arange(16)
+ t3 = torch.sin(x * np.pi / 2) * torch.arange(4)
+
+ res = [t1, t2, t3]
+ p = [[], [qcut.PrepareNode(wires=0)], [qcut.PrepareNode(wires=0)]]
+ m = [[qcut.MeasureNode(wires=0)], [qcut.MeasureNode(wires=0)], []]
+
+ edges = [
+ (0, 1, 0, {"pair": (m[0][0], p[1][0])}),
+ (1, 2, 0, {"pair": (m[1][0], p[2][0])}),
+ ]
+ g = MultiDiGraph(edges)
+
+ return qcut.qcut_processing_fn(res, g, p, m, use_opt_einsum=use_opt_einsum)
+
+ res = f(x)
+ res.backward()
+ grad = x.grad
+
+ x_ = x.detach().numpy()
+ f1 = f(torch.tensor(1, dtype=torch.float64))
+ expected_grad = (
+ 3 * x_**2 * np.sin(x_ * np.pi / 2) + x_**3 * np.cos(x_ * np.pi / 2) * np.pi / 2
+ ) * f1
+
+ assert np.allclose(grad.detach().numpy(), expected_grad)
+
+ @pytest.mark.parametrize("use_opt_einsum", [True, False])
+ def test_qcut_processing_fn_jax(self, use_opt_einsum):
+ """Test if qcut_processing_fn handles the gradient as expected in the jax interface
+ using a simple example"""
+ if use_opt_einsum:
+ pytest.importorskip("opt_einsum")
+ jax = pytest.importorskip("jax")
+ jnp = pytest.importorskip("jax.numpy")
+
+ x = jnp.array(0.9)
+
+ def f(x):
+ t1 = x * jnp.arange(4)
+ t2 = x**2 * jnp.arange(16).reshape((4, 4))
+ t3 = jnp.sin(x * np.pi / 2) * jnp.arange(4)
+
+ res = [t1, t2.flatten(), t3]
+ p = [[], [qcut.PrepareNode(wires=0)], [qcut.PrepareNode(wires=0)]]
+ m = [[qcut.MeasureNode(wires=0)], [qcut.MeasureNode(wires=0)], []]
+
+ edges = [
+ (0, 1, 0, {"pair": (m[0][0], p[1][0])}),
+ (1, 2, 0, {"pair": (m[1][0], p[2][0])}),
+ ]
+ g = MultiDiGraph(edges)
+
+ return qcut.qcut_processing_fn(res, g, p, m, use_opt_einsum=use_opt_einsum)
+
+ grad = jax.grad(f)(x)
+ expected_grad = (
+ 3 * x**2 * np.sin(x * np.pi / 2) + x**3 * np.cos(x * np.pi / 2) * np.pi / 2
+ ) * f(1)
+
+ assert np.allclose(grad, expected_grad)