Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the contraction of quantum and classical Jacobians consistent in gradient_transform #4945

Merged
merged 21 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@

<h3>Breaking changes 💔</h3>

* Applying a `gradient_transform` to a QNode directly now gives the same shape and type independent
of whether there is classical processing in the node.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* State measurements preserve `dtype`.
[(#5547)](https://github.com/PennyLaneAI/pennylane/pull/5547)

Expand Down Expand Up @@ -496,6 +500,10 @@

<h3>Bug fixes 🐛</h3>

* Fixed a bug where the shape and type of derivatives obtained by applying a gradient transform to
a QNode differed, based on whether the QNode uses classical coprocessing.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* `ApproxTimeEvolution`, `CommutingEvolution`, `QDrift`, and `TrotterProduct`
now de-queue their input observable.
[(#5524)](https://github.com/PennyLaneAI/pennylane/pull/5524)
Expand Down
48 changes: 29 additions & 19 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ def reorder_grads(grads, tape_specs):
return _move_first_axis_to_third_pos(grads, num_params, shots.num_copies, num_measurements)


tdot = partial(qml.math.tensordot, axes=[[0], [0]])
stack = qml.math.stack


# pylint: disable=too-many-return-statements,too-many-branches
def _contract_qjac_with_cjac(qjac, cjac, tape):
"""Contract a quantum Jacobian with a classical preprocessing Jacobian.
Expand All @@ -397,52 +401,58 @@ def _contract_qjac_with_cjac(qjac, cjac, tape):
cjac = cjac[0]

cjac_is_tuple = isinstance(cjac, tuple)
if not cjac_is_tuple:
is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]

if not qml.math.is_abstract(cjac) and (
is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0]))
):
# Classical Jacobian is the identity. No classical processing is present in the QNode
return qjac

multi_meas = num_measurements > 1

# This block only figures out whether there is a single tape parameter or not
if cjac_is_tuple:
multi_params = True
single_tape_param = False
else:
# Peel out a single measurement's and single shot setting's qjac
_qjac = qjac
if multi_meas:
_qjac = _qjac[0]
if has_partitioned_shots:
_qjac = _qjac[0]
multi_params = isinstance(_qjac, tuple)

tdot = partial(qml.math.tensordot, axes=[[0], [0]])
single_tape_param = not isinstance(_qjac, tuple)

if not multi_params:
if single_tape_param:
# Without dimension (e.g. expval) or with dimension (e.g. probs)
def _reshape(x):
return qml.math.reshape(x, (1,) if x.shape == () else (1, -1))

if not (multi_meas or has_partitioned_shots):
# Single parameter, single measurements
# Single parameter, single measurements, no shot vector
return tdot(_reshape(qjac), cjac)

if not (multi_meas and has_partitioned_shots):
# Single parameter, multiple measurements or shot vector, but not both
return tuple(tdot(_reshape(q), cjac) for q in qjac)

# Single parameter, multiple measurements
# Single parameter, multiple measurements, and shot vector
return tuple(tuple(tdot(_reshape(_q), cjac) for _q in q) for q in qjac)

if not multi_meas:
# Multiple parameters, single measurement
qjac = qml.math.stack(qjac)
qjac = stack(qjac)
if not cjac_is_tuple:
return tdot(qjac, qml.math.stack(cjac))
cjac = stack(cjac)
if has_partitioned_shots:
return tuple(tdot(stack(q), cjac) for q in qjac)
return tdot(qjac, cjac)
if has_partitioned_shots:
return tuple(tuple(tdot(q, c) for c in cjac if c is not None) for q in qjac)
return tuple(tdot(qjac, c) for c in cjac if c is not None)

# Multiple parameters, multiple measurements
if not cjac_is_tuple:
return tuple(tdot(qml.math.stack(q), qml.math.stack(cjac)) for q in qjac)
return tuple(tuple(tdot(qml.math.stack(q), c) for c in cjac if c is not None) for q in qjac)
cjac = stack(cjac)
if has_partitioned_shots:
return tuple(tuple(tdot(stack(_q), cjac) for _q in q) for q in qjac)
return tuple(tdot(stack(q), cjac) for q in qjac)
if has_partitioned_shots:
return tuple(
tuple(tuple(tdot(stack(_q), c) for c in cjac if c is not None) for _q in q)
for q in qjac
)
return tuple(tuple(tdot(stack(q), c) for c in cjac if c is not None) for q in qjac)
30 changes: 22 additions & 8 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,50 +186,64 @@ class TestGradientTransformIntegration:

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 500], 3e-1)])
@pytest.mark.parametrize("slicing", [False, True])
def test_acting_on_qnodes_single_param(self, shots, slicing, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_single_param(self, shots, slicing, prefactor, atol):
"""Test that a gradient transform acts on QNodes with a single parameter correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
if slicing:
qml.RX(weights[0], wires=[0])
qml.RX(prefactor * weights[0], wires=[0])
else:
qml.RX(weights, wires=[0])
qml.RX(prefactor * weights, wires=[0])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543] if slicing else 0.543, requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
expected = np.array([-np.sin(w[0] if slicing else w), 0])

# Need to multiply 0 with w to get the right output shape for non-scalar w
expected = (-prefactor * np.sin(prefactor * w), w * 0)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
assert np.allclose(res, expected, atol=atol, rtol=0)

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 100], 2e-1)])
def test_acting_on_qnodes_multi_param(self, shots, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_multi_param(self, shots, prefactor, atol):
"""Test that a gradient transform acts on QNodes with multiple parameters correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
qml.RX(weights[0], wires=[0])
qml.RY(weights[1], wires=[1])
qml.RY(prefactor * weights[1], wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(1))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543, -0.654], requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
x, y = w
expected = np.array([[-np.sin(x), 0], [0, -2 * np.cos(y) * np.sin(y)]])
y *= prefactor
expected = np.array(
[
[-np.sin(x), 0],
[
2 * np.cos(x) * np.sin(x) * np.cos(y) ** 2,
2 * prefactor * np.cos(y) * np.sin(y) * np.cos(x) ** 2,
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
],
]
)
if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
Expand Down
Loading
Loading