Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix transpile for state measurement #4732

Merged
merged 8 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.33.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@
<h3>Bug fixes 🐛</h3>

* Fixes `LocalHilbertSchmidt.compute_decomposition` so the template can be used in a qnode.
[(#4719)](https://github.com/PennyLaneAI/pennylane/pull/4719)

* Fixes `transforms.transpile` with arbitrary measurement processes.
[(#4732)](https://github.com/PennyLaneAI/pennylane/pull/4732)

* Providing `work_wires=None` to `qml.GroverOperator` no longer interprets `None` as a wire.
[(#4668)](https://github.com/PennyLaneAI/pennylane/pull/4668)
Expand Down
37 changes: 7 additions & 30 deletions pennylane/transforms/transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def circuit():
"The transpile transform only supports gates acting on 1 or 2 qubits."
)

gates = []

# we wrap all manipulations inside stop_recording() so that we don't queue anything due to unrolling of templates
# or newly applied swap gates
with QueuingManager.stop_recording():
Expand All @@ -121,6 +119,7 @@ def stop_at(obj):
# make copy of ops
list_op_copy = expanded_tape.operations.copy()
measurements = expanded_tape.measurements.copy()
gates = []

while len(list_op_copy) > 0:
op = list_op_copy[0]
Expand All @@ -141,7 +140,7 @@ def stop_at(obj):
continue

# since in each iteration, we adjust indices of each op, we reset logical -> phyiscal mapping
map_wires = {w: w for w in tape.wires}
wire_map = {w: w for w in tape.wires}

# to make sure two qubit gates which act on non-neighbouring qubits q1, q2 can be applied, we first look
# for the shortest path between the two qubits in the connectivity graph. We then move the q2 into the
Expand All @@ -156,17 +155,16 @@ def stop_at(obj):
# swap wires
gates.append(SWAP(wires=[w0, w1]))
# update logical -> phyiscal mapping
map_wires = {
k: (w0 if v == w1 else (w1 if v == w0 else v)) for k, v in map_wires.items()
wire_map = {
k: (w0 if v == w1 else (w1 if v == w0 else v)) for k, v in wire_map.items()
}

# append op to gates with adjusted indices and remove from list
gates.append(_adjust_op_indices(op, map_wires))
gates.append(op.map_wires(wire_map))
list_op_copy.pop(0)

# adjust qubit indices in remaining ops + measurements to new mapping
list_op_copy = [_adjust_op_indices(op, map_wires) for op in list_op_copy]
measurements = [_adjust_mmt_indices(m, map_wires) for m in measurements]
list_op_copy = [op.map_wires(wire_map) for op in list_op_copy]
measurements = [m.map_wires(wire_map) for m in measurements]
new_tape = type(tape)(gates, measurements, shots=tape.shots)

def null_postprocessing(results):
Expand All @@ -176,24 +174,3 @@ def null_postprocessing(results):
return results[0]

return [new_tape], null_postprocessing


def _adjust_op_indices(_op, _map_wires):
"""helper function which adjusts wires in Operation according to the map _map_wires"""
_new_wires = Wires([_map_wires[w] for w in _op.wires])
_params = _op.parameters
if len(_params) == 0:
return type(_op)(wires=_new_wires)
return type(_op)(*_params, wires=_new_wires)


def _adjust_mmt_indices(_m, _map_wires):
"""helper function which adjusts wires in MeasurementProcess according to the map _map_wires"""
_new_wires = Wires([_map_wires[w] for w in _m.wires])

# change wires of observable
if _m.obs is None:
return type(_m)(eigvals=_m.eigvals(), wires=_new_wires)

_new_obs = type(_m.obs)(wires=_new_wires, id=_m.obs.id)
return type(_m)(obs=_new_obs)
15 changes: 15 additions & 0 deletions tests/transforms/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,18 @@ def circuit(param):
assert qml.math.allclose(
original_expectation, transpiled_expectation, atol=np.finfo(float).eps
)

def test_transpile_state(self):
"""Test that transpile works with state measurement process."""

tape = qml.tape.QuantumScript([qml.PauliX(0), qml.CNOT((0,2))], [qml.state()], shots=100)
batch, fn = qml.transforms.transpile(tape, coupling_map = [(0,1), (1,2)])

assert len(batch) == 1
assert fn(("a", )) == "a"

assert batch[0][0] == qml.PauliX(0)
assert batch[0][1] == qml.SWAP((1,2))
assert batch[0][2] == qml.CNOT((0,1))
assert batch[0][3] == qml.state()
assert batch[0].shots == tape.shots
Loading