In [1]:
import numpy as np
from qiskit import QuantumCircuit, QuantumRegister, transpile
from typing import Optional

In [30]:
NUM_ELEMENTS: int = 12
NUM_QUBITS: int = (NUM_ELEMENTS - 1).bit_length()
NUM_DATA_QUBITS: int = 1

In [31]:
def init_perm_reg(perm_reg: QuantumRegister, num_elements: int):
    qc: QuantumCircuit = QuantumCircuit(perm_reg)
    num_qubits: int = len(perm_reg) // num_elements
    for k in range(1, num_elements):
        for i in range(k.bit_length()):
            if (k >> i) & 1:
                qc.x(perm_reg[k * num_qubits + i])
    return qc

In [32]:
def shukla_vedula(qreg: QuantumRegister, num_states: int):
    qc: QuantumCircuit = QuantumCircuit(qreg)
    num_qubits: int = (num_states - 1).bit_length()
    if num_qubits != num_states.bit_length():  # num_states is a power of 2
        qc.h(qreg[:num_qubits])
        return qc
    
    bit_pos: list[int] = [index for (index, bit) in enumerate(np.binary_repr(num_states)[::-1]) if bit == '1']
    qc.x(qreg[bit_pos[1:num_states.bit_length()]])
    cur_num_states: int = 1 << bit_pos[0]
    theta: float = -2 * np.arccos(np.sqrt(cur_num_states / num_states))
    if bit_pos[0] > 0:  # num_states is even
        qc.h(qreg[0:bit_pos[0]])
    qc.ry(theta, qreg[bit_pos[1]])
    qc.ch(qreg[bit_pos[1]], qreg[bit_pos[0]:bit_pos[1]], ctrl_state='0')
    for m in range(1, len(bit_pos) - 1):
        theta = -2 * np.arccos(np.sqrt(2**bit_pos[m] / (num_states - cur_num_states)))
        qc.cry(theta, qreg[bit_pos[m]], qreg[bit_pos[m + 1]], ctrl_state='0')
        qc.ch(qreg[bit_pos[m + 1]], qreg[bit_pos[m]:bit_pos[m + 1]], ctrl_state='0')
        cur_num_states += 1 << bit_pos[m]
    return qc

In [41]:
def quantum_fisher_yates(num_elements: int, perm_reg: QuantumRegister, ancilla_reg: QuantumRegister, data_reg: Optional[QuantumRegister] = None, disentangling: bool = True):
    num_qubits: int = (num_elements - 1).bit_length()
    if data_reg:
        qc: QuantumCircuit = QuantumCircuit(data_reg, perm_reg, ancilla_reg)
        num_data_qubits: int = len(data_reg) // num_elements
    else:
        qc: QuantumCircuit = QuantumCircuit(perm_reg, ancilla_reg)
    qc.compose(init_perm_reg(perm_reg, num_elements), inplace=True)
    offset: int = 0
    for i in range(1, num_elements):
        num_ctrl: int = i.bit_length()
        qc.compose(other=shukla_vedula(QuantumRegister(num_ctrl), i + 1), qubits=ancilla_reg[offset:offset + num_ctrl], inplace=True)
        for j in range(0, i):
            for q in range(0, num_qubits):
                qc.cx(perm_reg[j * num_qubits + q], perm_reg[i * num_qubits + q], ctrl_state='1')
                qc.mcx([perm_reg[i * num_qubits + q]] + ancilla_reg[offset:offset + num_ctrl], perm_reg[j * num_qubits + q], ctrl_state=np.binary_repr(j, num_ctrl) + '1')
                qc.cx(perm_reg[j * num_qubits + q], perm_reg[i * num_qubits + q], ctrl_state='1')
            if data_reg:
                for q in range(0, num_data_qubits):
                    qc.cx(data_reg[j * num_data_qubits + q], data_reg[i * num_data_qubits + q], ctrl_state='1')
                    qc.mcx([data_reg[i * num_data_qubits + q]] + ancilla_reg[offset:offset + num_ctrl], data_reg[j * num_data_qubits + q], ctrl_state=np.binary_repr(j, num_ctrl) + '1')
                    qc.cx(data_reg[j * num_data_qubits + q], data_reg[i * num_data_qubits + q], ctrl_state='1')
        if disentangling:
            for j in range(1, i + 1):
                for k in range(j.bit_length()):
                    if (j >> k) & 1:
                        qc.mcx(control_qubits=perm_reg[j * NUM_QUBITS:j * NUM_QUBITS + num_ctrl], target_qubit=ancilla_reg[offset + k], ctrl_state=i)
        else:
            offset += num_ctrl
    return qc

In [42]:
a: QuantumRegister = QuantumRegister(NUM_ELEMENTS * NUM_DATA_QUBITS, 'data')
b: QuantumRegister = QuantumRegister(NUM_ELEMENTS * NUM_QUBITS, 'perm')
c: QuantumRegister = QuantumRegister(sum([x.bit_length() for x in range(1, NUM_ELEMENTS)]), 'anc')
quc: QuantumCircuit = quantum_fisher_yates(NUM_ELEMENTS, b, c, None, disentangling=False)
#display(quc.draw(output='text'))
print(quc.count_ops())
target_basis = ['rx', 'ry', 'rz', 'h', 'cx']
decomposed = transpile(quc, basis_gates=target_basis, optimization_level=0)
print(decomposed.count_ops())
decomposed2 = transpile(quc, basis_gates=target_basis, optimization_level=3)
print(decomposed2.count_ops())
#display(decomposed.draw(output='mpl'))

OrderedDict({'cx': 528, 'mcx_o1': 40, 'mcx_o3': 40, 'mcx_o5': 36, 'mcx_o7': 32, 'x': 30, 'mcx_o9': 28, 'mcx_o11': 24, 'mcx_o13': 20, 'mcx_o15': 16, 'ch_o0': 15, 'mcx_o17': 12, 'h': 10, 'ry': 8, 'mcx_o19': 8, 'ccx_o1': 4, 'mcx_o21': 4, 'cry_o0': 2})
OrderedDict({'rz': 8720, 'cx': 8171, 'h': 3632, 'rx': 1216, 'ry': 12})
OrderedDict({'rz': 8207, 'cx': 8169, 'h': 3096, 'rx': 154, 'ry': 33})
