Skip to content

Commit

Permalink
Convert additional bit_indices to dag.find_bit in transpiler passes (#…
Browse files Browse the repository at this point in the history
…10463)

* Use dag find_bit in additional places

* Fix dag viz

* Lint

* More lint and cleanup

(cherry picked from commit d02a254)
  • Loading branch information
enavarro51 authored and mergify[bot] committed Jul 24, 2023
1 parent 7d964cb commit 9afa8e6
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 62 deletions.
24 changes: 9 additions & 15 deletions qiskit/transpiler/passes/optimization/consolidate_blocks.py
Expand Up @@ -90,15 +90,11 @@ def run(self, dag):
if self.decomposer is None:
return dag

# compute ordered indices for the global circuit wires
global_index_map = {wire: idx for idx, wire in enumerate(dag.qubits)}
blocks = self.property_set["block_list"] or []
basis_gate_name = self.decomposer.gate.name
all_block_gates = set()
for block in blocks:
if len(block) == 1 and self._check_not_in_basis(
block[0].name, block[0].qargs, global_index_map
):
if len(block) == 1 and self._check_not_in_basis(dag, block[0].name, block[0].qargs):
all_block_gates.add(block[0])
dag.substitute_node(block[0], UnitaryGate(block[0].op.to_matrix()))
else:
Expand All @@ -111,11 +107,11 @@ def run(self, dag):
if isinstance(nd, DAGOpNode) and getattr(nd.op, "condition", None):
block_cargs |= set(getattr(nd.op, "condition", None)[0])
all_block_gates.add(nd)
block_index_map = self._block_qargs_to_indices(block_qargs, global_index_map)
block_index_map = self._block_qargs_to_indices(dag, block_qargs)
for nd in block:
if nd.op.name == basis_gate_name:
basis_count += 1
if self._check_not_in_basis(nd.op.name, nd.qargs, global_index_map):
if self._check_not_in_basis(dag, nd.op.name, nd.qargs):
outside_basis = True
if len(block_qargs) > 2:
q = QuantumRegister(len(block_qargs))
Expand Down Expand Up @@ -153,9 +149,7 @@ def run(self, dag):
for run in runs:
if any(gate in all_block_gates for gate in run):
continue
if len(run) == 1 and not self._check_not_in_basis(
run[0].name, run[0].qargs, global_index_map
):
if len(run) == 1 and not self._check_not_in_basis(dag, run[0].name, run[0].qargs):
dag.substitute_node(run[0], UnitaryGate(run[0].op.to_matrix()))
else:
qubit = run[0].qargs[0]
Expand Down Expand Up @@ -201,15 +195,15 @@ def _handle_control_flow_ops(self, dag):
node.op = node.op.replace_blocks(pass_manager.run(block) for block in node.op.blocks)
return dag

def _check_not_in_basis(self, gate_name, qargs, global_index_map):
def _check_not_in_basis(self, dag, gate_name, qargs):
if self.target is not None:
return not self.target.instruction_supported(
gate_name, tuple(global_index_map[qubit] for qubit in qargs)
gate_name, tuple(dag.find_bit(qubit).index for qubit in qargs)
)
else:
return self.basis_gates and gate_name not in self.basis_gates

def _block_qargs_to_indices(self, block_qargs, global_index_map):
def _block_qargs_to_indices(self, dag, block_qargs):
"""Map each qubit in block_qargs to its wire position among the block's wires.
Args:
block_qargs (list): list of qubits that a block acts on
Expand All @@ -218,7 +212,7 @@ def _block_qargs_to_indices(self, block_qargs, global_index_map):
Returns:
dict: mapping from qarg to position in block
"""
block_indices = [global_index_map[q] for q in block_qargs]
block_indices = [dag.find_bit(q).index for q in block_qargs]
ordered_block_indices = {bit: index for index, bit in enumerate(sorted(block_indices))}
block_positions = {q: ordered_block_indices[global_index_map[q]] for q in block_qargs}
block_positions = {q: ordered_block_indices[dag.find_bit(q).index] for q in block_qargs}
return block_positions
Expand Up @@ -136,7 +136,6 @@ def __init__(
self.model = None
self.dag = None
self.parse_backend_properties()
self.qubit_indices = None

def powerset(self, iterable):
"""
Expand Down Expand Up @@ -189,8 +188,8 @@ def cx_tuple(self, gate):
Note: current implementation assumes that the CX error rates and
crosstalk behavior are independent of gate direction
"""
physical_q_0 = self.qubit_indices[gate.qargs[0]]
physical_q_1 = self.qubit_indices[gate.qargs[1]]
physical_q_0 = self.dag.find_bit(gate.qargs[0]).index
physical_q_1 = self.dag.find_bit(gate.qargs[1]).index
r_0 = min(physical_q_0, physical_q_1)
r_1 = max(physical_q_0, physical_q_1)
return (r_0, r_1)
Expand All @@ -199,7 +198,7 @@ def singleq_tuple(self, gate):
"""
Representation for single-qubit gate
"""
physical_q_0 = self.qubit_indices[gate.qargs[0]]
physical_q_0 = self.dag.find_bit(gate.qargs[0]).index
tup = (physical_q_0,)
return tup

Expand Down Expand Up @@ -310,15 +309,15 @@ def create_z3_vars(self):
active_qubits_list = []
for gate in self.dag.gate_nodes():
for q in gate.qargs:
active_qubits_list.append(self.qubit_indices[q])
active_qubits_list.append(self.dag.find_bit(q).index)
for active_qubit in list(set(active_qubits_list)):
q_var_name = "l_" + str(active_qubit)
self.qubit_lifetime[active_qubit] = z3.Real(q_var_name)

meas_q = []
for node in self.dag.op_nodes():
if isinstance(node.op, Measure):
meas_q.append(self.qubit_indices[node.qargs[0]])
meas_q.append(self.dag.find_bit(node.qargs[0]).index)

self.measured_qubits = list(set(self.input_measured_qubits).union(set(meas_q)))
self.measure_start = z3.Real("meas_start")
Expand All @@ -330,7 +329,7 @@ def basic_bounds(self):
for gate in self.gate_start_time:
self.opt.add(self.gate_start_time[gate] >= 0)
for gate in self.gate_duration:
q_0 = self.qubit_indices[gate.qargs[0]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
if isinstance(gate.op, U1Gate):
dur = self.bp_u1_dur[q_0]
elif isinstance(gate.op, U2Gate):
Expand Down Expand Up @@ -384,7 +383,7 @@ def fidelity_constraints(self):
import z3

for gate in self.gate_start_time:
q_0 = self.qubit_indices[gate.qargs[0]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
no_xtalk = False
if gate not in self.xtalk_overlap_set:
no_xtalk = True
Expand Down Expand Up @@ -439,23 +438,23 @@ def coherence_constraints(self):
if isinstance(gate.op, Barrier):
continue
if len(gate.qargs) == 1:
q_0 = self.qubit_indices[gate.qargs[0]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
self.last_gate_on_qubit[q_0] = gate
else:
q_0 = self.qubit_indices[gate.qargs[0]]
q_1 = self.qubit_indices[gate.qargs[1]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
q_1 = self.dag.find_bit(gate.qargs[1]).index
self.last_gate_on_qubit[q_0] = gate
self.last_gate_on_qubit[q_1] = gate

self.first_gate_on_qubit = {}
for gate in self.dag.topological_op_nodes():
if len(gate.qargs) == 1:
q_0 = self.qubit_indices[gate.qargs[0]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
if q_0 not in self.first_gate_on_qubit:
self.first_gate_on_qubit[q_0] = gate
else:
q_0 = self.qubit_indices[gate.qargs[0]]
q_1 = self.qubit_indices[gate.qargs[1]]
q_0 = self.dag.find_bit(gate.qargs[0]).index
q_1 = self.dag.find_bit(gate.qargs[1]).index
if q_0 not in self.first_gate_on_qubit:
self.first_gate_on_qubit[q_0] = gate
if q_1 not in self.first_gate_on_qubit:
Expand Down Expand Up @@ -719,7 +718,6 @@ def run(self, dag):
"""
self.dag = dag

self.qubit_indices = {bit: idx for idx, bit in enumerate(dag.qubits)}
# process input program
self.assign_gate_id(self.dag)
self.extract_dag_overlap_sets(self.dag)
Expand Down
5 changes: 2 additions & 3 deletions qiskit/transpiler/passes/routing/algorithms/bip_model.py
Expand Up @@ -87,7 +87,6 @@ def __init__(self, dag, coupling_map, qubit_subset, dummy_timesteps=None):
)

self._index_to_virtual = dict(enumerate(dag.qubits))
self._virtual_to_index = {v: i for i, v in self._index_to_virtual.items()}

# Construct internal circuit model
# Extract layers with 2-qubit gates
Expand All @@ -96,8 +95,8 @@ def __init__(self, dag, coupling_map, qubit_subset, dummy_timesteps=None):
for lay in dag.layers():
laygates = []
for node in lay["graph"].two_qubit_ops():
i1 = self._virtual_to_index[node.qargs[0]]
i2 = self._virtual_to_index[node.qargs[1]]
i1 = self._dag.find_bit(node.qargs[0]).index
i2 = self._dag.find_bit(node.qargs[1]).index
laygates.append(((i1, i2), node))
if laygates:
self._to_su4layer.append(len(self.su4layers))
Expand Down
5 changes: 2 additions & 3 deletions qiskit/transpiler/passes/routing/sabre_swap.py
Expand Up @@ -282,14 +282,13 @@ def recurse(block, block_qubit_indices):
return process_dag(block_dag, block_qubit_indices)

def process_dag(block_dag, wire_map):
clbit_indices = {bit: idx for idx, bit in enumerate(block_dag.clbits)}
dag_list = []
node_blocks = {}
for node in block_dag.topological_op_nodes():
cargs = {clbit_indices[x] for x in node.cargs}
cargs = {block_dag.find_bit(x).index for x in node.cargs}
if node.op.condition is not None:
for clbit in block_dag._bits_in_operation(node.op):
cargs.add(clbit_indices[clbit])
cargs.add(block_dag.find_bit(clbit).index)
if isinstance(node.op, ControlFlowOp):
node_blocks[node._node_id] = [
recurse(
Expand Down
7 changes: 3 additions & 4 deletions qiskit/transpiler/passes/scheduling/alap.py
Expand Up @@ -62,9 +62,8 @@ def run(self, dag):
new_dag.add_creg(creg)

idle_before = {q: 0 for q in dag.qubits + dag.clbits}
bit_indices = {bit: index for index, bit in enumerate(dag.qubits)}
for node in reversed(list(dag.topological_op_nodes())):
op_duration = self._get_node_duration(node, bit_indices, dag)
op_duration = self._get_node_duration(node, dag)

# compute t0, t1: instruction interval, note that
# t0: start time of instruction
Expand Down Expand Up @@ -131,7 +130,7 @@ def run(self, dag):

for bit in node.qargs:
delta = t0 - idle_before[bit]
if delta > 0 and self._delay_supported(bit_indices[bit]):
if delta > 0 and self._delay_supported(dag.find_bit(bit).index):
new_dag.apply_operation_front(Delay(delta, time_unit), [bit], [])
idle_before[bit] = t1

Expand All @@ -142,7 +141,7 @@ def run(self, dag):
delta = circuit_duration - before
if not (delta > 0 and isinstance(bit, Qubit)):
continue
if self._delay_supported(bit_indices[bit]):
if self._delay_supported(dag.find_bit(bit).index):
new_dag.apply_operation_front(Delay(delta, time_unit), [bit], [])

new_dag.name = dag.name
Expand Down
11 changes: 7 additions & 4 deletions qiskit/transpiler/passes/scheduling/asap.py
Expand Up @@ -69,9 +69,8 @@ def run(self, dag):
new_dag.add_creg(creg)

idle_after = {q: 0 for q in dag.qubits + dag.clbits}
bit_indices = {q: index for index, q in enumerate(dag.qubits)}
for node in dag.topological_op_nodes():
op_duration = self._get_node_duration(node, bit_indices, dag)
op_duration = self._get_node_duration(node, dag)

# compute t0, t1: instruction interval, note that
# t0: start time of instruction
Expand Down Expand Up @@ -150,7 +149,11 @@ def run(self, dag):
# Add delay to qubit wire
for bit in node.qargs:
delta = t0 - idle_after[bit]
if delta > 0 and isinstance(bit, Qubit) and self._delay_supported(bit_indices[bit]):
if (
delta > 0
and isinstance(bit, Qubit)
and self._delay_supported(dag.find_bit(bit).index)
):
new_dag.apply_operation_back(Delay(delta, time_unit), [bit], [])
idle_after[bit] = t1

Expand All @@ -161,7 +164,7 @@ def run(self, dag):
delta = circuit_duration - after
if not (delta > 0 and isinstance(bit, Qubit)):
continue
if self._delay_supported(bit_indices[bit]):
if self._delay_supported(dag.find_bit(bit).index):
new_dag.apply_operation_back(Delay(delta, time_unit), [bit], [])

new_dag.name = dag.name
Expand Down
4 changes: 1 addition & 3 deletions qiskit/transpiler/passes/scheduling/base_scheduler.py
Expand Up @@ -11,7 +11,6 @@
# that they have been altered from the originals.

"""Base circuit scheduling pass."""
from typing import Dict
from qiskit.transpiler import InstructionDurations
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.passes.scheduling.time_unit_conversion import TimeUnitConversion
Expand Down Expand Up @@ -258,11 +257,10 @@ def __init__(
@staticmethod
def _get_node_duration(
node: DAGOpNode,
bit_index_map: Dict,
dag: DAGCircuit,
) -> int:
"""A helper method to get duration from node or calibration."""
indices = [bit_index_map[qarg] for qarg in node.qargs]
indices = [dag.find_bit(qarg).index for qarg in node.qargs]

if dag.has_calibration_for(node):
# If node has calibration, this value should be the highest priority
Expand Down
3 changes: 1 addition & 2 deletions qiskit/transpiler/passes/scheduling/scheduling/alap.py
Expand Up @@ -46,9 +46,8 @@ def run(self, dag):

node_start_time = {}
idle_before = {q: 0 for q in dag.qubits + dag.clbits}
bit_indices = {bit: index for index, bit in enumerate(dag.qubits)}
for node in reversed(list(dag.topological_op_nodes())):
op_duration = self._get_node_duration(node, bit_indices, dag)
op_duration = self._get_node_duration(node, dag)

# compute t0, t1: instruction interval, note that
# t0: start time of instruction
Expand Down
3 changes: 1 addition & 2 deletions qiskit/transpiler/passes/scheduling/scheduling/asap.py
Expand Up @@ -46,9 +46,8 @@ def run(self, dag):

node_start_time = {}
idle_after = {q: 0 for q in dag.qubits + dag.clbits}
bit_indices = {bit: index for index, bit in enumerate(dag.qubits)}
for node in dag.topological_op_nodes():
op_duration = self._get_node_duration(node, bit_indices, dag)
op_duration = self._get_node_duration(node, dag)

# compute t0, t1: instruction interval, note that
# t0: start time of instruction
Expand Down
Expand Up @@ -13,7 +13,7 @@
"""Base circuit scheduling pass."""

import warnings
from typing import Dict

from qiskit.transpiler import InstructionDurations
from qiskit.transpiler.basepasses import AnalysisPass
from qiskit.transpiler.passes.scheduling.time_unit_conversion import TimeUnitConversion
Expand Down Expand Up @@ -59,11 +59,10 @@ def __init__(self, durations: InstructionDurations = None, target: Target = None
@staticmethod
def _get_node_duration(
node: DAGOpNode,
bit_index_map: Dict,
dag: DAGCircuit,
) -> int:
"""A helper method to get duration from node or calibration."""
indices = [bit_index_map[qarg] for qarg in node.qargs]
indices = [dag.find_bit(qarg).index for qarg in node.qargs]

if dag.has_calibration_for(node):
# If node has calibration, this value should be the highest priority
Expand Down
22 changes: 14 additions & 8 deletions qiskit/visualization/dag_visualization.py
Expand Up @@ -105,8 +105,6 @@ def node_attr_func(node):
edge_attr_func = None

else:
qubit_indices = {bit: index for index, bit in enumerate(dag.qubits)}
clbit_indices = {bit: index for index, bit in enumerate(dag.clbits)}
register_bit_labels = {
bit: f"{reg.name}[{idx}]"
for reg in list(dag.qregs.values()) + list(dag.cregs.values())
Expand All @@ -127,18 +125,26 @@ def node_attr_func(node):
n["fillcolor"] = "lightblue"
if isinstance(node, DAGInNode):
if isinstance(node.wire, Qubit):
label = register_bit_labels.get(node.wire, f"q_{qubit_indices[node.wire]}")
label = register_bit_labels.get(
node.wire, f"q_{dag.find_bit(node.wire).index}"
)
else:
label = register_bit_labels.get(node.wire, f"c_{clbit_indices[node.wire]}")
label = register_bit_labels.get(
node.wire, f"c_{dag.find_bit(node.wire).index}"
)
n["label"] = label
n["color"] = "black"
n["style"] = "filled"
n["fillcolor"] = "green"
if isinstance(node, DAGOutNode):
if isinstance(node.wire, Qubit):
label = register_bit_labels.get(node.wire, f"q[{qubit_indices[node.wire]}]")
label = register_bit_labels.get(
node.wire, f"q[{dag.find_bit(node.wire).index}]"
)
else:
label = register_bit_labels.get(node.wire, f"c[{clbit_indices[node.wire]}]")
label = register_bit_labels.get(
node.wire, f"c[{dag.find_bit(node.wire).index}]"
)
n["label"] = label
n["color"] = "black"
n["style"] = "filled"
Expand All @@ -150,9 +156,9 @@ def node_attr_func(node):
def edge_attr_func(edge):
e = {}
if isinstance(edge, Qubit):
label = register_bit_labels.get(edge, f"q_{qubit_indices[edge]}")
label = register_bit_labels.get(edge, f"q_{dag.find_bit(edge).index}")
else:
label = register_bit_labels.get(edge, f"c_{clbit_indices[edge]}")
label = register_bit_labels.get(edge, f"c_{dag.find_bit(edge).index}")
e["label"] = label
return e

Expand Down

0 comments on commit 9afa8e6

Please sign in to comment.