In [1]:
# plot the full bethe circuit for L=4,M=2. cleaned up version of the code.

from qiskit import *
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
import os
import json

def Theta_matrix(k_i,Jxy,Jz):
    k_j = k_i.copy()
    k_i = k_i[:,np.newaxis]
    Theta = 2.0*np.arctan(Jz*np.sin((k_i - k_j)/2.0)/(Jz*np.cos((k_i - k_j)/2.0) - Jxy*np.cos((k_i + k_j)/2.0)))
    return Theta

def DickeMukherjee(circ,n,k):
    dicke_pars = (n,k)
    circ.x(range(n-k, n))

    t = 0
    while ((n-t) > k): # first set of SCS
        SCS_mukherjee(dicke_pars,circ,n-t,k) # first slot after circ is the upper index of SCS^a_b, second is lower
        t += 1

    i = k
    while (i >= 2): # second set of SCS
        SCS_mukherjee(dicke_pars,circ,i,i-1)
        i -= 1

def SCS_mukherjee(dicke_pars,circ, a, b): # Split & Cyclic Shift transformation
    t = dicke_pars[0] - a # only true for the first set of SCS, but condition 1 of the if statement below ensures this
    if (b==dicke_pars[1]) and (t < (dicke_pars[1]-1)):
        pass
    else:
        mu(circ,a)

    loopct=0 # want to skip the first k-2-t M blocks for the first set of SCS
    for l in range(a-1,a-b+1-1,-1):
        if (b==dicke_pars[1]) and (loopct < (dicke_pars[1] - 2 - t)):
            pass
        else:
            if (dicke_pars[1]>1) and (t==0) and (l == (dicke_pars[0]-dicke_pars[1]+1)):
                circ.ry(2*np.arccos(np.sqrt(dicke_pars[1]/dicke_pars[0])),dicke_pars[0]-dicke_pars[1]-1) # -1 for counting from zero
                circ.cnot(dicke_pars[0]-dicke_pars[1]-1,dicke_pars[0]-1)

            elif l == (dicke_pars[0]-dicke_pars[1]+1): # improvement for first non-identity M
                circ.cnot(dicke_pars[0]-dicke_pars[1]-1,dicke_pars[0]-t-1)
                # CU
                theta = 2*np.arccos(np.sqrt((dicke_pars[0]-t-(dicke_pars[0]-dicke_pars[1]+1)+1)/(dicke_pars[0]-t)))
                alpha = np.pi - theta
                circ.ry(alpha/2,dicke_pars[0]-dicke_pars[1]-1)
                circ.cnot(dicke_pars[0]-t-1,dicke_pars[0]-dicke_pars[1]-1)
                circ.ry(-alpha/2,dicke_pars[0]-dicke_pars[1]-1)
                # end CU
                circ.cnot(dicke_pars[0]-dicke_pars[1]-1,dicke_pars[0]-t-1)

            else:
                M(circ,l,a)
        loopct+=1


def mu(circ,n):
    nn=n-1 # qubits counted from zero

    circ.cnot(nn-1,nn)

    theta = 2*np.arccos(np.sqrt(1/n))
    alpha = np.pi - theta
    circ.ry(alpha/2,nn-1)
    circ.cnot(nn,nn-1)
    circ.ry(-alpha/2,nn-1)

    circ.cnot(nn-1,nn)

def M(circ,l,n):
    th = 2*np.arccos(np.sqrt((n-l+1)/n))
    nn = n-1
    ll = l-1

    docancel = True
    # adjust last gate for cancellation
    if docancel:
        circ.data[-1][1][1] = circ.qregs[0][ll-1] # change target

    circ.cnot(ll-1, nn)
    if not docancel:
        circ.cnot(ll,ll-1)
    circ.ry(-th/4,ll-1)
    circ.cnot(nn,ll-1)
    circ.ry(th/4,ll-1)
    circ.cnot(ll,ll-1)
    circ.ry(-th/4,ll-1)
    circ.cnot(nn,ll-1)
    circ.ry(th/4,ll-1)
    circ.cnot(ll-1, nn)


def aswap(acircuit,i,j,theta,phi):
    acircuit.cx(j,i)
    acircuit.rz(-phi-np.pi,j)
    acircuit.ry(-theta-np.pi/2,j)
    acircuit.cx(i,j)
    acircuit.ry(theta+np.pi/2,j)
    acircuit.rz(phi+np.pi,j)
    acircuit.cx(j,i)

def permutation_superposition(circ,M,startqb):
    labelqbs=[]
    for ct in range(M):
        labelqbs.append(list(range(startqb,startqb+M)))
        startqb+=M

    for ct in range(M):
        circ.x(labelqbs[ct][ct])
        for ct2 in range(ct-1,-1,-1):
            aswap(circ,labelqbs[ct2][ct],labelqbs[ct2+1][ct],np.arccos(np.sqrt(1/(ct2+2))),0.0)
            for ct3 in range(ct):
                circ.cswap(labelqbs[ct2][ct],labelqbs[ct2+1][ct3],labelqbs[ct2][ct3])

def permutation_superposition_phases(circ,M,L,Theta):
    labelqbs=[]
    startqb=L
    for ct in range(M):
        labelqbs.append(list(range(startqb,startqb+M)))
        startqb+=M
    print(labelqbs,startqb)
    for ct in range(M):
        circ.x(labelqbs[ct][ct])
        for ct2 in range(ct-1,-1,-1):
            aswap(circ,labelqbs[ct2][ct],labelqbs[ct2+1][ct],np.arccos(np.sqrt(1/(ct2+2))),0.0)
            for ct3 in range(ct):
                circ.cswap(labelqbs[ct2][ct],labelqbs[ct2+1][ct3],labelqbs[ct2][ct3])
            for ct3 in range(ct):
                circ.cp(Theta[(labelqbs[ct2][ct]-L)%M,(labelqbs[ct2+1][ct3]-L)%M]+np.pi,labelqbs[ct2+1][ct3],labelqbs[ct2][ct]) # add pi to each transposition to get the overall -1 for the signature of the perm

def eikPx_phases_faucetmethod(circ,L,N,k_i,sysqbs,permlabelqbs,expikPxqbs,workexpikPxqbs):
    # initially set the expikPx qubits to 1 to accrue phases
    for qb in expikPxqbs:
        circ.x(qb)

    for sitenum in range(L):
        # turn off faucet ancillas as need be
        # last one
        if sitenum < N-1: # this faucet can not have been turned off yet
            pass
        else:
            circ.x(expikPxqbs[-2])
            circ.ccx(sysqbs[sitenum],expikPxqbs[-2],expikPxqbs[-1])
            circ.x(expikPxqbs[-2])
        # middle ones
        for faucetancilla in range(N-2,0,-1): # count backwards
            if faucetancilla > sitenum: # impossible for this ancilla to be turned off at this point in the process
                continue
            elif (L-(N-faucetancilla)) < sitenum: # impossible for this ancilla to still be on
                continue
            else:
                circ.x(expikPxqbs[faucetancilla-1])
                circ.ccx(expikPxqbs[faucetancilla-1],expikPxqbs[faucetancilla+1],workexpikPxqbs[0])
                circ.ccx(workexpikPxqbs[0],sysqbs[sitenum],expikPxqbs[faucetancilla])
                circ.ccx(expikPxqbs[faucetancilla-1],expikPxqbs[faucetancilla+1],workexpikPxqbs[0])
                circ.x(expikPxqbs[faucetancilla-1])
        # first one
        if sitenum > (L-N): # first ancilla guaranteed to be off
            pass
        else:
            circ.ccx(sysqbs[sitenum],expikPxqbs[1],expikPxqbs[0])

        # apply the pieces of exp(ikP) phase for this site
        for downspinnum in range(N):
            for labelnum in range(N):
                circ.cp(k_i[labelnum].real,permlabelqbs[downspinnum*N + labelnum],expikPxqbs[downspinnum])


#def main():
st = time.time()
np.set_printoptions(linewidth=200)
L=4
N=2
Jxy=2.0
Jz=-1.0
k_i = np.array([1.14676529, 3.56562369])
print('k_i values',k_i)
Theta_mat = Theta_matrix(k_i=k_i,Jxy=Jxy,Jz=Jz)
print(Theta_mat)
nqbspermlabel = N**2
print('nqbspermlabel',nqbspermlabel)
nqbsexpikPx = N
nqbsworkexpikPx = 1 # only use 1 work qubit
print('nqbsexpikPx',nqbsexpikPx)
print('nqbsworkexpikPx',nqbsworkexpikPx)
nqbs = L+nqbspermlabel+nqbsexpikPx+nqbsworkexpikPx
print('nqbs',nqbs)
sysqbs = [val for val in range(L)]
permlabelqbs = [val for val in range(L,L+nqbspermlabel)]
expikPxqbs = [val for val in range(L+nqbspermlabel,L+nqbspermlabel+nqbsexpikPx)]
workexpikPxqbs = [val for val in range(L+nqbspermlabel+nqbsexpikPx,nqbs)]
print('sysqbs',sysqbs)
print('permlabelqbs',permlabelqbs)
print('expikPxqbs',expikPxqbs)
print('workexpikPxqbs',workexpikPxqbs)
ntrials = 1
reverse_perm = QuantumCircuit(nqbs,nqbspermlabel)
permutation_superposition(reverse_perm,N,L)
reverse_perm = reverse_perm.inverse()
circuit = QuantumCircuit(nqbs,nqbspermlabel)
DickeMukherjee(circuit,L,N)
permutation_superposition_phases(circuit,N,L,Theta_mat)
eikPx_phases_faucetmethod(circuit,L,N,k_i,sysqbs,permlabelqbs,expikPxqbs,workexpikPxqbs)
circuit = circuit.compose(reverse_perm)
measurecirc = QuantumCircuit(nqbs,nqbspermlabel)
measurecirc.measure(permlabelqbs,range(N**2))
fullcircuit = circuit.compose(measurecirc)
dispcolors = {'cx':('#426299', '#000000'),
              'cswap':('#705399', '#000000'),
              'cp':('#CFFF8B', '#000000')}

outfilename = 'plot_full_bethecircuit_L4M2_test.pdf'

circuit_diagram = fullcircuit.draw(output='mpl',style={'displaycolor':dispcolors}, filename='my_circuit.pdf')

mymeta = {}
mymeta['plotscript'] = os.path.basename('/home/alberto/Desktop/Notebooks')
mymeta['myfilename'] = outfilename
plt.savefig(outfilename,metadata={"Creator": json.dumps(mymeta)})
et = time.time()
print('time = ',(et-st)/60," min.")

#if __name__ == '__main__':
#    main()


k_i values [1.14676529 3.56562369]
[[ 0.         1.4454685]
 [-1.4454685 -0.       ]]
nqbspermlabel 4
nqbsexpikPx 2
nqbsworkexpikPx 1
nqbs 11
sysqbs [0, 1, 2, 3]
permlabelqbs [4, 5, 6, 7]
expikPxqbs [8, 9]
workexpikPxqbs [10]
[[4, 5], [6, 7]] 8
time =  0.03265887896219889  min.


<Figure size 640x480 with 0 Axes>