In [1]:
from qiskit import QuantumCircuit


https://arxiv.org/html/2408.01304v1

In [478]:
import numpy as np
from qiskit.circuit.library import MCXGate, MCMTGate, XGate
from qiskit.circuit import Gate
from itertools import product
from qiskit.quantum_info import Statevector

class MTGate(Gate):
    def __init__(
        self,
        gate,
        num_target_qubits: int
        ):
        super().__init__('MT', num_target_qubits, [])
        self.gate = gate
        self.num_target_qubits = num_target_qubits

    def _define(self):
        qc = QuantumCircuit(self.num_target_qubits, name='MT')
        
        for i in range(self.num_target_qubits):
            qc.append(self.gate, [i])

        self.definition = qc

class CMXGate(Gate):
    def __init__(
        self,
        num_target_qubits: int
        ):
        super().__init__('CMT', num_target_qubits + 1, [])
        self.num_target_qubits = num_target_qubits

    def _define(self):
        qc = QuantumCircuit(self.num_target_qubits + 1, name='CMT')
        
        for i in range(self.num_target_qubits - 1):
            qc.cx(self.num_target_qubits - i -1, self.num_target_qubits - i)
        
        qc.cx(0, 1)

        for i in range(self.num_target_qubits - 1):
            qc.cx(i + 1, i + 2)

        self.definition = qc


def binary_combinations(n):
    # Generate all combinations of binary of length n
    return [''.join(map(str, bits)) for bits in product([0, 1], repeat=n)]

class SelectGate(Gate):
    def __init__(
        self,
        target_state: list
    ) -> None:
        
        self.target_state = target_state
        state_length = len(self.target_state)
        if state_length <= 0 or (state_length & (state_length - 1)) != 0:
            raise Exception(f"Number of state must equal to 2^n, got {state_length}")

        self.num_ctrl_qubits = int(np.log2(len(self.target_state)))
        self.num_target_qubits = max(target_state).bit_length()
        
        # initialize the circuit object
        num_qubits = self.num_ctrl_qubits + self.num_target_qubits
        self.num_qubits = num_qubits
        super().__init__("Select", num_qubits, [])

    def _define(self):
        combinations = binary_combinations(self.num_ctrl_qubits)
        qc = QuantumCircuit(self.num_qubits, name='Select')
        
        for i, ctrl_state in enumerate(combinations):
            target_value = self.target_state[i]
            bin_target_value = bin(target_value)[2:]
            num_qubits = bin_target_value.count("1")
            if num_qubits == 0:
                continue
            
            mcmx_gate = MCMTGate(
                gate = XGate(),
                num_ctrl_qubits = self.num_ctrl_qubits,
                num_target_qubits = num_qubits,
                ctrl_state = ctrl_state # [::-1]
            )
            
            target_qubits = []
            for bi, v in enumerate(bin_target_value[::-1]):
                if v == '1':
                    target_qubits.append(self.num_ctrl_qubits + bi)
                    
            qc.append(mcmx_gate, list(range(self.num_ctrl_qubits)) + target_qubits)

        self.definition = qc

class SelectNetwork(Gate):
    def __init__(
        self,
        lamb: int,
        target_state: list,
        barrier: bool = False
    ) -> None:
        
        self.lamb = lamb
        self.target_state = target_state
        self.barrier = barrier

        state_length = len(self.target_state)

        if state_length % self.lamb != 0:
            raise Exception(f"Number of state {state_length} must be divisible by lambda {self.lamb}")
        state_length_dl = int(state_length/self.lamb)
        if state_length_dl <= 0 or (state_length_dl & (state_length_dl - 1)) != 0:
            raise Exception(f"Number of state {state_length_dl}/{self.lamb}={state_length_dl/self.lamb} must equal to 2^n")

        self.num_ctrl_qubits = int(np.log2(state_length_dl)) 
        self.num_target_qubits = max(target_state).bit_length()
        self.total_num_target_qubits = self.num_target_qubits * self.lamb

        # initialize the circuit object
        num_qubits = self.num_ctrl_qubits + self.total_num_target_qubits
        self.num_qubits = num_qubits
        super().__init__("Select", num_qubits, [])

    def _define(self):
        combinations = binary_combinations(self.num_ctrl_qubits)
        qc = QuantumCircuit(self.num_qubits, name='Select')
        
        for i, ctrl_state in enumerate(combinations):
            ap_qubits = list(range(self.num_ctrl_qubits))
            for j in range(self.lamb):
                target_value = self.target_state[i*self.lamb + j]
                bin_target_value = bin(target_value)[2:]

                num_qubits = bin_target_value.count("1")
                if num_qubits == 0:
                    continue
            
                target_qubits = []
                for bi, v in enumerate(bin_target_value[::-1]):
                    if v == '1':
                        target_qubits.append(bi)
                        
                ap_qubits += [self.num_ctrl_qubits + v + j*self.num_target_qubits for v in target_qubits]
            
            if len(ap_qubits) - self.num_ctrl_qubits <= 0:
                continue
            
            if self.num_ctrl_qubits == 0:
                mx_gate = MTGate(
                    gate = XGate(),
                    num_target_qubits = len(ap_qubits) - self.num_ctrl_qubits,
                )
            else:
                mx_gate = MCMTGate(
                    gate = XGate(),
                    num_ctrl_qubits = self.num_ctrl_qubits,
                    num_target_qubits = len(ap_qubits) - self.num_ctrl_qubits,
                    ctrl_state = ctrl_state # [::-1]
                )
                
            qc.append(mx_gate, ap_qubits)
            if i != len(combinations) - 1:
                if self.barrier:
                    qc.barrier()
                
        self.definition = qc

class LambdaSwapNetwork(Gate):
    def __init__(
        self,
        lamb: int,
        num_target_qubits: int
    ) -> None:
        self.lamb = lamb
        self.num_ctrl_qubits = round(np.log2(self.lamb))
        self.num_target_qubits = num_target_qubits
        self.total_num_target_qubits = self.num_target_qubits * self.lamb
        self.num_ancilla_qubit = 1
        # initialize the circuit object
        num_qubits = self.num_ctrl_qubits + self.num_ancilla_qubit + self.total_num_target_qubits
        self.num_qubits = num_qubits
        super().__init__("Swap", num_qubits, [])

    def _define(self):
        combinations = binary_combinations(self.num_ctrl_qubits)
        qc = QuantumCircuit(self.num_qubits, name='Swap')
        
        for i in range(self.lamb):
            if i == 0:
                continue
            ctrl_state = combinations[i]
            mcx_gate = MCMTGate(
                gate = XGate(),
                num_ctrl_qubits = self.num_ctrl_qubits,
                num_target_qubits = 1,
                ctrl_state = ctrl_state # [::-1]
            )
            qc.append(mcx_gate, list(range(self.num_ctrl_qubits)) + [self.num_ctrl_qubits])
            for j in range(self.num_target_qubits):
                qc.cswap(self.num_ctrl_qubits, self.num_ctrl_qubits+1+j, self.num_ctrl_qubits+1+j+(i*self.num_target_qubits))
            qc.append(mcx_gate, list(range(self.num_ctrl_qubits)) + [self.num_ctrl_qubits])

        self.definition = qc

In [479]:
qc = QuantumCircuit(6)
qc.append(CMXGate(3), [0, 1, 2, 3])
qc.decompose().draw()

In [480]:
qc = QuantumCircuit(6)
qc.append(
    MTGate(
        gate=XGate(),
        num_target_qubits=3
        ), [0, 1, 4]
    )
qc.decompose().draw()

In [481]:
select_gate_1 = SelectGate(target_state=[0, 1, 5, 8])
qc = QuantumCircuit(select_gate_1.num_qubits)
qc.x(0)
qc.append(select_gate_1, range(select_gate_1.num_qubits))
qc.decompose().draw()

In [482]:
select_gate_1.num_qubits

6

In [483]:
select_gate_1 = SelectNetwork(lamb=3, target_state=[0, 1, 2, 3, 4, 5], barrier=True)
qc = QuantumCircuit(select_gate_1.num_qubits)
qc.x(0)
qc.barrier()
qc.append(select_gate_1, range(select_gate_1.num_qubits))
qc.decompose().draw()

In [484]:
swap_gate_1.num_qubits

6

In [485]:
swap_gate_1 = LambdaSwapNetwork(lamb=2, num_target_qubits=2)
qc = QuantumCircuit(swap_gate_1.num_qubits)
qc.x(0)
qc.x(5)
qc.append(swap_gate_1, range(swap_gate_1.num_qubits))

qc.decompose().draw()


In [486]:
Statevector(qc).draw('latex')

<IPython.core.display.Latex object>