In [None]:
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

BITS = 4  # Each number is 4 bits (0..15)

# ---------- Parse output ----------
def parse_output_bitstring(bitstr, n_items):
    N = len(bitstr)  # = n_items * BITS
    out = []
    for r in range(n_items):
        val = 0
        for j in range(BITS):                  # LSB -> MSB
            c_idx = r * BITS + j
            pos = N - 1 - c_idx
            bit = 1 if bitstr[pos] == '1' else 0
            val |= (bit << j)
        out.append(val)
    return out

# ---------- Low-level: Controlled SWAP ----------
def cswap_bit(qc, ctrl, a, b):
    qc.ccx(ctrl, a, b)
    qc.ccx(ctrl, b, a)
    qc.ccx(ctrl, a, b)

def cswap_register(qc, ctrl, A, B):
    for i in range(BITS):
        cswap_bit(qc, ctrl, A[i], B[i])

# ---------- Comparator: flag = [A > B], and fully clean equality chain ----------
def compute_flag_a_gt_b_and_cleanup(qc, A, B, eq_chain, tmp1, tmp2, flag):
    assert len(A) == len(B) == BITS
    assert len(eq_chain) == BITS + 1

    # eq0 = 1
    qc.x(eq_chain[0])

    # Forward: from MSB -> LSB
    for k in range(BITS - 1, -1, -1):  # 3,2,1,0
        a, b = A[k], B[k]
        eq_prev = eq_chain[BITS - 1 - k]
        eq_next = eq_chain[BITS - 1 - k + 1]

        # tmp1 = A[k] & ~B[k]
        qc.x(b); qc.ccx(a, b, tmp1); qc.x(b)

        # At the first different bit with A=1, B=0, set flag ^= 1
        qc.ccx(eq_prev, tmp1, flag)

        # Clear tmp1
        qc.x(b); qc.ccx(a, b, tmp1); qc.x(b)

        # tmp2 = XNOR(a,b)
        qc.cx(a, tmp2); qc.cx(b, tmp2); qc.x(tmp2)
        # eq_next = eq_prev & XNOR(a,b)
        qc.ccx(eq_prev, tmp2, eq_next)
        # Clear tmp2
        qc.x(tmp2); qc.cx(b, tmp2); qc.cx(a, tmp2)

    # Backward: reset eq_chain[1..] all to 0
    for k in range(0, BITS):
        bit = BITS - 1 - k
        a, b = A[bit], B[bit]
        eq_prev = eq_chain[k]; eq_next = eq_chain[k + 1]

        # Recompute XNOR into tmp2
        qc.cx(a, tmp2); qc.cx(b, tmp2); qc.x(tmp2)
        # Turn off eq_next
        qc.ccx(eq_prev, tmp2, eq_next)
        # Clear tmp2
        qc.x(tmp2); qc.cx(b, tmp2); qc.cx(a, tmp2)

    # eq0 restored from 1 to 0
    qc.x(eq_chain[0])

def compute_flag_a_lt_b_and_cleanup(qc, A, B, eq_chain, tmp1, tmp2, flag):
    # A < B is equivalent to B > A
    compute_flag_a_gt_b_and_cleanup(qc, B, A, eq_chain, tmp1, tmp2, flag)

# ---------- Compare-and-swap (ascending: swap if A>B) ----------
def compare_and_swap_up(qc, A, B, eq_chain, tmp1, tmp2, flag):
    # First call: compute A > B, flag = (A > B)
    compute_flag_a_gt_b_and_cleanup(qc, A, B, eq_chain, tmp1, tmp2, flag)
    # Controlled swap by flag
    cswap_register(qc, flag, A, B)
    # Second call: use the same circuit to reset flag back to 0
    compute_flag_a_gt_b_and_cleanup(qc, A, B, eq_chain, tmp1, tmp2, flag)

# ---------- Build ONLY the sorting unitary (no input preparation, no measurement) ----------
def build_sort_unitary():
    num_data = 4 * BITS
    anc_eq = BITS + 1
    num_flags = 5
    total_qubits = num_data + anc_eq + 2 + num_flags

    sort_u = QuantumCircuit(total_qubits)  # quantum-only subcircuit

    regs = [[i * BITS + j for j in range(BITS)] for i in range(4)]
    base = num_data
    eq_chain = [base + i for i in range(anc_eq)]
    tmp1 = base + anc_eq
    tmp2 = base + anc_eq + 1
    flags = [base + anc_eq + 2 + k for k in range(num_flags)]
    f = iter(flags)

    # Sorting network comparisons (ascending)
    compare_and_swap_up(sort_u, regs[0], regs[1], eq_chain, tmp1, tmp2, next(f))
    compare_and_swap_up(sort_u, regs[2], regs[3], eq_chain, tmp1, tmp2, next(f))
    compare_and_swap_up(sort_u, regs[0], regs[2], eq_chain, tmp1, tmp2, next(f))
    compare_and_swap_up(sort_u, regs[1], regs[3], eq_chain, tmp1, tmp2, next(f))
    compare_and_swap_up(sort_u, regs[1], regs[2], eq_chain, tmp1, tmp2, next(f))

    return sort_u, num_data, total_qubits

# ---------- Build ONLY the input-preparation subcircuit ----------
def build_prepare_input(values, total_qubits):
    prep = QuantumCircuit(total_qubits)
    # write values into the first num_data qubits (LSB-first in each 4-bit chunk)
    for i, val in enumerate(values):
        for j in range(BITS):
            if (val >> j) & 1:
                prep.x(i * BITS + j)
    return prep

# ---------- Main ----------
if __name__ == "__main__":
    values = [int(x) for x in input("Enter 4 numbers (0..15): ").split()]
    if len(values) != 4 or any(v < 0 or v > 15 for v in values):
        raise ValueError("Please enter exactly 4 integers in 0..15")

    mode = input("Choose mode: sort / inverse / sort+inverse: ").strip()

    sim = AerSimulator(method="matrix_product_state")

    sort_u, num_data, total_qubits = build_sort_unitary()
    prep = build_prepare_input(values, total_qubits)

    # Compose the quantum-only pipeline
    if mode == "sort":
        # prepare -> sort
        quantum_part = prep.compose(sort_u)
    elif mode == "inverse":
        # prepare -> inverse(sort)
        quantum_part = prep.compose(sort_u.inverse())
    elif mode == "sort+inverse":
        # prepare -> sort -> inverse(sort)
        quantum_part = prep.compose(sort_u).compose(sort_u.inverse())
    else:
        raise ValueError("Invalid mode")

    # Wrap with measurement circuit
    qc = QuantumCircuit(total_qubits, num_data)
    qc.compose(quantum_part, inplace=True)
    qc.measure(range(num_data), range(num_data))

    # Run
    tqc = transpile(qc, sim, optimization_level=0)
    result = sim.run(tqc, shots=1).result()
    bitstr = list(result.get_counts().keys())[0]
    output = parse_output_bitstring(bitstr, 4)

    print("Original input:", values)
    print("Final result:", output)
    # Text circuit (no extra deps)
    # print(qc.draw("text"))