# An Efficient Algorithm for Sparse Quantum State Preparation

Implementation of the [paper from Niels Gleinig & Torsten Hoefler](https://htor.inf.ethz.ch/publications/img/quantum_dac.pdf) using Classiq's python SDK.

In [6]:
# Classiq related (algorithmic part)
from classiq import *
import numpy as np
from typing import List, Tuple
from classiq.qmod import control, unitary
from classiq.execution import ClassiqBackendPreferences, ExecutionPreferences
from classiq.synthesis import set_execution_preferences, SerializedQuantumProgram


# Qiskit related (end of computation)
from qiskit import QuantumCircuit
from qiskit.circuit.controlledgate import ControlledGate
from qiskit.quantum_info import Operator
from qiskit.circuit.library import *
from math import log2, ceil
from qiskit.qasm2 import dumps
import qiskit_aer
import qiskit

In [7]:
def gate_matrix(alpha: float, beta: float) -> List[float]:
    """ Implementation of the gate G as described in the paper (II.A)

    Args:
        alpha (float): Probability of ket 0
        beta (float): Probability of ket 1

    Returns:
        List[float]: List representing the G matrix
    """
    alpha = np.sqrt(alpha)
    beta = np.sqrt(beta)
    a11 = np.sin(beta)
    a12 = np.exp(1j * alpha) * np.cos(beta)
    a21 = np.exp(- 1j * alpha) * np.cos(beta)
    a22 = - np.sin(beta)

    return np.array([[a11, a12], [a21, a22]]).tolist()

In [8]:

def state_to_bitwise(state: List[float]) -> List[List[int]]:
    """ Converts an array of coefficients into a set of bitwise indexes, removing empty ones

    Args:
        state (List[float]): State to convert

    Returns:
        List[List[int]]: List of bitwise indexes
    """
    size = 1
    length = 2
    state_len = len(state)
    while state_len > length:
        size += 1
        length *= 2
    result = []
    for i in range(state_len):
        if state[i] != 0:
            bitwise = []
            while i > 0:
                bitwise.append(i%2)
                i = i//2
            while len(bitwise) < size:
                bitwise.append(0)
            bitwise.reverse()
            result.append(bitwise)
    return result

In [9]:
def optimal_split(T: List[str]) -> Tuple[int, List[str], List[str]]:
    """ Finds a bit such that it splits T into two sets as unequal as possible but neither are empty and:
    .. math::
        b \in \{1, 2, ..., n\}
        T_0 := \{x \in T | x[b] == 0\}
        T_1 := \{x \in T | x[b] == 1\}
    
    Args:
        T (List[str]): Set T

    Raises:
        IndexError: State should not be empy
        RuntimeError: No split for the given state

    Returns:
        Tuple[int, List[str], List[str]]: b, T0, T1 as desribed above
    """
    if T == []:
        raise IndexError("State should not be empy")
    append_T0 = False
    append_T1 = False
    for bit_nb in range(len(T[0])):
        T0, T1 = [], []
        for state in T:
            if state[bit_nb] == 0:
                T0.append(state)
                append_T0 = True
            else:
                T1.append(state)
                append_T1 = True
        if append_T0 and append_T1:
            return bit_nb, T0, T1
    raise RuntimeError("No split for the given state")

  """ Finds a bit such that it splits T into two sets as unequal as possible but neither are empty and:


In [10]:
def bitwise_to_int(bitwise : List[int]) -> int:
    """ Converts a list of bits to its decimal representation

    Args:
        bitwise (List[int]): List of bits

    Returns:
        int: Decimal representation of bitwise
    """
    res = 0
    for i in bitwise:
        res *= 2
        res += i
    return res

In [11]:
def build_T_prime(T: List[str], dif_qubits: List[int], dif_values: List[int]) -> List[str]:
    """ Builds T' according to the paper:
    Let T' ⊂ S denote the set of strings that have the values in dif_values on the bits dif_qubits;

    Args:
        T (List[str]): State s represented in bits
        dif_qubits (List[int]): Stack of bits b ∈ {1, 2, . . . , n} that will hold in the end the bits that we use as control for the “merging” step
        dif_values (List[int]): Stack of boolean values

    Returns:
        List[str]: List of bits T' as described above
    """
    T_prime = []
    for state in T:
        matches = True
        for qubit, value in zip(dif_qubits, dif_values):
            matches = int(state[qubit]) == value
            if not matches:
                break
        if matches:
            T_prime.append(state)
    return T_prime

In [12]:
def translate_circuit(full_circuit: list[QuantumProgram], state: List[int]) -> QuantumCircuit:
    """ Translates the circuit from a list of QuantumProgram of Classiq to a QuantumCircuit in Qiskit.
        It reverses the order of gates while inversing them and adding the required X gates (from Algorithm 2)

    Args:
        full_circuit (list[QuantumProgram]): List of Classiq QuantumPrograms to concatenate
        state (List[int]): List of amplitudes required for the state

    Returns:
        QuantumCircuit: Final Qiskit circuit for sparse state preparation
    """

    nb_qubits : int = ceil(log2(len(state)))
    gates = []

    for circ in full_circuit:

        for gate in circ.debug_info:
            if "XGate" in gate.name:
                target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in gate.registers if reg.name == "TARGET").qubit_indexes_absolute))
                gates.append((XGate(), target))

            for child in gate.children:
                if child.generated_function.name == "control":
                    ctrl = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == "control_group").qubit_indexes_absolute))
                    target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == "TARGET").qubit_indexes_absolute))
                    strip_and_replace = lambda x: complex(x.replace(" ", "").replace("*I", "j").strip(']['))
                    i = np.array(list(map(strip_and_replace, child.parameters[0].value.strip('][').split(',')))).reshape(2, 2).tolist()

                    UC = UnitaryGate(i).control(1)
                    gates.append((UC, [ctrl, target]))

                elif "CXGate" in child.generated_function.name:
                    ctrl = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == "CTRL").qubit_indexes_absolute))
                    target = list(map(lambda x: nb_qubits - 1 - x, next(reg for reg in child.registers if reg.name == "TARGET").qubit_indexes_absolute))

                    gates.append((CXGate(), [ctrl, target]))

    T = state_to_bitwise(state)
    for i in range(len(T[0])):
        if T[0][i] == 1:
            gates.append((XGate(), [i]))

    # reverse QC and inverse gates
    rev_qc = QuantumCircuit(nb_qubits)

    for i in range(len(gates) - 1, -1, -1):
        gate, trgts = gates[i]
        if isinstance(gate, ControlledGate) and not isinstance(gate, CXGate):
            rev_qc.append(Operator(gate.inverse()), trgts)
        else:
            rev_qc.append(gate.inverse(), trgts)

    return rev_qc

In [13]:
def algo1_classic_part(state : list[float]) -> Tuple[int, list[int], list[int], list[int]]:
    """ Classical part of the algorithm (not requiring any Classiq) for easier understanding
        Basically lines 1 to 28 of Algorithm 1
        Called by main

    Args:
        state (list[float]): State we want to prepare

    Returns:
        Tuple[int, list[int], list[int], list[int]]: dif, x1, x2, dif_qubits as they are needed for the quantum part of the algorithm
    """
    dif_qubits = []
    dif_values = []
    T = state_to_bitwise(state)
    T_copy = T.copy()
    while len(T) > 1:
        bit_nb, T0, T1 = optimal_split(T)
        dif_qubits.append(bit_nb)
        if (len(T0) < len(T1)):
            T = T0
            dif_values.append(0)
        else:
            T = T1
            dif_values.append(1)
    
    dif = dif_qubits.pop()
    dif_values.pop()
    x1 = T[0]
    T_prime = build_T_prime(T_copy, dif_qubits, dif_values)
    T_prime.remove(x1)

    while len(T_prime) > 1:
        bit_nb, T0, T1 = optimal_split(T_prime)
        dif_qubits.append(bit_nb)
        if (len(T0) < len(T1)):
            T_prime = T0
            dif_values.append(0)
        else:
            T_prime = T1
            dif_values.append(1)
    
    x2 = T_prime[0]
    return (dif, x1, x2, dif_qubits)

In [14]:
@qfunc(generative=True)
def algo1(quantum_circuit : QArray[QBit], dif : int, x1 : list[int], x2 : list[int], dif_qubits : list[int], state : list[float]):
    """ Quantum part of Algorithm 1
        Called by main

    Args:
        quantum_circuit (QArray[QBit]): Quantum circuit as required by Classiq
        dif (int): last value appended to dif_qubit
        x1 (list[int]): Single element of T
        x2 (list[int]): Single element of T'
        dif_qubits (list[int]): Stack of bits
        state (list[float]): State to prepare /!\ Global variable
    """

    qdif=quantum_circuit[dif]
    size_n = len(x1)

    if x1[dif] != 1:
        X(qdif)

    for b in range(size_n):
        if b != dif and x1[b] != x2[b]:
            CX(ctrl=qdif, target=quantum_circuit[b])

    for b in dif_qubits:
        if x2[b] != 1:
            X(quantum_circuit[b])

    if dif == 0:
        target_qbit : QArray[QBit] = QArray("traget_qbit", length=1)
        control_group : QArray[QBit] = QArray("control_group",length=size_n-1)
        bind(quantum_circuit, [target_qbit, control_group])

        control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))

        bind([target_qbit, control_group], quantum_circuit)
    elif dif == size_n - 1:
        target_qbit : QArray[QBit] = QArray("traget_qbit", length=1)
        control_group : QArray[QBit] = QArray("control_group",length=size_n-1)
        bind(quantum_circuit, [control_group, target_qbit])

        control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))

        bind([control_group, target_qbit], quantum_circuit)
    else:
        target_qbit : QArray[QBit] = QArray("traget_qbit", length=1)
        before_target : QArray[QBit] = QArray("before",length=max(0,dif-1))
        after_target : QArray[QBit] = QArray("after", length=size_n-dif-1)
        bind(quantum_circuit, [before_target, target_qbit, after_target])
        control_group : QArray[QBit] = QArray("control_group",length=size_n-1)
        bind([before_target,after_target], control_group)

        control(ctrl=control_group, stmt_block=lambda: unitary(gate_matrix(state[bitwise_to_int(x1)], state[bitwise_to_int(x2)]), target=target_qbit))

        bind(control_group, [before_target, after_target])
        bind([before_target, target_qbit, after_target], quantum_circuit)

  """ Quantum part of Algorithm 1


In [15]:
@qfunc
def main(quantum_circuit: Output[QArray[QBit]]) -> None:
    """ Basically Algorithm 2 but it **needs** to be called main
        The while loop will call this every iteration

    Args:
        quantum_circuit (Output[QArray[QBit]]): Quantum circuit on which the gates will be added
    """
    T = state_to_bitwise(state)
    size_n = len(T[0])
    allocate(size_n, quantum_circuit)
    if len(T)>1:
        dif, x1, x2, dif_qubits = algo1_classic_part(state)
        algo1(quantum_circuit, dif, x1, x2, dif_qubits, state)
        kept_state = bitwise_to_int(x1)
        merged_state = bitwise_to_int(x2)
        state[kept_state] += state[merged_state]
        state[merged_state] = 0.

In [16]:
def algo1_iter() -> SerializedQuantumProgram :
    """ Driver code for an iteration of algo1 to leverage Classiq

    Returns:
        SerializedQuantumProgram: Synthetized program from Classiq
    """
    quantum_program = create_model(main)
    backend_preferences = ClassiqBackendPreferences(backend_name="simulator_statevector")
    qmod_b_load = set_execution_preferences(
        quantum_program,
        execution_preferences=ExecutionPreferences(
            num_shots=1, backend_preferences=backend_preferences
        ),
    )
    qprog_b_load = synthesize(qmod_b_load)
    return qprog_b_load

# State to prepare

In [18]:
# BELL STATE
# state = [0.5, 0, 0, 0.5]
# GHZ STATE 4 Qbits
# state : List[float] = [0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5]
state = [0 for _ in range(8)]
state[1] = 2 / np.sqrt(168)
state[4] = 8 / np.sqrt(168)
state[7] = 10 / np.sqrt(168)

In [19]:
T = state_to_bitwise(state)
full_circuit : List[QuantumProgram] = []
while len(T) > 1:
    qprog_b_load = algo1_iter()
    full_circuit.insert(0, QuantumProgram.from_qprog(qprog_b_load))
    T = state_to_bitwise(state)
qprog_b_load = algo1_iter()
full_circuit.append(QuantumProgram.from_qprog(qprog_b_load))

In [22]:
rev_qc = translate_circuit(full_circuit, state)
rev_qc.measure_all()
bc = qiskit_aer.Aer.get_backend("aer_simulator")
tqc = qiskit.transpile(rev_qc, bc)
job = bc.run(tqc, shots=10000)
print(job.result().get_counts())
rev_qc.draw()

{'100': 1899, '110': 4008, '001': 4093}


In [23]:
qasm_str = dumps(rev_qc)
with open("circuit.qasm", "w") as text_file:
    text_file.write(qasm_str)