Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make templates valid Pytrees #5698

Merged
merged 30 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cd1e401
fix templates
dwierichs May 16, 2024
1e5b022
revert unwanted
dwierichs May 16, 2024
b959e3d
qdrift fix; delete flatten tests
dwierichs May 16, 2024
20f1d3f
flip sign flatten test
dwierichs May 16, 2024
8594edd
reduce items
dwierichs May 16, 2024
ac00ccd
lint
dwierichs May 16, 2024
77236db
Merge branch 'master' into valid-templates
dwierichs May 16, 2024
bbe80a8
finish lint
dwierichs May 16, 2024
a8ca571
changelog
dwierichs May 16, 2024
dd9bbed
Merge branch 'master' into valid-templates
dwierichs May 17, 2024
162e687
map all the wires
dwierichs May 17, 2024
6f3eb71
Merge branch 'master' into valid-templates
dwierichs May 17, 2024
9d705f3
iqp
dwierichs May 17, 2024
cf1589b
bind_new_parameters QDrift
dwierichs May 22, 2024
de0acc2
review
dwierichs May 22, 2024
c28aaa3
update qdrift to not store the decomposition, but to sample an intern…
dwierichs May 23, 2024
5420b51
merge
dwierichs May 23, 2024
6fb3490
revert test removal
dwierichs May 23, 2024
30975ae
bundle type checking
dwierichs May 23, 2024
60bef1c
"[skip ci]"
dwierichs May 23, 2024
c929177
lint '[skip ci]'
dwierichs May 23, 2024
7a935ef
random decomp
dwierichs May 23, 2024
761c455
merge
dwierichs May 28, 2024
a708d1d
Update tests/templates/test_subroutines/test_qdrift.py
dwierichs May 28, 2024
18dd447
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
2bf69d4
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
0fc8c07
note
dwierichs May 29, 2024
cb241d2
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
c8468f0
Merge branch 'master' into valid-templates
dwierichs May 30, 2024
4bf0ac5
Merge branch 'master' into valid-templates
dwierichs May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pennylane/templates/subroutines/amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def compute_decomposition(**kwargs):
else:
for _ in range(iters):
ops.append(O)
if qml.QueuingManager.recording():
qml.apply(O)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
ops.append(qml.Reflection(U, np.pi, reflection_wires=reflection_wires))

return ops
Expand Down
11 changes: 5 additions & 6 deletions pennylane/templates/subroutines/kupccgsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,10 @@ def ansatz(weights):
grad_method = None

def _flatten(self):
hyperparameters = (
("k", self.hyperparameters["k"]),
("delta_sz", self.hyperparameters["delta_sz"]),
# tuple version of init_state is essentially identical, but is hashable
("init_state", tuple(self.hyperparameters["init_state"])),

# Do not need to flatten s_wires or d_wires because they are derived hyperparameters
hyperparameters = tuple(
(key, self.hyperparameters[key]) for key in ["k", "delta_sz", "init_state"]
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
)
return self.data, (self.wires, hyperparameters)

Expand Down Expand Up @@ -242,7 +241,7 @@ def __init__(self, weights, wires, k=1, delta_sz=0, init_state=None, id=None):
raise ValueError(f"Elements of 'init_state' must be integers; got {init_state.dtype}")

self._hyperparameters = {
"init_state": init_state,
"init_state": tuple(init_state),
"s_wires": s_wires,
"d_wires": d_wires,
"k": k,
Expand Down
44 changes: 13 additions & 31 deletions pennylane/templates/subroutines/qdrift.py
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ def my_circ(time):

"""

def _flatten(self):
h = self.hyperparameters["base"]
hashable_hyperparameters = tuple(
(key, value) for key, value in self.hyperparameters.items() if key != "base"
)
return (h, self.data[-1]), hashable_hyperparameters

@classmethod
def _unflatten(cls, data, metadata):
return cls(*data, **dict(metadata))
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

def __init__( # pylint: disable=too-many-arguments
self, hamiltonian, time, n=1, seed=None, decomposition=None, id=None
):
Expand Down Expand Up @@ -178,7 +189,7 @@ def __init__( # pylint: disable=too-many-arguments
"coefficients of the input Hamiltonian."
)

if decomposition is None: # need to do this to allow flatten and _unflatten
if decomposition is None: # need to do this to allow flatten and unflatten
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
unwrapped_coeffs = unwrap(coeffs)
decomposition = _sample_decomposition(unwrapped_coeffs, ops, time, n=n, seed=seed)

Expand All @@ -188,42 +199,13 @@ def __init__( # pylint: disable=too-many-arguments
"base": hamiltonian,
"decomposition": decomposition,
}
super().__init__(time, wires=hamiltonian.wires, id=id)
super().__init__(*hamiltonian.data, time, wires=hamiltonian.wires, id=id)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["base"])
context.append(self)
return self

@classmethod
def _unflatten(cls, data, metadata):
"""Recreate an operation from its serialized format.

Args:
data: the trainable component of the operation
metadata: the non-trainable component of the operation

The output of ``Operator._flatten`` and the class type must be sufficient to reconstruct the original
operation with ``Operator._unflatten``.

**Example:**

>>> op = qml.Rot(1.2, 2.3, 3.4, wires=0)
>>> op._flatten()
((1.2, 2.3, 3.4), (<Wires = [0]>, ()))
>>> qml.Rot._unflatten(*op._flatten())
>>> op = qml.PauliRot(1.2, "XY", wires=(0,1))
>>> op._flatten()
((1.2,), (<Wires = [0, 1]>, (('pauli_word', 'XY'),)))
>>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") )
>>> type(op)._unflatten(*op._flatten())
Controlled(U2(3.4, 4.5, wires=['a']), control_wires=['b', 'c'])

"""
hyperparameters_dict = dict(metadata[1])
hamiltonian = hyperparameters_dict.pop("base")
return cls(hamiltonian, *data, **hyperparameters_dict)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def compute_decomposition(*args, **kwargs): # pylint: disable=unused-argument
r"""Representation of the operator as a product of other operators (static method).
Expand Down
4 changes: 2 additions & 2 deletions pennylane/templates/subroutines/qmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def __init__(self, probs, func, target_wires, estimation_wires, id=None):
"The probability distribution must have a length that is a power of two"
)

target_wires = list(target_wires)
estimation_wires = list(estimation_wires)
target_wires = tuple(target_wires)
estimation_wires = tuple(estimation_wires)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
wires = target_wires + estimation_wires

if num_target_wires != len(target_wires):
Expand Down
16 changes: 9 additions & 7 deletions pennylane/templates/subroutines/qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
Contains the QuantumPhaseEstimation template.
"""
# pylint: disable=too-many-arguments,arguments-differ
import copy

import pennylane as qml
from pennylane.operation import AnyWires, Operator
from pennylane.queuing import QueuingManager
from pennylane.resource.error import ErrorOperation, SpectralNormError
from pennylane.wires import Wires


class QuantumPhaseEstimation(ErrorOperation):
Expand Down Expand Up @@ -232,17 +235,16 @@ def error(self):

# pylint: disable=protected-access
def map_wires(self, wire_map: dict):
new_op = super().map_wires(wire_map)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["unitary"] = qml.map_wires(
new_op._hyperparameters["unitary"], wire_map
)

new_op._hyperparameters["estimation_wires"] = [
wire_map.get(wire, wire) for wire in self.estimation_wires
]
new_op._hyperparameters["target_wires"] = [
wire_map.get(wire, wire) for wire in self.target_wires
]
for key in ["estimation_wires", "target_wires"]:
new_op._hyperparameters[key] = [
wire_map.get(wire, wire) for wire in self.hyperparameters[key]
]

return new_op

Expand Down
5 changes: 3 additions & 2 deletions pennylane/templates/subroutines/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class Reflection(Operation):
Args:
U (Operator): the operator that prepares the state :math:`|\Psi\rangle`
alpha (float): the angle of the operator, default is :math:`\pi`
reflection_wires (Any or Iterable[Any]): subsystem of wires on which to reflect, the default is ``None`` and the reflection will be applied on the ``U`` wires
reflection_wires (Any or Iterable[Any]): subsystem of wires on which to reflect, the
default is ``None`` and the reflection will be applied on the ``U`` wires.

**Example**

Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(self, U, alpha=np.pi, reflection_wires=None, id=None):

self._hyperparameters = {
"base": U,
"reflection_wires": reflection_wires,
"reflection_wires": tuple(reflection_wires),
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
}

super().__init__(alpha, wires=wires, id=id)
Expand Down
13 changes: 6 additions & 7 deletions pennylane/templates/subroutines/trotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__( # pylint: disable=too-many-arguments
"check_hermitian": check_hermitian,
}

super().__init__(time, wires=hamiltonian.wires, id=id)
super().__init__(*hamiltonian.data, time, wires=hamiltonian.wires, id=id)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["base"])
Expand Down Expand Up @@ -288,7 +288,7 @@ def error(
SpectralNormError: The spectral norm error.
"""
base_unitary = self.hyperparameters["base"]
t, p, n = (self.parameters[0], self.hyperparameters["order"], self.hyperparameters["n"])
t, p, n = (self.parameters[-1], self.hyperparameters["order"], self.hyperparameters["n"])

parameters = [t] + base_unitary.parameters
if any(
Expand Down Expand Up @@ -343,10 +343,10 @@ def _flatten(self):
(<Wires = ['b', 'c']>, (True, True), <Wires = []>))
"""
hamiltonian = self.hyperparameters["base"]
time = self.parameters[0]
time = self.data[-1]

hashable_hyperparameters = tuple(
(key, value) for key, value in self.hyperparameters.items() if key != "base"
item for item in self.hyperparameters.items() if item[0] != "base"
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
)
return (hamiltonian, time), hashable_hyperparameters

Expand Down Expand Up @@ -375,8 +375,7 @@ def _unflatten(cls, data, metadata):
Controlled(U2(3.4, 4.5, wires=['a']), control_wires=['b', 'c'])

"""
hyperparameters_dict = dict(metadata)
return cls(*data, **hyperparameters_dict)
return cls(*data, **dict(metadata))

@staticmethod
def compute_decomposition(*args, **kwargs):
Expand All @@ -399,7 +398,7 @@ def compute_decomposition(*args, **kwargs):
Returns:
list[Operator]: decomposition of the operator
"""
time = args[0]
time = args[-1]
n = kwargs["n"]
order = kwargs["order"]
ops = kwargs["base"].operands
Expand Down
6 changes: 5 additions & 1 deletion pennylane/templates/subroutines/uccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ def __init__(self, weights, wires, s_wires=None, d_wires=None, init_state=None,
if init_state.dtype != np.dtype("int"):
raise ValueError(f"Elements of 'init_state' must be integers; got {init_state.dtype}")

self._hyperparameters = {"init_state": init_state, "s_wires": s_wires, "d_wires": d_wires}
self._hyperparameters = {
"init_state": tuple(init_state),
"s_wires": tuple(tuple(w) for w in s_wires),
"d_wires": tuple(tuple(tuple(w) for w in dw) for dw in d_wires),
}
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(weights, wires=wires, id=id)

Expand Down
24 changes: 7 additions & 17 deletions tests/templates/test_subroutines/test_amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def test_error_wrong_work_wire(self, wires, fixed_point, work_wire):
with pytest.raises(ValueError, match="work_wire must be different from the wires of O."):
qml.AmplitudeAmplification(U, O, iters=3, fixed_point=fixed_point, work_wire=work_wire)

def test_standard_validity(self):
"""Test standard validity using assert_valid."""
U = generator(wires=range(3))
O = oracle([0, 2], wires=range(3))
op = qml.AmplitudeAmplification(U, O, iters=3, fixed_point=False)
qml.ops.functions.assert_valid(op)


@pytest.mark.parametrize(
"n_wires, items, iters",
Expand Down Expand Up @@ -257,23 +264,6 @@ def circuit3():
assert np.allclose(circuit1(), circuit3())


# pylint: disable=protected-access
def test_flatten_and_unflatten():
"""Test the _flatten and _unflatten methods for AmplitudeAmplification."""

op = qml.AmplitudeAmplification(qml.RX(0.25, wires=0), qml.PauliZ(0))
data, metadata = op._flatten()

assert len(data) == 2
assert len(metadata) == 5

new_op = type(op)._unflatten(*op._flatten())
assert qml.equal(op, new_op)
assert op is not new_op

assert hash(metadata)


def test_amplification():
"""Test that AmplitudeAmplification amplifies a marked element."""

Expand Down
19 changes: 0 additions & 19 deletions tests/templates/test_subroutines/test_approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,6 @@ def test_standard_validity():
qml.ops.functions.assert_valid(op)


# pylint: disable=protected-access
def test_flatten_unflatten():
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Tests the _flatten and _unflatten methods."""
H = 2.0 * qml.PauliX(0) + 3.0 * qml.PauliY(0)
t = 0.1
op = qml.ApproxTimeEvolution(H, t, n=20)
data, metadata = op._flatten()
assert data[0] is H
assert data[1] == t
assert metadata == (20,)

# check metadata hashable
assert hash(metadata)

new_op = type(op)._unflatten(*op._flatten())
assert qml.equal(op, new_op)
assert new_op is not op


def test_queuing():
"""Test that ApproxTimeEvolution de-queues the input hamiltonian."""

Expand Down
22 changes: 0 additions & 22 deletions tests/templates/test_subroutines/test_commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,6 @@ def test_standard_validity():
qml.ops.functions.assert_valid(op)


# pylint: disable=protected-access
def test_flatten_unflatten():
"""Unit tests for the flatten and unflatten methods."""
H = 2.0 * qml.PauliX(0) @ qml.PauliY(1) + 3.0 * qml.PauliY(0) @ qml.PauliZ(1)
time = 0.5
frequencies = (2, 4)
shifts = (1, 0.5)
op = qml.CommutingEvolution(H, time, frequencies=frequencies, shifts=shifts)
data, metadata = op._flatten()

assert hash(metadata)

assert len(data) == 2
assert data[1] is H
assert data[0] == time
assert metadata == (frequencies, shifts)

new_op = type(op)._unflatten(*op._flatten())
assert qml.equal(op, new_op)
assert op is not new_op


def test_adjoint():
"""Tests the CommutingEvolution.adjoint method provides the correct adjoint operation."""

Expand Down
22 changes: 0 additions & 22 deletions tests/templates/test_subroutines/test_controlled_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,6 @@ def test_id(self):
op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2], id="a")
assert op.id == "a"

# pylint: disable=protected-access
def test_flatten_and_unflatten(self):
"""Test the _flatten and _unflatten methods for ControlledSequence"""

op = qml.ControlledSequence(qml.RX(0.25, wires=3), control=[0, 1, 2])
data, metadata = op._flatten()

assert len(data) == 1
assert qml.equal(data[0], op.base)

assert len(metadata) == 1
assert metadata[0] == op.control

# make sure metadata is hashable
assert hash(metadata)

new_op = type(op)._unflatten(*op._flatten())

assert qml.equal(op.base, new_op.base)
assert op.control_wires == new_op.control_wires
assert op is not new_op

def test_overlapping_wires_error(self):
"""Test that an error is raised if the wires of the base
operator and the control wires overlap"""
Expand Down
21 changes: 0 additions & 21 deletions tests/templates/test_subroutines/test_double_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,6 @@ def test_standard_validity():
qml.ops.functions.assert_valid(op)


# pylint: disable=protected-access
def test_flatten_unflatten():
"""Test the _flatten and _unflatten methods."""
weight = 0.5
wires1 = qml.wires.Wires((0, 1))
wires2 = qml.wires.Wires((2, 3, 4))
op = qml.FermionicDoubleExcitation(weight, wires1=wires1, wires2=wires2)

data, metadata = op._flatten()
assert data == (0.5,)
assert metadata[0] == wires1
assert metadata[1] == wires2

# test that its hashable
assert hash(metadata)

new_op = type(op)._unflatten(*op._flatten())
assert qml.equal(op, new_op)
assert op is not new_op


class TestDecomposition:
"""Tests that the template defines the correct decomposition."""

Expand Down
7 changes: 0 additions & 7 deletions tests/templates/test_subroutines/test_fable.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ def test_standard_validity(self, input_matrix):
op = qml.FABLE(input_matrix, wires=range(5), tol=0.01)
qml.ops.functions.assert_valid(op)

# pylint: disable=protected-access
def test_flatten_unflatten(self, input_matrix):
"""Test the flatten and unflatten methods."""
op = qml.FABLE(input_matrix, wires=range(5), tol=0.01)
new_op = type(op)._unflatten(*op._flatten())
assert qml.equal(op, new_op)

@pytest.mark.parametrize(
("input", "wires"),
[
Expand Down
Loading
Loading