In [None]:
import cirq

def apply_commutation(circuit: cirq.Circuit) -> cirq.Circuit:
    moments = list(circuit)
    new_moments = []

    print("Initial Moments:" + str(moments))

    i = 0
    while i < len(moments) - 1:
        m1, m2 = moments[i], moments[i + 1]
        print(f"\nProcessing Moment {i} and {i + 1}")
        print(f"  m1: {m1}")
        print(f"  m2: {m2}")

        m1_ops = list(m1.operations)
        m2_ops = list(m2.operations)

        to_move_to_m2 = []
        to_move_to_m1 = []
        swapped_m2_indices = set()

        for idx1, op1 in enumerate(m1_ops):
            for idx2, op2 in enumerate(m2_ops):
                if idx2 in swapped_m2_indices:
                    continue
                if cirq.commutes(op1, op2, atol=1e-6):
                    print(f"    Swapping op1: {op1} with op2: {op2}")
                    to_move_to_m2.append(op1)
                    to_move_to_m1.append(op2)
                    swapped_m2_indices.add(idx2)
                    break

        new_m1_ops = [op for op in m1_ops if op not in to_move_to_m2] + to_move_to_m1
        new_m2_ops = [op for idx, op in enumerate(m2_ops) if idx not in swapped_m2_indices] + to_move_to_m2

        def split_ops_into_moments(ops):
            moments_list = []
            used_qubits = set()
            current_ops = []
            for op in ops:
                if any(q in used_qubits for q in op.qubits):
                    moments_list.append(cirq.Moment(current_ops))
                    current_ops = [op]
                    used_qubits = set(op.qubits)
                else:
                    current_ops.append(op)
                    used_qubits.update(op.qubits)
            if current_ops:
                moments_list.append(cirq.Moment(current_ops))
            return moments_list

        new_m1_moments = split_ops_into_moments(new_m1_ops)
        new_m2_moments = split_ops_into_moments(new_m2_ops)

        print(f"  New m1 moments: {new_m1_moments}")
        print(f"  New m2 moments: {new_m2_moments}")

        new_moments.extend(new_m1_moments)
        new_moments.extend(new_m2_moments)

        i += 2

    if i == len(moments) - 1:
        new_moments.append(moments[-1])

    return cirq.Circuit(new_moments)

# Example usage
def main():
    q0, q1 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit(
        cirq.H(q0),
        cirq.CNOT(q0, q1),
        cirq.X(q0),
        cirq.Z(q1),
        cirq.Y(q1)
    )

    print("Original Circuit:")
    print(circuit)

    new_circuit = apply_commutation(circuit)

    print("\nTransformed Circuit:")
    print(new_circuit)

if __name__ == "__main__":
    main()


Original Circuit:
0: ───H───@───X───────
          │
1: ───────X───Z───Y───
Initial Moments:[cirq.Moment(
    cirq.H(cirq.LineQubit(0)),
), cirq.Moment(
    cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)),
), cirq.Moment(
    cirq.X(cirq.LineQubit(0)),
    cirq.Z(cirq.LineQubit(1)),
), cirq.Moment(
    cirq.Y(cirq.LineQubit(1)),
)]

Processing Moment 0 and 1
  m1:   ╷ 0
╶─┼───
0 │ H
  │
  m2:   ╷ 0 1
╶─┼─────
0 │ @─X
  │
[cirq.H(cirq.LineQubit(0))]
[cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1))]
  New m1 moments: [cirq.Moment(
    cirq.H(cirq.LineQubit(0)),
)]
  New m2 moments: [cirq.Moment(
    cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)),
)]

Processing Moment 2 and 3
  m1:   ╷ 0 1
╶─┼─────
0 │ X Z
  │
  m2:   ╷ 1
╶─┼───
0 │ Y
  │
[cirq.X(cirq.LineQubit(0)), cirq.Z(cirq.LineQubit(1))]
[cirq.Y(cirq.LineQubit(1))]
    Swapping op1: X(q(0)) with op2: Y(q(1))
  New m1 moments: [cirq.Moment(
    cirq.Z(cirq.LineQubit(1)),
), cirq.Moment(
    cirq.Y(cirq.LineQubit(1)),
)]
  New m2 mom

Commutes across multiple qubit check

In [8]:
import cirq

q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
    cirq.Z(q0),
    cirq.Y(q1)
)

ops = list(circuit.all_operations())
print(cirq.commutes(ops[0], ops[1], atol=1e-6))

True
