In [57]:
import qiskit
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit import Aer

import numpy as np
from math import floor, ceil

%matplotlib inline

In [2]:
def RTL(qc, a, b, c):
    ## fig 3 dashed
    qc.rccx(a, b, c)
    
def RTL_inv(qc, a, b, c):
    RTL(qc, a, b, c)

In [3]:
def RTS(qc, a, b, c):
    ## fig 3 gates 2-6
    qc.h(c)
    qc.t(c)
    qc.cx(b, c)
    qc.tdg(c)
    qc.cx(a, c)
    
def RTS_inv(qc, a, b, c):
    qc.cx(a, c)
    qc.t(c)
    qc.cx(b, c)
    qc.tdg(c)
    qc.h(c)

In [70]:
def SRTS(qc, a, b, c):
    ## circuit 3 dashed
    qc.h(c)
    qc.cx(c, b)
    qc.tdg(b)
    qc.cx(a, b)
    qc.t(b)
    qc.cx(c, b)
    qc.tdg(b)
    qc.cx(a, b)
    qc.t(b)
    
def SRTS_inv(qc, a, b, c):
    qc.tdg(b)
    qc.cx(a, b)
    qc.t(b)
    qc.cx(c, b)
    qc.tdg(b)
    qc.cx(a, b)
    qc.t(b)
    qc.cx(c, b)
    qc.h(c)

In [71]:
def RT4L(qc, a, b, c, d):
    ## fig 4
    qc.rcccx(a, b, c, d)
    
def RT4L_inv(qc, a, b, c, d):
    qc.h(d)
    qc.t(d)
    qc.cx(c, d)
    qc.tdg(d)
    qc.h(d)
    qc.t(d)
    qc.cx(b, d)
    qc.tdg(d)
    qc.cx(a, d)
    qc.t(d)
    qc.cx(b, d)
    qc.tdg(d)
    qc.cx(a, d)
    qc.h(d)
    qc.t(d)
    qc.cx(c, d)
    qc.tdg(d)
    qc.h(d)

In [72]:
def RT4S(qc, a, b, c, d):
    ## fig 4 dashed
    qc.h(d)
    qc.t(d)
    qc.cx(c, d)
    qc.tdg(d)
    qc.h(d)
    qc.cx(a, d)
    qc.t(d)
    qc.cx(b, d)
    qc.tdg(d)
    qc.cx(a, d)
    
def RT4S_inv(qc, a, b, c, d):
    qc.cx(a, d)
    qc.t(d)
    qc.cx(b, d)
    qc.tdg(d)
    qc.cx(a, d)
    qc.h(d)
    qc.t(d)
    qc.cx(c, d)
    qc.tdg(d)
    qc.h(d)

In [184]:
def apply_mct_clean(self, controls, target, ancilla_register):
    if len(controls) < 3:
        raise ValueError("there's something wrong")
    
    n = len(controls)
    ancilla = ancilla_register[:ceil((n-2)/2)]
        
    if n == 3:
        # TODO: Check for ancilla length
        self.rccx(controls[0], controls[1], ancilla[0])
        self.ccx(controls[2], ancilla[0], target)
        self.rccx(controls[0], controls[1], ancilla[0])
        return
    
    if n == 4:
        # TODO: Check for ancilla length
        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
        self.ccx(controls[3], ancilla[0], target)
        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
        return
    
    ## when controls >= 5
    
    if n % 2 == 0:
        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
        self.barrier()
        anc_idx = 1

        for i in range(3, n-1, 2):
    #         print('i = /{}'.format(i))
            self.rcccx(controls[i], controls[i+1], ancilla[anc_idx-1], ancilla[anc_idx])
            self.barrier()
            anc_idx += 1
        if (n-3)%2 == 1:
            self.ccx(controls[-1], ancilla[-1], target)
            self.barrier()
        else:
            self.rccx(controls[-2], ancilla[-2], ancilla[-1])
            self.barrier()
            self.ccx(controls[-1], ancilla[-1], target)
            self.barrier()
            self.rccx(controls[-2], ancilla[-2], ancilla[-1])
            self.barrier()
        for i in reversed(range(3, n-1, 2)):
            anc_idx -= 1
            self.rcccx(controls[i], controls[i+1], ancilla[anc_idx-1], ancilla[anc_idx])
            self.barrier()

        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
    else:
        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
        self.barrier()
        anc_idx = 1

        for i in range(3, n-3, 2):
    #         print('i = /{}'.format(i))
            self.rcccx(controls[i], controls[i+1], ancilla[anc_idx-1], ancilla[anc_idx])
            self.barrier()
            anc_idx += 1
        if (n-3)%2 == 1:
            self.ccx(controls[-1], ancilla[-1], target)
            self.barrier()
        else:
            self.rccx(controls[-2], ancilla[-2], ancilla[-1])
            self.barrier()
            self.ccx(controls[-1], ancilla[-1], target)
            self.barrier()
            self.rccx(controls[-2], ancilla[-2], ancilla[-1])
            self.barrier()
        for i in reversed(range(3, n-3, 2)):
            anc_idx -= 1
            self.rcccx(controls[i], controls[i+1], ancilla[anc_idx-1], ancilla[anc_idx])
            self.barrier()

        self.rcccx(controls[0], controls[1], controls[2], ancilla[0])
    
qr = QuantumRegister(5, 'qr')
anc = QuantumRegister(2, 'anc')
target = QuantumRegister(1, 'target')

qc = QuantumCircuit(qr, anc, target)
apply_mct_clean(qc, qr, target, anc)

# backend = Aer.get_backend('unitary_simulator')
# job = qiskit.execute(qc, backend)
# result = job.result()
# print(result.get_unitary(qc, decimals=3))

# qc.draw(output='mpl')

In [185]:
def apply_mct_dirty(self, controls, target, ancilla):
    # TODO: check controls to be list of bits or register
    if len(controls) == 1:
        self.cx(controls[0], target)
        return
    if len(controls) == 2:
        self.ccx(controls[0], controls[1], target)
        return
    
    if len(controls) == 3:
        SRTS(self, controls[2], ancilla[0], target)
        RTL(self, controls[0], controls[1], ancilla[0])
        SRTS_inv(self, controls[2], ancilla[0], target)
        RTL_inv(self, controls[0], controls[1], ancilla[0])
        return
             
    n = len(controls)
    anc = ancilla[:ceil((n-2)/2)]
    
    
    SRTS(self, controls[-1], anc[-1], target)
    qc.barrier()
    
    if (n-4)%2 == 0:
        a_idx = 1
        for i in reversed(range(floor((n-4)/2))):
            RT4S(self, anc[a_idx - 1 + i], controls[2*i+3], controls[2*i+4], anc[a_idx + i])
            qc.barrier()
    else:
        a_idx = 2
        for i in reversed(range(floor((n-4)/2))):
            RT4S(self, anc[a_idx - 1 + i], controls[2*i+4], controls[2*i+5], anc[a_idx + i])
            qc.barrier()
        RTS(self, anc[0], controls[3], anc[1])
        qc.barrier()
    
    RT4L(self, controls[0], controls[1], controls[2], anc[0])
    qc.barrier()
    
    if (n-4)%2 == 0:
        a_idx = 1
        for i in (range(floor((n-4)/2))):
            RT4S_inv(self, anc[a_idx - 1 + i], controls[2*i+3], controls[2*i+4], anc[a_idx + i])
            qc.barrier()
    else:
        a_idx = 2
        RTS_inv(self, anc[0], controls[3], anc[1])
        qc.barrier()
        for i in (range(floor((n-4)/2))):
            RT4S_inv(self, anc[a_idx - 1 + i], controls[2*i+4], controls[2*i+5], anc[a_idx + i])
            qc.barrier()
            
    SRTS_inv(self, controls[-1], anc[-1], target)
    qc.barrier()
    
    ## SAME AS ABOVE
    if (n-4)%2 == 0:
        a_idx = 1
        for i in reversed(range(floor((n-4)/2))):
            RT4S(self, anc[a_idx - 1 + i], controls[2*i+3], controls[2*i+4], anc[a_idx + i])
            qc.barrier()
    else:
        a_idx = 2
        for i in reversed(range(floor((n-4)/2))):
            RT4S(self, anc[a_idx - 1 + i], controls[2*i+4], controls[2*i+5], anc[a_idx + i])
            qc.barrier()
        RTS(self, anc[0], controls[3], anc[1])
        qc.barrier()
    
    RT4L_inv(self, controls[0], controls[1], controls[2], anc[0])
    qc.barrier()
    
    if (n-4)%2 == 0:
        a_idx = 1
        for i in (range(floor((n-4)/2))):
            RT4S_inv(self, anc[a_idx - 1 + i], controls[2*i+3], controls[2*i+4], anc[a_idx + i])
            qc.barrier()
    else:
        a_idx = 2
        RTS_inv(self, anc[0], controls[3], anc[1])
        qc.barrier()
        for i in (range(floor((n-4)/2))):
            RT4S_inv(self, anc[a_idx - 1 + i], controls[2*i+4], controls[2*i+5], anc[a_idx + i])
            qc.barrier()

In [186]:
qr = QuantumRegister(3, 'qr')
anc = QuantumRegister(3, 'anc')
target = QuantumRegister(1, 'target')

qc = QuantumCircuit(qr, anc, target)
apply_mct_dirty(qc, qr, target, anc)

# backend = Aer.get_backend('unitary_simulator')
# job = qiskit.execute(qc, backend)
# result = job.result()
# print(result.get_unitary(qc, decimals=3))

# qc.draw(output='mpl')

In [187]:
for n in range(5, 10):
    print('{}-bit controls'.format(n))
    for inp in range(2**n):
        qr = QuantumRegister(n, 'qr')
        anc = QuantumRegister(max(ceil((n-2)/2), 1), 'anc')
        target = QuantumRegister(1)
        cr = ClassicalRegister(1, 'cr')
        qc = QuantumCircuit(qr, anc, target, cr)
        
        for i in range(n):
            if (inp & (1<<i)) > 0:
                qc.x(qr[i])
        
#         apply_mct_dirty(qc, qr, target, anc)
        apply_mct_clean(qc, qr, target, anc)
        qc.barrier()
        qc.measure(target, cr[0])
        
        backend = Aer.get_backend('qasm_simulator')
        job = qiskit.execute(qc, backend, shots=10)
        result = job.result()
        counts = result.get_counts()
#         print(inp)
        if '1' in counts:
            print('{} got 1'.format(inp))

5-bit controls
31 got 1
6-bit controls
63 got 1
7-bit controls
127 got 1
8-bit controls
255 got 1
9-bit controls
511 got 1


In [None]:
def apply_mct(circuit, controls, target, anc, mode='clean-ancilla'):
    if len(controls) == 1:
        circuit.cx(controls[0], target)
    else
    if mode == 'clean-ancilla':
        