Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StateMP accepts wires #4570

Merged
merged 8 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
process, `DensityMatrixMP`.
[(#4558)](https://github.com/PennyLaneAI/pennylane/pull/4558)

* The `StateMP` measurement now accepts a wire order (eg. a device wire order). The `process_state`
method will re-order the given state to go from the inputted wire-order to the process's wire-order.
If the process's wire-order contains extra wires, it will assume those are in the zero-state.
[(#4570)](https://github.com/PennyLaneAI/pennylane/pull/4570)

<h3>Breaking changes 💔</h3>

* The `__eq__` and `__hash__` methods of `Operator` and `MeasurementProcess` no longer rely on the
Expand Down
10 changes: 10 additions & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,15 @@ def _cond_tf(pred, true_fn, false_fn, args):
)


def _pad_tf(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
if ar.ndim(paddings) == 1:
paddings = (paddings,)
return _i("tf").pad(tensor, paddings, mode=mode, constant_values=constant_values, name=name)
timmysilv marked this conversation as resolved.
Show resolved Hide resolved


ar.register_function("tensorflow", "pad", _pad_tf)


# -------------------------------- Torch --------------------------------- #

ar.autoray._FUNC_ALIASES["torch", "unstack"] = "unbind"
Expand Down Expand Up @@ -694,6 +703,7 @@ def _sum_torch(tensor, axis=None, keepdims=False, dtype=None):

ar.register_function("torch", "sum", _sum_torch)
ar.register_function("torch", "cond", _cond)
ar.register_function("torch", "pad", _i("torch").nn.functional.pad)


# -------------------------------- JAX --------------------------------- #
Expand Down
34 changes: 28 additions & 6 deletions pennylane/measurements/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Sequence, Optional

import pennylane as qml
from pennylane.wires import Wires
from pennylane.wires import Wires, WireError

from .measurements import State, StateMeasurement

Expand Down Expand Up @@ -131,12 +131,13 @@ class StateMP(StateMeasurement):
Please refer to :func:`state` for detailed documentation.

Args:
wires (.Wires): The wires the measurement process applies to.
id (str): custom label given to a measurement instance, can be useful for some applications
where the instance has to be identified
"""

def __init__(self, *, id: Optional[str] = None):
super().__init__(wires=None, id=id)
def __init__(self, wires: Optional[Wires] = None, id: Optional[str] = None):
super().__init__(wires=wires, id=id)

@property
def return_type(self):
Expand All @@ -155,7 +156,29 @@ def shape(self, device, shots):

def process_state(self, state: Sequence[complex], wire_order: Wires):
# pylint:disable=redefined-outer-name
return state
wires = self.wires
if not wires or wire_order == wires:
return state

if set(wire_order) == set(wires):
state = qml.math.reshape(state, (2,) * len(wire_order))
desired_axes = [wires.index(w) for w in wire_order]
return qml.math.flatten(qml.math.transpose(state, desired_axes))

if not wires.contains_wires(wire_order):
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
raise WireError(
f"Unexpected wires {set(wire_order) - set(wires)} found in wire order. Expected wire order to be a subset of {wires}"
)

# pad with zeros, put existing wires last
state = qml.math.pad(state, (0, 2 ** len(wires) - 2 ** len(wire_order)))
state = qml.math.reshape(state, (2,) * len(wires))

# re-order
new_wire_order = Wires.unique_wires([wires, wire_order]) + wire_order
desired_axes = [new_wire_order.index(w) for w in wires]
state = qml.math.transpose(state, desired_axes)
return qml.math.flatten(state)


class DensityMatrixMP(StateMP):
Expand All @@ -170,8 +193,7 @@ class DensityMatrixMP(StateMP):
"""

def __init__(self, wires: Wires, id: Optional[str] = None):
# pylint:disable=non-parent-init-called,super-init-not-called
StateMeasurement.__init__(self, wires=wires, id=id)
super().__init__(wires=wires, id=id)

def shape(self, device, shots):
num_shot_elements = (
Expand Down
85 changes: 80 additions & 5 deletions tests/measurements/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from pennylane.math.quantum import reduce_statevector, reduce_dm
from pennylane.math.matrix_manipulation import _permute_dense_matrix
from pennylane.wires import Wires, WireError


class TestStateMP:
Expand All @@ -45,17 +46,91 @@ class TestStateMP:
def test_process_state_vector(self, vec):
"""Test the processing of a state vector."""

mp = StateMP()
mp = StateMP(wires=None)
assert mp.return_type == State
assert mp.numeric_type is complex

processed = mp.process_state(vec, None)
assert qml.math.allclose(processed, vec)

def test_state_does_not_accept_wires(self):
"""Test that StateMP does not accept wires."""
with pytest.raises(TypeError, match="unexpected keyword argument 'wires'"):
StateMP(wires=[0])
@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
def test_state_returns_itself_if_wires_match(self, interface):
"""Test that when wire_order matches the StateMP, the state is returned."""
ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface)
assert StateMP(wires=[1, 0]).process_state(ket, wire_order=Wires([1, 0])) is ket

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
@pytest.mark.parametrize("wires, wire_order", [([1, 0], [0, 1]), (["b", "a"], ["a", "b"])])
def test_reorder_state(self, interface, wires, wire_order):
"""Test that a state can be re-ordered."""
ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface)
result = StateMP(wires=wires).process_state(ket, wire_order=Wires(wire_order))
assert qml.math.allclose(result, np.array([0.48j, -0.64j, 0.48, 0.36]))
assert qml.math.get_interface(ket) == interface

@pytest.mark.parametrize(
"mp_wires, expected_state",
[
([0, 1, 2], [1, 0, 2, 0, 3, 0, 4, 0]),
([2, 0, 1], [1, 2, 3, 4, 0, 0, 0, 0]),
([1, 0, 2], [1, 0, 3, 0, 2, 0, 4, 0]),
([1, 2, 0], [1, 3, 0, 0, 2, 4, 0, 0]),
],
)
@pytest.mark.parametrize("custom_wire_labels", [False, True])
def test_expand_state_over_wires(self, mp_wires, expected_state, custom_wire_labels):
"""Test the expanded state is correctly ordered with extra wires being zero."""
wire_order = [0, 1]
if custom_wire_labels:
# non-lexicographical-ordered
wire_map = {0: "b", 1: "c", 2: "a"}
mp_wires = [wire_map[w] for w in mp_wires]
wire_order = ["b", "c"]
mp = StateMP(wires=mp_wires)
ket = np.arange(1, 5)
result = mp.process_state(ket, wire_order=Wires(wire_order))
assert np.array_equal(result, expected_state)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "autograd", "jax", "torch", "tensorflow"])
def test_expand_state_all_interfaces(self, interface):
"""Test that expanding the state over wires preserves interface."""
mp = StateMP(wires=[4, 2, 0, 1])
ket = qml.math.array([0.48j, 0.48, -0.64j, 0.36], like=interface)
result = mp.process_state(ket, wire_order=Wires([1, 2]))
reshaped = qml.math.reshape(result, (2, 2, 2, 2))
assert qml.math.all(reshaped[1, :, 1, :] == 0)
assert qml.math.allclose(reshaped[0, :, 0, :], np.array([[0.48j, -0.64j], [0.48, 0.36]]))
if interface != "autograd":
# autograd.numpy.pad drops pennylane tensor for some reason
assert qml.math.get_interface(result) == interface

@pytest.mark.jax
@pytest.mark.parametrize(
"wires,expected",
[
([1, 0], np.array([0.48j, -0.64j, 0.48, 0.36])),
([2, 1, 0], np.array([0.48j, -0.64j, 0.48, 0.36, 0.0, 0.0, 0.0, 0.0])),
],
)
def test_state_jax_jit(self, wires, expected):
"""Test that re-ordering and expanding works with jax-jit."""
import jax

@jax.jit
def get_state(ket):
return StateMP(wires=wires).process_state(ket, wire_order=Wires([0, 1]))

result = get_state(jax.numpy.array([0.48j, 0.48, -0.64j, 0.36]))
assert qml.math.allclose(result, expected)
assert isinstance(result, jax.Array)

def test_wire_ordering_error(self):
"""Test that a wire order error is raised when unknown wires are given."""
with pytest.raises(WireError, match=r"Unexpected wires \{2\} found in wire order"):
StateMP(wires=[0, 1]).process_state([1, 0], wire_order=[2])


class TestDensityMatrixMP:
Expand Down
Loading