Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make templates valid Pytrees #5698

Merged
merged 30 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cd1e401
fix templates
dwierichs May 16, 2024
1e5b022
revert unwanted
dwierichs May 16, 2024
b959e3d
qdrift fix; delete flatten tests
dwierichs May 16, 2024
20f1d3f
flip sign flatten test
dwierichs May 16, 2024
8594edd
reduce items
dwierichs May 16, 2024
ac00ccd
lint
dwierichs May 16, 2024
77236db
Merge branch 'master' into valid-templates
dwierichs May 16, 2024
bbe80a8
finish lint
dwierichs May 16, 2024
a8ca571
changelog
dwierichs May 16, 2024
dd9bbed
Merge branch 'master' into valid-templates
dwierichs May 17, 2024
162e687
map all the wires
dwierichs May 17, 2024
6f3eb71
Merge branch 'master' into valid-templates
dwierichs May 17, 2024
9d705f3
iqp
dwierichs May 17, 2024
cf1589b
bind_new_parameters QDrift
dwierichs May 22, 2024
de0acc2
review
dwierichs May 22, 2024
c28aaa3
update qdrift to not store the decomposition, but to sample an intern…
dwierichs May 23, 2024
5420b51
merge
dwierichs May 23, 2024
6fb3490
revert test removal
dwierichs May 23, 2024
30975ae
bundle type checking
dwierichs May 23, 2024
60bef1c
"[skip ci]"
dwierichs May 23, 2024
c929177
lint '[skip ci]'
dwierichs May 23, 2024
7a935ef
random decomp
dwierichs May 23, 2024
761c455
merge
dwierichs May 28, 2024
a708d1d
Update tests/templates/test_subroutines/test_qdrift.py
dwierichs May 28, 2024
18dd447
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
2bf69d4
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
0fc8c07
note
dwierichs May 29, 2024
cb241d2
Merge branch 'master' into valid-templates
dwierichs May 29, 2024
c8468f0
Merge branch 'master' into valid-templates
dwierichs May 30, 2024
4bf0ac5
Merge branch 'master' into valid-templates
dwierichs May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

<h3>Improvements 🛠</h3>

* A number of templates have been updated to be valid pytrees and PennyLane operations.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* `ctrl` now works with tuple-valued `control_values` when applied to any already controlled operation.
[(#5725)](https://github.com/PennyLaneAI/pennylane/pull/5725)

Expand Down Expand Up @@ -121,6 +124,10 @@

<h3>Breaking changes 💔</h3>

* A custom decomposition can no longer be provided to `QDrift`. Instead, apply the operations in your custom
operation directly with `qml.apply`.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* Sampling observables composed of `X`, `Y`, `Z` and `Hadamard` now returns values of type `float` instead of `int`.
[(#5607)](https://github.com/PennyLaneAI/pennylane/pull/5607)

Expand Down Expand Up @@ -156,6 +163,12 @@

<h3>Bug fixes 🐛</h3>

* `QuantumPhaseEstimation.map_wires` on longer modifies the original operation instance.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* The decomposition of `AmplitudeAmplification` now correctly queues all operations.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces.
[(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672)

Expand Down
24 changes: 19 additions & 5 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inner_func(*args, **kwargs):
return inner_func


def _check_decomposition(op):
def _check_decomposition(op, skip_wire_mapping):
"""Checks involving the decomposition."""
if op.has_decomposition:
decomp = op.decomposition()
Expand All @@ -64,6 +64,18 @@ def _check_decomposition(op):
assert o1 == o3, "decomposition must match queued operations"
assert o1 == o4, "decomposition must match expansion"
assert isinstance(o1, qml.operation.Operator), "decomposition must contain operators"

if skip_wire_mapping:
return
# Check that mapping wires transitions to the decomposition
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
mapped_op = op.map_wires(wire_map)
mapped_decomp = mapped_op.decomposition()
orig_decomp = op.decomposition()
for mapped_op, orig_op in zip(mapped_decomp, orig_decomp):
assert (
mapped_op.wires == qml.map_wires(orig_op, wire_map).wires
), "Operators in decomposition of wire-mapped operator must have mapped wires."
else:
failure_comment = "If has_decomposition is False, then decomposition must raise a ``DecompositionUndefinedError``."
_assert_error_raised(
Expand Down Expand Up @@ -216,17 +228,19 @@ def _check_bind_new_parameters(op):
assert qml.math.allclose(d1, d2), failure_comment


def _check_wires(op):
def _check_wires(op, skip_wire_mapping):
"""Check that wires are a ``Wires`` class and can be mapped."""
assert isinstance(op.wires, qml.wires.Wires), "wires must be a wires instance"

if skip_wire_mapping:
return
wire_map = {w: ascii_lowercase[i] for i, w in enumerate(op.wires)}
mapped_op = op.map_wires(wire_map)
new_wires = qml.wires.Wires(list(ascii_lowercase[: len(op.wires)]))
assert mapped_op.wires == new_wires, "wires must be mappable with map_wires"


def assert_valid(op: qml.operation.Operator, skip_pickle=False) -> None:
def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mapping=False) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
sure it has been correctly defined.

Expand Down Expand Up @@ -278,14 +292,14 @@ def __init__(self, wires):
assert qml.math.allclose(d, p), "data and parameters must match."

if len(op.wires) <= 26:
_check_wires(op)
_check_wires(op, skip_wire_mapping)
_check_copy(op)
_check_pytree(op)
if not skip_pickle:
_check_pickle(op)
_check_bind_new_parameters(op)

_check_decomposition(op)
_check_decomposition(op, skip_wire_mapping)
_check_matrix(op)
_check_matrix_matches_decomp(op)
_check_eigendecomposition(op)
10 changes: 10 additions & 0 deletions pennylane/ops/functions/bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def bind_new_parameters_commuting_evolution(
return qml.CommutingEvolution(new_hamiltonian, time, frequencies=freq, shifts=shifts)


@bind_new_parameters.register
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
def bind_new_parameters_qdrift(op: qml.QDrift, params: Sequence[TensorLike]):
new_hamiltonian = bind_new_parameters(op.hyperparameters["base"], params[:-1])
time = params[-1]
n = op.hyperparameters["n"]
seed = op.hyperparameters["seed"]

return qml.QDrift(new_hamiltonian, time, n=n, seed=seed)


@bind_new_parameters.register
def bind_new_parameters_fermionic_double_excitation(
op: qml.FermionicDoubleExcitation, params: Sequence[TensorLike]
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/qubit/arithmetic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def compute_decomposition(value, geq=True, wires=None, work_wires=None, **kwargs
small_val = not geq and value == 0
large_val = geq and value > 2 ** len(control_wires) - 1
if small_val or large_val:
gates = [Identity(0)]
gates = [Identity(wires[0])]
dwierichs marked this conversation as resolved.
Show resolved Hide resolved

else:
values = range(value, 2 ** (len(control_wires))) if geq else range(value)
Expand Down
7 changes: 5 additions & 2 deletions pennylane/ops/qubit/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,13 @@ def compute_decomposition(D, wires):
ops = [QubitUnitary(qml.math.tensordot(global_phase, qml.math.eye(2), axes=0), wires[0])]
for wire0 in range(n):
# Single PauliZ generators correspond to the coeffs at powers of two
ops.append(qml.RZ(coeffs[1 << wire0], n - 1 - wire0))
ops.append(qml.RZ(coeffs[1 << wire0], wires[n - 1 - wire0]))
# Double PauliZ generators correspond to the coeffs at the sum of two powers of two
ops.extend(
qml.IsingZZ(coeffs[(1 << wire0) + (1 << wire1)], [n - 1 - wire0, n - 1 - wire1])
qml.IsingZZ(
coeffs[(1 << wire0) + (1 << wire1)],
[wires[n - 1 - wire0], wires[n - 1 - wire1]],
)
for wire1 in range(wire0)
)

Expand Down
11 changes: 11 additions & 0 deletions pennylane/templates/embeddings/iqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
Contains the IQPEmbedding template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy
from itertools import combinations

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires


class IQPEmbedding(Operation):
Expand Down Expand Up @@ -186,6 +188,15 @@ def __init__(self, features, wires, n_repeats=1, pattern=None, id=None):

super().__init__(features, wires=wires, id=id)

def map_wires(self, wire_map):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["pattern"] = [
[wire_map.get(w, w) for w in wires] for wires in new_op._hyperparameters["pattern"]
]
return new_op

@property
def num_params(self):
return 1
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/all_singles_doubles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
Contains the AllSinglesDoubles template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import BasisState
from pennylane.wires import Wires


class AllSinglesDoubles(Operation):
Expand Down Expand Up @@ -155,6 +158,15 @@ def __init__(self, weights, wires, hf_state, singles=None, doubles=None, id=None

super().__init__(weights, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
for key in ["singles", "doubles"]:
new_op._hyperparameters[key] = tuple(
tuple(wire_map[w] for w in wires) for wires in new_op._hyperparameters[key]
)
return new_op

@property
def num_params(self):
return 1
Expand Down
27 changes: 21 additions & 6 deletions pennylane/templates/subroutines/amplitude_amplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"""

# pylint: disable-msg=too-many-arguments
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import Operation
from pennylane.wires import Wires


def _get_fixed_point_angles(iters, p_min):
Expand Down Expand Up @@ -101,16 +104,12 @@ def circuit():

def _flatten(self):
data = (self.hyperparameters["U"], self.hyperparameters["O"])
metadata = tuple(
(key, value) for key, value in self.hyperparameters.items() if key not in ["O", "U"]
)
metadata = tuple(item for item in self.hyperparameters.items() if item[0] not in ["O", "U"])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
return data, metadata

@classmethod
def _unflatten(cls, data, metadata):
U, O = (data[0], data[1])
hyperparams_dict = dict(metadata)
return cls(U, O, **hyperparams_dict)
return cls(*data, **dict(metadata))

def __init__(
self, U, O, iters=1, fixed_point=False, work_wire=None, p_min=0.9, reflection_wires=None
Expand Down Expand Up @@ -169,10 +168,26 @@ def compute_decomposition(**kwargs):
else:
for _ in range(iters):
ops.append(O)
if qml.QueuingManager.recording():
qml.apply(O)
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
ops.append(qml.Reflection(U, np.pi, reflection_wires=reflection_wires))

return ops

def map_wires(self, wire_map: dict):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["U"] = new_op._hyperparameters["U"].map_wires(wire_map)
new_op._hyperparameters["O"] = new_op._hyperparameters["O"].map_wires(wire_map)
new_op._hyperparameters["reflection_wires"] = Wires(
[wire_map.get(wire, wire) for wire in new_op._hyperparameters["reflection_wires"]]
)
new_op._hyperparameters["work_wire"] = wire_map.get(
w := new_op._hyperparameters["work_wire"], w
)
return new_op

def queue(self, context=qml.QueuingManager):
for op in [self.hyperparameters["U"], self.hyperparameters["O"]]:
context.remove(op)
Expand Down
11 changes: 11 additions & 0 deletions pennylane/templates/subroutines/approx_time_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
Contains the ApproxTimeEvolution template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import PauliRot
from pennylane.wires import Wires


class ApproxTimeEvolution(Operation):
Expand Down Expand Up @@ -139,6 +142,14 @@ def __init__(self, hamiltonian, time, n, id=None):
# trainable parameters are passed to the base init method
super().__init__(*hamiltonian.data, time, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["hamiltonian"] = qml.map_wires(
new_op._hyperparameters["hamiltonian"], wire_map
)
return new_op

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/commuting_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
Contains the CommutingEvolution template.
"""
# pylint: disable-msg=too-many-arguments,import-outside-toplevel
import copy

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires


class CommutingEvolution(Operation):
Expand Down Expand Up @@ -139,6 +142,15 @@ def __init__(self, hamiltonian, time, frequencies=None, shifts=None, id=None):

super().__init__(time, *hamiltonian.parameters, wires=hamiltonian.wires, id=id)

def map_wires(self, wire_map: dict):
# pylint: disable=protected-access
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
new_op._hyperparameters["hamiltonian"] = qml.map_wires(
new_op._hyperparameters["hamiltonian"], wire_map
)
return new_op

def queue(self, context=qml.QueuingManager):
context.remove(self.hyperparameters["hamiltonian"])
context.append(self)
Expand Down
12 changes: 12 additions & 0 deletions pennylane/templates/subroutines/fermionic_double_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
Contains the FermionicDoubleExcitation template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import copy

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.ops import CNOT, RX, RZ, Hadamard
from pennylane.wires import Wires


def _layer1(weight, s, r, q, p, set_cnot_wires):
Expand Down Expand Up @@ -532,6 +535,15 @@ def __init__(self, weight, wires1=None, wires2=None, id=None):
wires = wires1 + wires2
super().__init__(weight, wires=wires, id=id)

def map_wires(self, wire_map: dict):
new_op = copy.deepcopy(self)
new_op._wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
for key in ["wires1", "wires2"]:
new_op._hyperparameters[key] = Wires(
[wire_map.get(wire, wire) for wire in self._hyperparameters[key]]
)
return new_op

@property
def num_params(self):
return 1
Expand Down
3 changes: 3 additions & 0 deletions pennylane/templates/subroutines/hilbert_schmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(self, *params, v_function, v_wires, u_tape, id=None):

super().__init__(*params, wires=wires, id=id)

def map_wires(self, wire_map: dict):
raise NotImplementedError("Mapping the wires of HilbertSchmidt is not implemented.")

@property
def num_params(self):
return self._num_params
Expand Down
Loading
Loading