In [None]:
import cirq

def spacer_around_CNOT(circuit: cirq.Circuit) -> cirq.Circuit:
# Verified. See code_verification\apply_commutation.ipynb for test cases.
    """
    This function is a helper function for apply_commutation.

    This function first adds a identity gate around the CNOT gate along the target qubit to spread the 
    CNOT gates apart from the rest. Then the identity gates are swapped in the direction away from the 
    CNOT gates with a single qubit gate.
    """
    moments = list(circuit)
    # spread
    i = 0
    while i < len(moments):
        moment = moments[i]
        cnot_ops = [op for op in moment.operations if isinstance(op.gate, cirq.CNotPowGate)]
        for cnot in cnot_ops:
            q0, q1 = cnot.qubits
            if i > 0:
                left = moments[i - 1]
                if q0 in left.qubits and q1 in left.qubits:
                    moments.insert(i, cirq.Moment([cirq.I(q1)]))
                    i += 1
                    break
            if i + 1 < len(moments):
                right = moments[i + 1]
                if q0 in right.qubits and q1 in right.qubits:
                    moments.insert(i + 1, cirq.Moment([cirq.I(q1)]))
                    break
        i += 1
    
    circuit = cirq.Circuit(moments)

    # shift
    new_moments = list(circuit)
    num_moments = len(new_moments)

    shifted_ids = set()  # (moment index, qubit tuple)

    for i in range(num_moments):
        moment = new_moments[i]
        ops = list(moment.operations)

        for op_idx, op in enumerate(ops):
            if isinstance(op.gate, cirq.IdentityGate):
                id_key = (i, tuple(op.qubits))
                if id_key in shifted_ids:
                    continue

                id_qubits = set(op.qubits)

                # try swap with i+1 moment
                if i + 1 < num_moments:
                    next_ops = list(new_moments[i + 1].operations)
                    for j, next_op in enumerate(next_ops):
                        if not isinstance(next_op.gate, cirq.CNotPowGate) and set(next_op.qubits) == id_qubits:
                            ops[op_idx], next_ops[j] = next_ops[j], ops[op_idx]
                            new_moments[i] = cirq.Moment(ops)
                            new_moments[i + 1] = cirq.Moment(next_ops)
                            shifted_ids.discard((i, tuple(op.qubits)))
                            shifted_ids.add((i + 1, tuple(op.qubits)))  # mark new location
                            break

                # try swap with i-1 moment
                elif i - 1 >= 0:
                    prev_ops = list(new_moments[i - 1].operations)
                    for j, prev_op in enumerate(prev_ops):
                        if not isinstance(prev_op.gate, cirq.CNotPowGate) and set(prev_op.qubits) == id_qubits:
                            ops[op_idx], prev_ops[j] = prev_ops[j], ops[op_idx]
                            new_moments[i] = cirq.Moment(ops)
                            new_moments[i - 1] = cirq.Moment(prev_ops)
                            shifted_ids.discard((i, tuple(op.qubits)))
                            shifted_ids.add((i - 1, tuple(op.qubits)))
                            break
    return cirq.Circuit(new_moments)

def apply_commutation(circuit: cirq.Circuit) -> cirq.Circuit:
# Verified and debugged. See code_verification\apply_commutation.ipynb for test cases.
    """
    This function applies commutation by first applyings spacers around the CNOT gates, 
    swapping the gates, then removing the identity spacers. The addition of Identity spacers
    around the CNOT is because CNOT is a multi-qubit gate and may overlap with other single 
    qubit gates. As a result, the swaps that occur may not be apparent because it was swapped 
    with an Identity gate, which is later removed.
    """
    # Add spacers around CNOT gate
    circuit = spacer_around_CNOT(circuit)

    # swap/commutation
    moments = list(circuit)

    i = 0
    while i < len(moments) - 1:
        m1, m2 = moments[i], moments[i + 1]
        m1_ops = list(m1.operations)
        m2_ops = list(m2.operations)

        to_move_to_m2 = []
        to_move_to_m1 = []

        for op1 in m1_ops:
            for op2 in m2_ops:
                if cirq.commutes(op1, op2, atol=1e-10) and (not set(op1.qubits).isdisjoint(op2.qubits)):
                    print(f"Swapping {op1} and {op2}")
                    to_move_to_m2.append(op1)
                    to_move_to_m1.append(op2)
                    break

        new_m1_ops = [op for op in m1_ops if op not in to_move_to_m2] + to_move_to_m1
        new_m2_ops = [op for op in m2_ops if op not in to_move_to_m1] + to_move_to_m2

        moments[i] = cirq.Moment(new_m1_ops)
        moments[i + 1] = cirq.Moment(new_m2_ops)

        i += 1
    circuit = cirq.Circuit(moments)

    # remove the identity spacers
    ops_in_order = []

    for moment in circuit:
        for op in moment.operations:
            if not isinstance(op.gate, cirq.IdentityGate):
                ops_in_order.append(op)
                
    return cirq.Circuit(ops_in_order)


In [None]:
import numpy as np
def circuits_equivalent(circ1: cirq.Circuit, circ2: cirq.Circuit):
    U1 = cirq.unitary(circ1)
    U2 = cirq.unitary(circ2)
    diff = U1.conj().T @ U2
    phases = np.angle(np.linalg.eigvals(diff))
    global_phase = np.exp(1j * phases[0])
    return np.allclose(global_phase * U1, U2, atol=1E-8)

In [123]:
# Example usage
def test_1():
    q0, q1 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit(
        cirq.Z(q0),
        cirq.CNOT(q0, q1),
        cirq.Z(q0),
        cirq.Z(q1),
        cirq.H(q1),
        cirq.CNOT(q0, q1),
        cirq.X(q0),
        cirq.Z(q1),
        cirq.Y(q1)
    )

    print("Original Circuit:")
    print(circuit)
    new_circuit = apply_commutation(circuit)
    print("\nTransformed Circuit:")
    print(new_circuit)
    print(f"Circuit equivalence: {circuits_equivalent(circuit, new_circuit)}")

def test_2():
    q0, q1, q2 = cirq.LineQubit.range(3)
    circuit = cirq.Circuit(
        cirq.Z(q0),
        cirq.CNOT(q0, q1),
        cirq.Z(q0),
        cirq.Z(q1),
        cirq.H(q1),
        cirq.CNOT(q0, q1),
        cirq.X(q0),
        cirq.Z(q1),
        cirq.CNOT(q1, q2),
        cirq.Y(q1)
    )

    print("Original Circuit:")
    print(circuit)
    new_circuit = apply_commutation(circuit)
    print("\nTransformed Circuit:")
    print(new_circuit)
    print(f"Circuit equivalence: {circuits_equivalent(circuit, new_circuit)}")

if __name__ == "__main__":
    print("############ TEST CASE 1: #############")
    test_1()
    print("############ TEST CASE 2: #############")
    test_2()

############ TEST CASE 1: #############
Original Circuit:
0: ───Z───@───Z───────@───X───────
          │           │
1: ───────X───Z───H───X───Z───Y───
Swapping Z(q(0)) and CNOT(q(0), q(1))
Swapping Z(q(1)) and I(q(1))
Swapping Z(q(1)) and I(q(1))

Transformed Circuit:
0: ───@───Z───Z───@───X───────
      │           │
1: ───X───Z───H───X───Z───Y───
Circuit equivalence: True
############ TEST CASE 2: #############
Original Circuit:
0: ───Z───@───Z───────@───X───────────
          │           │
1: ───────X───Z───H───X───Z───@───Y───
                              │
2: ───────────────────────────X───────
Swapping Z(q(0)) and CNOT(q(0), q(1))
Swapping Z(q(1)) and I(q(1))
Swapping Z(q(1)) and I(q(1))
Swapping Z(q(1)) and CNOT(q(1), q(2))

Transformed Circuit:
0: ───@───Z───Z───@───X───────────
      │           │
1: ───X───Z───H───X───@───Z───Y───
                      │
2: ───────────────────X───────────
Circuit equivalence: True
