Skip to content

Commit

Permalink
feat: Track classical register indices for measurements (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
Altanali committed Jun 26, 2024
1 parent 7e717a2 commit c9730d2
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/braket/default_simulator/openqasm/_helpers/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ def convert_index(index: Union[RangeDefinition, IntegerLiteral]) -> Union[int, s

def flatten_indices(indices: list[IndexElement]) -> list:
"""Convert a[i][j][k] to the equivalent a[i, j, k]"""
return sum((index for index in indices), [])
result = []
for index in indices:
if isinstance(index, DiscreteSet):
result.append(index)
else:
result += index
return result


def unwrap_var_type(var_type: ClassicalType) -> ClassicalType:
Expand Down
11 changes: 9 additions & 2 deletions src/braket/default_simulator/openqasm/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from __future__ import annotations

from collections.abc import Iterable
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.results = []
self.qubit_set = set()
self.measured_qubits = []
self.target_classical_indices = []

if instructions:
for instruction in instructions:
Expand All @@ -61,11 +63,16 @@ def add_instruction(self, instruction: [GateOperation, KrausOperation]) -> None:
self.instructions.append(instruction)
self.qubit_set |= set(instruction.targets)

def add_measure(self, target: tuple[int]):
for qubit in target:
def add_measure(self, target: tuple[int], classical_targets: Iterable[int] = None):
for index, qubit in enumerate(target):
if qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.target_classical_indices.append(
classical_targets[index]
if classical_targets
else max(index, len(self.target_classical_indices))
)

def add_result(self, result: Results) -> None:
"""
Expand Down
34 changes: 32 additions & 2 deletions src/braket/default_simulator/openqasm/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._helpers.arrays import (
convert_range_def_to_range,
create_empty_array,
flatten_indices,
get_elements,
get_type_width,
)
Expand Down Expand Up @@ -471,12 +472,41 @@ def _(self, node: QuantumGateModifier) -> QuantumGateModifier:
@visit.register
def _(self, node: QuantumMeasurement) -> None:
qubits = self.context.get_qubits(self.visit(node.qubit))
self.context.add_measure(qubits)
return qubits

@visit.register
def _(self, node: QuantumMeasurementStatement) -> None:
"""The measure is performed but the assignment is ignored"""
self.visit(node.measure)
qubits = self.visit(node.measure)
targets = []
if node.target:
if isinstance(node.target, IndexedIdentifier):
indices = flatten_indices(node.target.indices)
if len(node.target.indices) != 1:
raise ValueError(
"Multi-Dimensional indexing not supported for classical registers."
)
elem = indices[0]
if isinstance(elem, DiscreteSet):
self._uses_advanced_language_features = True
target_indices = [self.visit(val).value for val in elem.values]
targets.extend(target_indices)
elif isinstance(elem, RangeDefinition):
self._uses_advanced_language_features = True
target_indices = convert_range_def_to_range(self.visit(elem))
targets.extend(target_indices)
else:
target_idx = elem.value
targets.append(target_idx)

if not len(targets):
targets = None

if targets and len(targets) != len(qubits):
raise ValueError(
f"Number of qubits ({len(qubits)}) does not match number of provided classical targets ({len(targets)})"
)
self.context.add_measure(qubits, targets)

@visit.register
def _(self, node: ClassicalAssignment) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/braket/default_simulator/openqasm/program_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def add_kraus_instruction(self, matrices: list[np.ndarray], target: list[int]):
"""
raise NotImplementedError

def add_measure(self, target: tuple[int]):
def add_measure(self, target: tuple[int], classical_targets: Iterable[int] = None):
"""Add qubit targets to be measured"""


Expand Down Expand Up @@ -902,5 +902,5 @@ def add_kraus_instruction(self, matrices: list[np.ndarray], target: list[int]):
def add_result(self, result: Results) -> None:
self._circuit.add_result(result)

def add_measure(self, target: tuple[int]):
self._circuit.add_measure(target)
def add_measure(self, target: tuple[int], classical_targets: Iterable[int] = None):
self._circuit.add_measure(target, classical_targets)
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ def test_basis_rotation_hermitian():
"b[0] = measure q[0];",
]
),
[0],
([0], [0]),
),
(
"\n".join(
Expand All @@ -2068,19 +2068,19 @@ def test_basis_rotation_hermitian():
"b = measure q;",
]
),
[0, 1, 2],
([0, 1, 2], [0, 1, 2]),
),
(
"\n".join(
[
"bit[1] b;",
"bit[2] b;",
"qubit[2] q;",
"h q[0];",
"h q[1];",
"b[0] = measure q[0:1];",
"b[0:1] = measure q[0:1];",
]
),
[0, 1],
([0, 1], [0, 1]),
),
(
"\n".join(
Expand All @@ -2091,11 +2091,11 @@ def test_basis_rotation_hermitian():
"cnot q[0], q[1];",
"cnot q[1], q[2];",
"b[0] = measure q[0];",
"b[1] = measure q[1];",
"b[2] = measure q[2];",
"b[2] = measure q[1];",
"b[1] = measure q[2];",
]
),
[0, 1, 2],
([0, 1, 2], [0, 2, 1]),
),
(
"\n".join(
Expand All @@ -2105,10 +2105,10 @@ def test_basis_rotation_hermitian():
"h q[0];",
"h q[1];",
"cnot q[1], q[2];",
"b[0] = measure q[{0, 2}];",
"b[{2, 1}] = measure q[{0, 2}];",
]
),
[0, 2],
([0, 2], [2, 1]),
),
(
"\n".join(
Expand All @@ -2119,7 +2119,7 @@ def test_basis_rotation_hermitian():
"b[0] = measure $0;",
]
),
[0],
([0], [0]),
),
(
"\n".join(
Expand All @@ -2130,7 +2130,7 @@ def test_basis_rotation_hermitian():
"}",
]
),
[0, 1, 2],
([0, 1, 2], [0, 1, 2]),
),
(
"\n".join(
Expand All @@ -2144,7 +2144,7 @@ def test_basis_rotation_hermitian():
"measure q[0];",
]
),
[1, 0],
([1, 0], [0, 1]),
),
(
"\n".join(
Expand All @@ -2154,13 +2154,14 @@ def test_basis_rotation_hermitian():
"b[0] = measure q[1:5];",
]
),
[1],
([1], [0]),
),
],
)
def test_measurement(qasm, expected):
circuit = Interpreter().build_circuit(qasm)
assert circuit.measured_qubits == expected
assert circuit.measured_qubits == expected[0]
assert circuit.target_classical_indices == expected[1]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2239,3 +2240,23 @@ def test_measure_invalid_qubit():
def test_measure_qubit_out_of_range(qasm, expected):
with pytest.raises(IndexError, match=expected):
Interpreter().build_circuit(qasm)


@pytest.mark.parametrize(
"qasm,error_message",
[
(
"\n".join(["OPENQASM 3.0;" "bit[2] b;", "qubit[1] q;", "b[{0, 1}] = measure q[0];"]),
re.escape(
"Number of qubits (1) does not match number of provided classical targets (2)"
),
),
(
"\n".join(["OPENQASM 3.0;" "bit[2] b;", "qubit[2] q;", "b[0][2] = measure q[1];"]),
re.escape("Multi-Dimensional indexing not supported for classical registers."),
),
],
)
def test_invalid_measurement_with_classical_indices(qasm, error_message):
with pytest.raises(ValueError, match=error_message):
Interpreter().build_circuit(qasm)

0 comments on commit c9730d2

Please sign in to comment.