Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve mid-circuit measurement conversion abilities #417

Merged
merged 25 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
080529c
add support for conditional ops
obliviateandsurrender Feb 10, 2024
6d3e434
happy `codefactor`
obliviateandsurrender Feb 10, 2024
a449ba4
add logic for `IfElseOp`
obliviateandsurrender Feb 10, 2024
35e0693
add support of `SwitchCaseOp`
obliviateandsurrender Feb 12, 2024
2ec080d
happy `codefactor`
obliviateandsurrender Feb 12, 2024
72914a2
fix `control_values`
obliviateandsurrender Feb 12, 2024
bc9d5f0
minor tweaks
obliviateandsurrender Feb 12, 2024
b396667
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane-qis…
obliviateandsurrender Feb 13, 2024
5655416
codefactor?
obliviateandsurrender Feb 13, 2024
91f192e
codefactor?
obliviateandsurrender Feb 13, 2024
a6817ac
`changelog`
obliviateandsurrender Feb 13, 2024
8872599
minor tweaks
obliviateandsurrender Feb 14, 2024
bc11bc7
address comments
obliviateandsurrender Feb 15, 2024
9b1fca7
minor tweaks
obliviateandsurrender Feb 19, 2024
905f3d1
readying master merging
obliviateandsurrender Feb 19, 2024
0da02b4
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane-qis…
obliviateandsurrender Feb 19, 2024
6ef6b6c
Merge branch 'master' into cond-mid-meas-support
obliviateandsurrender Feb 20, 2024
0856c7c
apply suggestions
obliviateandsurrender Feb 20, 2024
1e3727b
happy `black`
obliviateandsurrender Feb 20, 2024
6137bcf
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane-qis…
obliviateandsurrender Feb 21, 2024
41daafa
minor tweak
obliviateandsurrender Feb 21, 2024
cf26bf1
happy `black`
obliviateandsurrender Feb 21, 2024
2f8edc2
address comments
obliviateandsurrender Feb 21, 2024
9188978
fix tests?
obliviateandsurrender Feb 21, 2024
927cb2d
minor tweak
obliviateandsurrender Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
a dictionary. The old dictionary UI continues to be supported.
[(#406)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/406)

* Measurement operations are now added to the PennyLane template when a `QuantumCircuit`
* Measurement operations are now added to the PennyLane template when a ``QuantumCircuit``
is converted using `load`. Additionally, one can override any existing terminal
measurements by providing a list of PennyLane
`measurements <https://docs.pennylane.ai/en/stable/introduction/measurements.html>`_ themselves.
[(#405)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/405)

* Added support for coverting conditional operations based on mid-circuit measurements and
two of the ``ControlFlowOp`` operations - ``IfElseOp`` and ``SwitchCaseOp`` when converting
a ``QuantumCircuit`` using `load`.
[(#417)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/417)

### Breaking changes 💔

### Deprecations 👋
Expand Down
264 changes: 183 additions & 81 deletions pennylane_qiskit/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
"""
from typing import Dict, Any
import warnings
from functools import partial, reduce

import numpy as np
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter, ParameterExpression, Measure, Barrier
from qiskit.circuit import Parameter, ParameterExpression, Measure, Barrier, ControlFlowOp
from qiskit.circuit.controlflow.switch_case import _DefaultCaseType
from qiskit.circuit.library import GlobalPhaseGate
from qiskit.exceptions import QiskitError
from sympy import lambdify
Expand Down Expand Up @@ -49,6 +51,31 @@ def _check_parameter_bound(param: Parameter, unbound_params: Dict[Parameter, Any
raise ValueError(f"The parameter {param} was not bound correctly.".format(param))


def _process_basic_param_args(params, *args, **kwargs):
"""Process the basic conditions for parameter dictionary computation.

Returns:
params (dict): A dictionary mapping ``quantum_circuit.parameters`` to values
flag (bool): Indicating whether the returned ``params`` can be used.
"""

# if no kwargs are passed, and a dictionary has been passed as a single argument, then assume it is params
if params is None and not kwargs and (len(args) == 1 and isinstance(args[0], dict)):
return (args[0], True)

if not args and not kwargs:
return (params, True)

# make params dict if using args and/or kwargs
if params is not None:
raise RuntimeError(
"Cannot define parameters via the params kwarg when passing Parameter values "
"as individual args or kwargs."
)

return ({}, False)


def _format_params_dict(quantum_circuit, params, *args, **kwargs):
"""Processes the inputs for calling the quantum function and returns
a dictionary of the format ``{Parameter("name"): value}`` for all the parameters.
Expand All @@ -73,23 +100,11 @@ def _format_params_dict(quantum_circuit, params, *args, **kwargs):
params (dict): A dictionary mapping ``quantum_circuit.parameters`` to values
"""

# if no kwargs are passed, and a dictionary has been passed as a single argument, then assume it is params
if params is None and not kwargs and (len(args) == 1 and isinstance(args[0], dict)):
return args[0]
params, flag = _process_basic_param_args(params, *args, **kwargs)

if not args and not kwargs:
if flag:
return params

# make params dict if using args and/or kwargs
if params is not None:
raise RuntimeError(
"Cannot define parameters via the params kwarg when passing Parameter values "
"as individual args or kwargs."
)

# create en empty params dict
params = {}

# populate it with any parameters defined as kwargs
for k, v in kwargs.items():
# the key needs to be the actual Parameter, whereas kwargs keys are parameter names
Expand All @@ -107,8 +122,9 @@ def _format_params_dict(quantum_circuit, params, *args, **kwargs):
# all other checks regarding correct arguments will be processed in _check_circuit_and_assign_parameters
# (based on the full params dict generated by this function), but this information can only be captured here
if len(args) > len(arg_parameters):
s = "s" if len(arg_parameters) > 1 else ""
raise TypeError(
f"Expected {len(arg_parameters)} positional argument{'s' if len(arg_parameters) > 1 else ''} but {len(args)} were given"
f"Expected {len(arg_parameters)} positional argument{s} but {len(args)} were given"
)
params.update(dict(zip(arg_parameters, args)))

Expand Down Expand Up @@ -213,24 +229,7 @@ def map_wires(qc_wires: list, wires: list) -> dict:
)


def execute_supported_operation(operation_name: str, parameters: list, wires: list):
"""Utility function that executes an operation that is natively supported by PennyLane.

Args:
operation_name (str): Name of the PL operator to be executed
parameters (str): parameters of the operation that will be executed
wires (list): wires of the operation
"""
operation = getattr(pennylane_ops, operation_name)

if not parameters:
operation(wires=wires)
elif operation_name in ["QubitStateVector", "StatePrep"]:
operation(np.array(parameters), wires=wires)
else:
operation(*parameters, wires=wires)


# pylint:disable=too-many-statements, too-many-branches
def load(quantum_circuit: QuantumCircuit, measurements=None):
"""Loads a PennyLane template from a Qiskit QuantumCircuit.
Warnings are created for each of the QuantumCircuit instructions that were
Expand All @@ -246,7 +245,7 @@ def load(quantum_circuit: QuantumCircuit, measurements=None):
function: the resulting PennyLane template
"""

# pylint:disable=too-many-branches
# pylint:disable=too-many-branches, fixme, protected-access
def _function(*args, params: dict = None, wires: list = None, **kwargs):
"""Returns a PennyLane quantum function created based on the input QuantumCircuit.
Warnings are created for each of the QuantumCircuit instructions that were
Expand Down Expand Up @@ -328,56 +327,62 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs):
wire_map = map_wires(qc_wires, wires)

# Stores the measurements encountered in the circuit
mid_circ_meas, terminal_meas = [], []
terminal_meas = []
mid_circ_meas, mid_circ_regs = [], {}
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

# Processing the dictionary of parameters passed
for idx, (op, qargs, _) in enumerate(qc.data):
# the new Singleton classes have different names than the objects they represent, but base_class.__name__ still matches
instruction_name = getattr(op, "base_class", op.__class__).__name__

operation_wires = [wire_map[hash(qubit)] for qubit in qargs]

for idx, (ops, qargs, cargs) in enumerate(qc.data):
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
# the new Singleton classes have different names than the objects they represent,
# but base_class.__name__ still matches
instruction_name = getattr(ops, "base_class", ops.__class__).__name__
# New Qiskit gates that are not natively supported by PL (identical
# gates exist with a different name)
# TODO: remove the following when gates have been renamed in PennyLane
instruction_name = "U3Gate" if instruction_name == "UGate" else instruction_name

# pylint:disable=protected-access
if (
instruction_name in inv_map
and inv_map[instruction_name] in pennylane_ops._qubit__ops__
):
# Extract the bound parameters from the operation. If the bound parameters are a
# Qiskit ParameterExpression, then replace it with the corresponding PennyLane
# variable from the unbound_params dictionary.

pl_parameters = []
for p in op.params:
_check_parameter_bound(p, unbound_params)

if isinstance(p, ParameterExpression):
if p.parameters: # non-empty set = has unbound parameters
ordered_params = tuple(p.parameters)

f = lambdify(ordered_params, p._symbol_expr, modules=qml.numpy)
f_args = []
for i_ordered_params in ordered_params:
f_args.append(unbound_params.get(i_ordered_params))
pl_parameters.append(f(*f_args))
else: # needed for qiskit<0.43.1
pl_parameters.append(float(p)) # pragma: no cover
else:
pl_parameters.append(p)

execute_supported_operation(
inv_map[instruction_name], pl_parameters, operation_wires
)
# Define operator builders and helpers
operation_func = None
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

elif instruction_name in dagger_map:
gate = dagger_map[instruction_name]
qml.adjoint(gate)(wires=operation_wires)
def operation_overlapper(op):
return op
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(op, Measure):
operation_wires = [wire_map[hash(qubit)] for qubit in qargs]
operation_kwargs = {"wires": operation_wires}
operation_args = []
operation_cond = False

# Extract the bound parameters from the operation. If the bound parameters are a
# Qiskit ParameterExpression, then replace it with the corresponding PennyLane
# variable from the var_ref_map dictionary.
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
operation_params = []
for p in ops.params:
_check_parameter_bound(p, unbound_params)

if isinstance(p, ParameterExpression):
if p.parameters: # non-empty set = has unbound parameters
ordered_params = tuple(p.parameters)
f = lambdify(ordered_params, p._symbol_expr, modules=qml.numpy)
f_args = []
for i_ordered_params in ordered_params:
f_args.append(unbound_params.get(i_ordered_params))
operation_params.append(f(*f_args))
else: # needed for qiskit<0.43.1
operation_params.append(float(p)) # pragma: no cover
else:
operation_params.append(p)
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

if instruction_name in dagger_map:
operation_func = dagger_map[instruction_name]
operation_overlapper = qml.adjoint

elif instruction_name in inv_map:
operation_name = inv_map[instruction_name]
operation_func = getattr(pennylane_ops, operation_name)
operation_args.extend(operation_params)
if operation_name in ["QubitStateVector", "StatePrep"]:
operation_args = [np.array(operation_params)]
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(ops, Measure):
# Store the current operation wires
op_wires = set(operation_wires)
# Look-ahead for more gate(s) on its wire(s)
Expand All @@ -391,22 +396,91 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs):
meas_terminal = False
break

# Allows for adding terminal measurements
if meas_terminal:
terminal_meas.extend(operation_wires)

# Allows for queing the mid-circuit measurements
if not meas_terminal:
mid_circ_meas.append(qml.measure(wires=operation_wires))
else:
terminal_meas.extend(operation_wires)
operation_func = qml.measure
mid_circ_meas.append(qml.measure(wires=operation_wires))

# Allows for tracking conditional operations
for carg in cargs:
mid_circ_regs[carg] = mid_circ_meas[-1]

# TODO: this may contain some logic for the bigger ControlFlowOps
elif isinstance(ops, ControlFlowOp):
operation_cond = True

else:

try:
operation_matrix = op.to_matrix()
pennylane_ops.QubitUnitary(operation_matrix, wires=operation_wires)
operation_args = [ops.to_matrix()]
operation_func = qml.QubitUnitary

except (AttributeError, QiskitError):
warnings.warn(
f"{__name__}: The {instruction_name} instruction is not supported by PennyLane,"
" and has not been added to the template.",
UserWarning,
)

# Check if it is a conditional operation
if operation_cond or (ops.condition and ops.condition[0] in mid_circ_regs):
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
# Iteratively recurse over to build different branches
with qml.QueuingManager.stop_recording():
branch_funcs = [
partial(load(branch_inst, measurements=None), params=params, wires=wires)
for branch_inst in operation_params
if isinstance(branch_inst, QuantumCircuit)
]
lillian542 marked this conversation as resolved.
Show resolved Hide resolved

# Get the functions for handling condition
true_fn, false_fn, elif_fns, cond_op = _conditional_funcs(
ops, cargs, operation_func, branch_funcs, instruction_name
)
res_reg, res_bit = cond_op

# Check for multi-qubit register
if tuple(cargs) not in mid_circ_regs:
ctrl_cargs = min(len(cargs), len(qargs))
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
mid_circ_regs[tuple(cargs)] = sum(
2**idx * qml.measure(wires=operation_wires[idx])
for idx in range(ctrl_cargs)
)

# Check for elif branches (doesn't require qjit)
if elif_fns:
for elif_fn in elif_fns:
qml.cond(mid_circ_regs[res_reg] == elif_fn[0], elif_fn[1])(
*operation_args, **operation_kwargs
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
)

# Check if just conditional requires some extra work
if isinstance(res_bit, str):
# Handles the default case in the SwitchCaseOp
if res_bit == "SwitchDefault":
elif_bits = [elif_fn[0] for elif_fn in elif_fns]
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
qml.cond(
reduce(
lambda m0, m1: m0 & m1,
[(mid_circ_regs[res_reg] != elif_bit) for elif_bit in elif_bits],
),
true_fn,
)(*operation_args, **operation_kwargs)
# Just do the routine conditional
else:
qml.cond(
mid_circ_regs[res_reg] == res_bit,
true_fn,
false_fn,
)(*operation_args, **operation_kwargs)

# Check if it is not a mid-circuit measurement
elif operation_func and not isinstance(ops, Measure):
operation_overlapper(operation_func)(*operation_args, **operation_kwargs)

# Use the user-provided measurements
if measurements:
if qml.queuing.QueuingManager.active_context():
Expand Down Expand Up @@ -436,3 +510,31 @@ def load_qasm_from_file(file: str):
function: the new PennyLane template
"""
return load(QuantumCircuit.from_qasm_file(file))


# pylint:disable=fixme, protected-access
def _conditional_funcs(ops, cargs, operation_func, branch_funcs, ctrl_flow_type):
"""Builds the conditional functions for Controlled flows"""
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
true_fn, false_fn, elif_fns = operation_func, None, ()
# Logic for using legacy c_if
if not isinstance(ops, ControlFlowOp):
return true_fn, false_fn, elif_fns, ops.condition

# Logic for handling IfElseOp
if ctrl_flow_type == "IfElseOp":
true_fn = branch_funcs[0]
if len(branch_funcs) == 2:
false_fn = branch_funcs[1]

# Logic for handling SwitchCaseOp
elif ctrl_flow_type == "SwitchCaseOp":
elif_fns = []
for res_bit, case in ops._case_map.items():
if not isinstance(case, _DefaultCaseType):
elif_fns.append((res_bit, branch_funcs[case]))
ops.condition = [tuple(cargs), "SwitchCase"]
if any((isinstance(case, _DefaultCaseType) for case in ops._case_map)):
true_fn = branch_funcs[-1]
ops.condition = [tuple(cargs), "SwitchDefault"]
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

return true_fn, false_fn, elif_fns, ops.condition