Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the contraction of quantum and classical Jacobians consistent in gradient_transform #4945

Merged
merged 21 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@

<h3>Breaking changes 💔</h3>

* Applying a `gradient_transform` to a QNode directly now gives the same shape and type independent
of whether there is classical processing in the node.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* The function `qml.transforms.classical_jacobian` has been moved to the gradients module
and is now accessible as `qml.gradients.classical_jacobian`.
[(#4900)](https://github.com/PennyLaneAI/pennylane/pull/4900)
Expand Down Expand Up @@ -462,6 +466,10 @@

<h3>Bug fixes 🐛</h3>

* Fixed a bug where the shape and type of derivatives obtained by applying a gradient transform to
a QNode differed, based on whether the QNode uses classical coprocessing.
[(#4945)](https://github.com/PennyLaneAI/pennylane/pull/4945)

* Fixed a bug where the parameter-shift rule of `qml.ctrl(op)` was wrong if `op` had a generator
that has two or more eigenvalues and is stored as a `SparseHamiltonian`.
[(#4899)](https://github.com/PennyLaneAI/pennylane/pull/4899)
Expand Down Expand Up @@ -568,4 +576,4 @@ Mudit Pandey,
Matthew Silverman,
Jay Soni,
David Wierichs,
Justin Woodring.
Justin Woodring.
51 changes: 39 additions & 12 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,54 +411,81 @@ def _contract_qjac_with_cjac(qjac, cjac, tape):
cjac = cjac[0]

cjac_is_tuple = isinstance(cjac, tuple)
if not cjac_is_tuple:
is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]
# skip_cjac = False
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
# if not cjac_is_tuple:
# is_square = cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1]

if not qml.math.is_abstract(cjac) and (
is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0]))
):
# Classical Jacobian is the identity. No classical processing is present in the QNode
return qjac
# if not qml.math.is_abstract(cjac) and (
# is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0]))
# ):
# skip_cjac = True
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

multi_meas = num_measurements > 1

if cjac_is_tuple:
multi_params = True
single_tape_param = False
else:
_qjac = qjac
if multi_meas:
_qjac = _qjac[0]
if has_partitioned_shots:
_qjac = _qjac[0]
multi_params = isinstance(_qjac, tuple)
single_tape_param = not isinstance(_qjac, tuple)

tdot = partial(qml.math.tensordot, axes=[[0], [0]])

if not multi_params:
if single_tape_param:
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
# Without dimension (e.g. expval) or with dimension (e.g. probs)
def _reshape(x):
return qml.math.reshape(x, (1,) if x.shape == () else (1, -1))

if not (multi_meas or has_partitioned_shots):
# Single parameter, single measurements
# Single parameter, single measurements, no shot vector
# if skip_cjac:
# return qml.math.moveaxis(_reshape(qjac), 0, -1)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tdot(_reshape(qjac), cjac)

if not (multi_meas and has_partitioned_shots):
# Single parameter, multiple measurements or shot vector, but not both
# if skip_cjac:
# return tuple(qml.math.moveaxis(_reshape(q), 0, -1) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tdot(_reshape(q), cjac) for q in qjac)

# Single parameter, multiple measurements
# Single parameter, multiple measurements, and shot vector
# if skip_cjac:
# return tuple(tuple(qml.math.moveaxis(_reshape(_q), 0, -1) for _q in q) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tuple(tdot(_reshape(_q), cjac) for _q in q) for q in qjac)

if not multi_meas:
# Multiple parameters, single measurement
qjac = qml.math.stack(qjac)
if not cjac_is_tuple:
# if skip_cjac:
# return qml.math.moveaxis(qjac, 0, -1)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
if has_partitioned_shots:
return tuple(tdot(qml.math.stack(q), qml.math.stack(cjac)) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tdot(qjac, qml.math.stack(cjac))
if has_partitioned_shots:
print("this")
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tuple(tdot(q, c) for c in cjac if c is not None) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tdot(qjac, c) for c in cjac if c is not None)

# Multiple parameters, multiple measurements
if not cjac_is_tuple:
if has_partitioned_shots:
# if skip_cjac:
# return tuple(tuple(qml.math.moveaxis(qml.math.stack(_q), 0, -1) for _q in q) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(
tuple(tdot(qml.math.stack(_q), qml.math.stack(cjac)) for _q in q) for q in qjac
)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
# if skip_cjac:
# return tuple(qml.math.moveaxis(qml.math.stack(q), 0, -1) for q in qjac)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tdot(qml.math.stack(q), qml.math.stack(cjac)) for q in qjac)
if has_partitioned_shots:
return tuple(
tuple(tuple(tdot(qml.math.stack(_q), c) for c in cjac if c is not None) for _q in q)
for q in qjac
)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return tuple(tuple(tdot(qml.math.stack(q), c) for c in cjac if c is not None) for q in qjac)


Expand Down
33 changes: 24 additions & 9 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,50 +203,65 @@ class TestGradientTransformIntegration:

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 500], 3e-1)])
@pytest.mark.parametrize("slicing", [False, True])
def test_acting_on_qnodes_single_param(self, shots, slicing, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_single_param(self, shots, slicing, prefactor, atol):
"""Test that a gradient transform acts on QNodes with a single parameter correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
if slicing:
qml.RX(weights[0], wires=[0])
qml.RX(prefactor * weights[0], wires=[0])
else:
qml.RX(weights, wires=[0])
qml.RX(prefactor * weights, wires=[0])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543] if slicing else 0.543, requires_grad=True)
scalar_param = 0.543
w = np.array([scalar_param] if slicing else scalar_param, requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
expected = np.array([-np.sin(w[0] if slicing else w), 0])

# Need to multiply 0 with w to get the right output shape for non-scalar w
expected = (-prefactor * np.sin(prefactor * w), w * 0)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
assert np.allclose(res, expected, atol=atol, rtol=0)

@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 100], 2e-1)])
def test_acting_on_qnodes_multi_param(self, shots, atol):
@pytest.mark.parametrize("prefactor", [1.0, 2.0])
def test_acting_on_qnodes_multi_param(self, shots, prefactor, atol):
"""Test that a gradient transform acts on QNodes with multiple parameters correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
def circuit(weights):
qml.RX(weights[0], wires=[0])
qml.RY(weights[1], wires=[1])
qml.RY(prefactor * weights[1], wires=[1])
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliX(1))
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(1))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

grad_fn = qml.gradients.param_shift(circuit)

w = np.array([0.543, -0.654], requires_grad=True)
res = grad_fn(w)
assert circuit.interface == "auto"
x, y = w
expected = np.array([[-np.sin(x), 0], [0, -2 * np.cos(y) * np.sin(y)]])
y *= prefactor
expected = np.array(
[
[-np.sin(x), 0],
[
2 * np.cos(x) * np.sin(x) * np.cos(y) ** 2,
2 * prefactor * np.cos(y) * np.sin(y) * np.cos(x) ** 2,
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
],
]
)
if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
Expand Down
Loading
Loading