In [1]:
!pip install jax
!pip install jaxlib
import numpy as np
from numpy import linalg as LA
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from scipy.linalg import sqrtm
from numpy.linalg import matrix_rank
from jax.numpy import linalg as jla
from scipy.linalg import logm, expm

from scipy.linalg import block_diag
import copy
import numba
from numba import int64, float64, jit, njit, vectorize
import matplotlib.pyplot as plt



In [2]:
def qubit(x):
  if x==0:
    return np.array([1,0])
  elif x==1:
    return np.array([0,1])

In [3]:
def pauli(x):
    if x==0:
        return np.array([[1,0],[0,1]])
    if x==1:
        return np.array([[0,1],[1,0]])
    if x==2:
        return np.array([[0,-1j],[+1j,0]])
    if x==3:
        return np.array([[1,0],[0,-1]])
    else:
        print('invalid')

# modified rho based on delta, gamma, and input x rather than theta and p
def rhom(delta,g,x):
  return pauli(x)@np.array([[delta, g],[g, 1-delta]])@pauli(x)

# "bitnode" take input states parameterized by (d1, g1) and (d2, g2) and returns the post-measurement state parameters and probabilities at a bitnode
def bitnode(d1,d2,g1,g2):
  # the state given root value z=0 is rho0 = np.kron(rhom(d1,g1,0),rhom(d2,g2,0))
  # the state given root value z=1 is rho1 = np.kron(rhom(d1,g1,1),rhom(d2,g2,1))
  
  # we find the paired measurements by taking the eigenvectors of the difference matrix rhofull
  rhofull = np.kron(rhom(d1,g1,0),rhom(d2,g2,0))-np.kron(rhom(d1,g1,1),rhom(d2,g2,1))
  
  evals, evecs = np.linalg.eigh(rhofull)
  # fix eigenvector v0 
  v0 = evecs[:,0]
  
  # symmetry operator Un = np.kron(pauli(1),pauli(1))
  Un = np.kron(pauli(1),pauli(1))
  # check if the second eigenvector evecs[:,1] is orthogonal to Un@evecs[:,1]
  x = evecs[:,1]@(Un@evecs[:,1])
 
  #print("evals")
  #print(evals)
  if np.abs(x)<10e-10:
    v1 = evecs[:,1]
  # if not orthogonal, combine evecs[:,1], evecs[:,2] to create v1 s.t. v1@(Un@v1)= 0
  if np.abs(x)>=10e-10:
    vec1, vec2 = evecs[:,1], evecs[:,2]
    b11, b12, b21, b22 = np.inner(vec1, (Un@vec1).conj()), np.inner(vec2, (Un@vec1).conj()), np.inner(vec1, (Un@vec2).conj()), np.inner(vec2, (Un@vec2).conj())
    
    alpha = (-b12-b21-np.sqrt((b12+b21)**2-4*b11*b22))/(2*b22)
    v1 = vec1+alpha*vec2
    v1 = v1/np.sqrt(v1@v1)
    ##
  #print("v0")
  #print(v0)
  #print("v1")
  #print(v1)
  ##
  # the paired measurement is then given by {|v0><v0| + Un|v0><v0|Un, |v1><v1| + Un|v1><v1|Un}
  ## find new state parameters (d1a, g1a) for measurement outcome 0
  # find probability p0 of observing measurement  outcome 0
  p0 = v0@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@v0+v0@Un@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@Un@v0
  d1a, g1a = v0@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@v0/(p0+10e-21), v0@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@(Un@v0)/(p0+10e-21)
  ## find new state parameters (d1b, g1b) for measurement outcome 1
  # find probability p1 of observing measurement  outcome 1  
  p1 = v1@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@v1+v1@Un@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@Un@v1
  d2a, g2a = v1@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@v1/(p1+10e-21), v1@np.kron(rhom(d1,g1,0),rhom(d2,g2,0))@(Un@v1)/(p1+10e-21)
  #print("pvec")
  #print([p0, p1])
  return np.array([[1-d1a, g1a], [1-d2a, g2a]]), np.array([p0, p1])

# "checknode" take input states parameterized by (d1, g1) and (d2, g2) and returns the post-measurement state parameters and probabilities at a checknode
def checknode(d1,d2,g1,g2):
  # rho0, rho1 correspond to the states at a check node when z=0 (z=1) respectively
  rho0, rho1 = 1/2*(np.kron(rhom(d1,g1,0),rhom(d2,g2,0)) + np.kron(rhom(d1,g1,1),rhom(d2,g2,1))), 1/2*(np.kron(rhom(d1,g1,0),rhom(d2,g2,1)) + np.kron(rhom(d1,g1,1),rhom(d2,g2,0)))
  # for check node combining, the optimal choice of eigenvectors appears to always be generated by v0 and v1
  v0 = 1/np.sqrt(2)*np.array([1,0,0,1])
  v1 = 1/np.sqrt(2)*np.array([-1, 0, 0, 1])
  # symmetry operator for a check node
  Un = np.kron(pauli(1),pauli(0))
  ## find new state parameters (d1a, g1a) for measurement outcome 0
  # find probability p0 of observing measurement  outcome 0
  p0 = v0@rho0@v0+v0@Un@rho0@Un@v0
  d1a, g1a = v0@rho0@v0/(p0+10e-21), v0@rho0@(Un@v0)/(p0+10e-21)
  ## find new gamma, delta for second outcome
  p1 = v1@rho0@v1+v1@Un@rho0@Un@v1
  d2a, g2a = v1@rho0@v1/(p1+10e-21), v1@rho0@(Un@v1)/(p1+10e-21)
  # return new gamma, delta pairs as well as respective probabilities tra and tr2a
  # d1a=min([d1a,1-d1a])
  # d2a=min([d2a,1-d2a])
  return np.array([[d1a, g1a], [d2a, g2a]]), np.array([p0, p1])

H = 1/np.sqrt(2)*np.array([[1,-1],[1,1]])
sgz = np.array([[1,0],[0,-1]])
sgx = np.array([[0,1],[1,0]])
 
# pure state channel outputs
def rho(th):
  return H@np.array([[np.cos(th/2)**2, np.cos(th/2)*np.sin(th/2)],
                   [np.cos(th/2)*np.sin(th/2), np.sin(th/2)**2]])@np.transpose(H)

# output of the channel given input z, parameter th, and flip parameter p
def W(z,th,p):
  if z == 0:
    return (1-p)*rho(th)+p*np.array([[1/2, 0], [0, 1/2]])
  if z == 1:
    return (1-p)*sgx@rho(th)@sgx+p*np.array([[1/2, 0], [0, 1/2]])


In [4]:
def helstrom(density_mat1,unitary):
  r1=density_mat1
  u=unitary
  r2=u@r1@np.conjugate(np.transpose(u))
  l,v=LA.eig(r1-r2)
  return l,v


def helstrom_success(rho1,unitary):
  l,vec=helstrom(rho1,unitary)
  u=unitary
  rho2=u@rho1@np.conjugate(np.transpose(u))
  v_pos_eig = np.array(vec[:,l>0])
  v_pos_eigh= np.conjugate(np.transpose(v_pos_eig))
  p1=np.trace(v_pos_eigh @ rho1 @ v_pos_eig)
  p2=1-np.trace(v_pos_eigh@ rho2 @ v_pos_eig)
  return 0.5*(p1+p2)

def helstrom_success_vec(X):
  o=[]
  l=len(X[0])
  for i in range(l):
    d,g=X[0][i],X[1][i]
    r=rhom(d,g,0)
    h=helstrom_success(r,pauli(1))
    o.append(h)
  return o

In [5]:
@njit
def rhom_jit(delta:float64,g:float64) -> (float64[:,:]):
    return np.array([[delta, g],[g, 1-delta]],dtype=float64)

# "bitnode" take input states parameterized by (d1, g1) and (d2, g2) and returns the post-measurement state parameters and probabilities at a bitnode
@njit
def bitnode_jit(d1:float64,d2:float64,g1:float64,g2:float64) ->(float64[:,:],float64[:]):
  # the state given root value z=0 is rho0 = np.kron(rhom(d1,g1,0),rhom(d2,g2,0))
  # the state given root value z=1 is rho1 = np.kron(rhom(d1,g1,1),rhom(d2,g2,1))
  
  # we find the paired measurements by taking the eigenvectors of the difference matrix rhofull
  x=np.array([[0,1],[1,0]],dtype=float64)
  
  rhofull = np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))-np.kron(x@rhom_jit(d1,g1)@x,x@rhom_jit(d2,g2)@x)
 
  evals, evecs = np.linalg.eigh(rhofull)
  # fix eigenvector v0 
  v0 = evecs[:,0]
  
  # symmetry operator Un = np.kron(pauli(1),pauli(1))
  Un = np.kron(x,x)
  # check if the second eigenvector evecs[:,1] is orthogonal to Un@evecs[:,1]
  x1 = evecs[:,1]@(Un@evecs[:,1])
 
  #print("evals")
  #print(evals)
  if np.abs(x1)<10e-10:
    v1 = evecs[:,1]
  # if not orthogonal, combine evecs[:,1], evecs[:,2] to create v1 s.t. v1@(Un@v1)= 0
  if np.abs(x1)>=10e-10:
    vec1, vec2 = evecs[:,1], evecs[:,2]
    b11, b12, b21, b22 = np.dot(vec1, (Un@vec1).conj()), np.dot(vec2, (Un@vec1).conj()), np.dot(vec1, (Un@vec2).conj()), np.dot(vec2, (Un@vec2).conj())
    
    alpha = (-b12-b21-np.sqrt((b12+b21)**2-4*b11*b22))/(2*b22)
    v1 = vec1+alpha*vec2
    v1 = v1/np.sqrt(v1@v1)
    ##
  #print("v0")
  #print(v0)
  #print("v1")
  #print(v1)
  ##
  # the paired measurement is then given by {|v0><v0| + Un|v0><v0|Un, |v1><v1| + Un|v1><v1|Un}
  ## find new state parameters (d1a, g1a) for measurement outcome 0
  # find probability p0 of observing measurement  outcome 0
  p0 = v0@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@v0+v0@Un@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@Un@v0
  d1a, g1a = v0@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@v0/(p0+10e-21), v0@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@(Un@v0)/(p0+10e-21)
  ## find new state parameters (d1b, g1b) for measurement outcome 1
  # find probability p1 of observing measurement  outcome 1  
  p1 = v1@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@v1+v1@Un@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@Un@v1
  d2a, g2a = v1@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@v1/(p1+10e-21), v1@np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2))@(Un@v1)/(p1+10e-21)
  #print("pvec")
  #print([p0, p1])
  d1a=min([d1a,1-d1a])
  d2a=min([d2a,1-d2a])
  return np.array([[d1a, g1a], [d2a, g2a]]), np.array([p0, p1])

# "checknode" take input states parameterized by (d1, g1) and (d2, g2) and returns the post-measurement state parameters and probabilities at a checknode
@njit
def checknode_jit(d1:float64,d2:float64,g1:float64,g2:float64) ->(float64[:,:],float64[:]):
  x=np.array([[0,1],[1,0]],dtype=float64)
  I=np.array([[1,0],[0,1]],dtype=float64)
  # rho0, rho1 correspond to the states at a check node when z=0 (z=1) respectively
  rho0, rho1 = 1/2*(np.kron(rhom_jit(d1,g1),rhom_jit(d2,g2)) + np.kron(x@rhom_jit(d1,g1)@x,x@rhom_jit(d2,g2)@x)), 1/2*(np.kron(rhom_jit(d1,g1),x@rhom_jit(d2,g2)@x) + np.kron(x@rhom_jit(d1,g1)@x,rhom_jit(d2,g2)))
  # for check node combining, the optimal choice of eigenvectors appears to always be generated by v0 and v1
  v0 = 1/np.sqrt(2)*np.array([1,0,0,1])
  v1 = 1/np.sqrt(2)*np.array([-1, 0, 0, 1])
  # symmetry operator for a check node
  Un = np.kron(x,I)
  ## find new state parameters (d1a, g1a) for measurement outcome 0
  # find probability p0 of observing measurement  outcome 0
  p0 = v0@rho0@v0+v0@Un@rho0@Un@v0
  d1a, g1a = v0@rho0@v0/(p0+10e-21), v0@rho0@(Un@v0)/(p0+10e-21)
  ## find new gamma, delta for second outcome
  p1 = v1@rho0@v1+v1@Un@rho0@Un@v1
  d2a, g2a = v1@rho0@v1/(p1+10e-21), v1@rho0@(Un@v1)/(p1+10e-21)
  # return new gamma, delta pairs as well as respective probabilities tra and tr2a
  d1a=min([d1a,1-d1a])
  d2a=min([d2a,1-d2a])
  return np.array([[d1a, g1a], [d2a, g2a]]), np.array([p0, p1])
  


In [6]:
@njit
def bitnode_jit2(d1:float64,d2:float64,g1:float64,g2:float64,pr=None)->(float64[:,:]):
  rho,pb=bitnode_jit(d1,d2,g1,g2)
  if pr==None:
    s=pb[0]
  else:
    s=pr
  o=int(np.random.random()>s)
  #choice([0,1],p=[s,1-s])
  if o==0:
    return rho[0]
  else:
    return rho[1]

@njit
def checknode_jit1(d1:float64,d2:float64,g1:float64,g2:float64,pr=None)->(float64[:,:]):
  rho,pc=checknode_jit(d1,d2,g1,g2)
  if pr==None:
    s=pc[0]
  else:
    s=pr
  o=int(np.random.random()>s)
  if o==0:
    return rho[0]
  else:
    return rho[1]

@njit
def bitnode_vec_jit(d1:float64[:],d2:float64[:],g1:float64[:],g2:float64[:],pr_vec=None,perm=None)->(float64[:],float64[:]):
    l=np.shape(d1)[0]
    if perm==None:
      p=np.random.permutation(l)
    else:
      p=perm
    d=np.zeros(l)
    g=np.zeros(l)
    if pr_vec==None:
      for i in range(l):
        d[i],g[i]=bitnode_jit2(d1[i],d2[p[i]],g1[i],g2[p[i]])
        
    else:
      for i in range(l):
        d[i],g[i]=bitnode_jit2(d1[i],d2[p[i]],g1[i],g2[p[i]],pr_vec[i])
        
    return d,g
@njit
def checknode_vec_jit(d1:float64[:],d2:float64[:],g1:float64[:],g2:float64[:],pr_vec=None,perm=None)->(float64[:],float64[:]):
    l=np.shape(d1)[0]
    if perm==None:
      p=np.random.permutation(l)
    else:
      p=perm
    d=np.zeros(l)
    g=np.zeros(l)
    if pr_vec==None:
      for i in range(l):
        d[i],g[i]=checknode_jit1(d1[i],d2[p[i]],g1[i],g2[p[i]])
        
    else:
      for i in range(l):
        d[i],g[i]=checknode_jit1(d1[i],d2[p[i]],g1[i],g2[p[i]],pr_vec[i])
        
    return d,g
    
    
@njit   
def bitnode_power_jit(d:float64[:],g:float64[:],k:int64,pr_vec=None,perm=None)->(float64[:],float64[:]):
  if k==1:
    return d,g
  else:
    d1,g1=bitnode_vec_jit(d,d,g,g,pr_vec,perm)
    if k>2:
      for i in range(k-2):
        d1,g1=bitnode_vec_jit(d,d1,g,g1,pr_vec,perm)
    return d1,g1

  
@njit
def checknode_power_jit(d:float64[:],g:float64[:],k:int64,pr_vec=None,perm=None)->(float64[:],float64[:]):
  if k==1:
    return d,g
  else:
    d1,g1=checknode_vec_jit(d,d,g,g,pr_vec,perm)
    if k>2:
      for i in range(k-2):
        d1,g1=checknode_vec_jit(d,d1,g,g1,pr_vec,perm)
    return d1,g1


In [7]:
@njit
def polar(n:int64,d:float64[:],g=float64[:])->(float64[:,:,:]):
 # s=0
  #for i in range(n+1):
  #  s=s+2**i
  
  old_stage=np.zeros(shape=(1,2,len(d)),dtype=float64)
  old_stage[0][0],old_stage[0][1]=d,g
 # o=np.zeros(shape=(s,2,len(x)))
  for i in range(n):
    new_stage=np.zeros(shape=(2**(i+1),2,len(d)))
    for j in range(2**(i)):
      new_stage[2*j][0],new_stage[2*j][1]=checknode_vec_jit(old_stage[j][0],old_stage[j][0],old_stage[j][1],old_stage[j][1])
      new_stage[2*j+1][0],new_stage[2*j+1][1]=bitnode_vec_jit(old_stage[j][0],old_stage[j][0],old_stage[j][1],old_stage[j][1])
    old_stage=new_stage
  return old_stage


def polar_design(biterrd,d):
    '''
    Design polar code from additive channel noise scores

          Arguments:
                  biterrd (float[:]): Numpy array of channel noise scores (e.g., error rates)
                  d (float): Sum constraint of total noise for unfrozen channels

          Returns:
                  f (float[:]): Numpy array indicating frozen bits by 0 and info bits by 0.5
    '''
    # Sort into increasing order and compute cumulative sum
    order = np.argsort(biterrd)
    SE = biterrd[order]
    CSE = np.cumsum(SE)

    # Find best frozen bits
    k = np.sum(CSE<d)
    f = np.zeros(len(biterrd))
    f[order[:k]] = 0.5
    return f


In [8]:
def bitnode_gen_unitary(d1,d2,g1,g2):
  rhofull = np.kron(rhom(d1,g1,0),rhom(d2,g2,0))-np.kron(rhom(d1,g1,1),rhom(d2,g2,1))
  
  evals, evecs = LA.eigh(rhofull)
  # fix eigenvector v0 
  v0 = evecs[:,0]
  
  # symmetry operator Un = np.kron(pauli(1),pauli(1))
  Un = np.kron(pauli(1),pauli(1))
  # check if the second eigenvector evecs[:,1] is orthogonal to Un@evecs[:,1]
  x = evecs[:,1]@(Un@evecs[:,1])
 
  #print("evals")
  #print(evals)
  if np.abs(x)<10e-10:
    v1 = evecs[:,1]
  # if not orthogonal, combine evecs[:,1], evecs[:,2] to create v1 s.t. v1@(Un@v1)= 0
  if np.abs(x)>=10e-10:
    vec1, vec2 = evecs[:,1], evecs[:,2]
    b11, b12, b21, b22 = np.inner(vec1, (Un@vec1).conj()), np.inner(vec2, (Un@vec1).conj()), np.inner(vec1, (Un@vec2).conj()), np.inner(vec2, (Un@vec2).conj())
    
    alpha = (-b12-b21-np.sqrt((b12+b21)**2-4*b11*b22))/(2*b22)
    v1 = vec1+alpha*vec2
    v1 = v1/np.sqrt(v1@v1)
  vb=np.zeros((4,4))
  vb[:,0]=Un@v0
  vb[:,1]=Un@v1
  vb[:,2]=v0
  vb[:,3]=v1
  return np.transpose(vb)

def checknode_gen_unitary(d1,d2,g1,g2):
  v0 = 1/np.sqrt(2)*np.array([1,0,0,1])
  v1 = 1/np.sqrt(2)*np.array([-1, 0, 0, 1])
  ux=np.kron(pauli(1),np.eye(2))
  vcf=np.zeros((4,4))
  vcf[:,0]=v0
  vcf[:,1]=v1
  vcf[:,2]=ux@v0
  vcf[:,3]=ux@v1
  return np.transpose(vcf)

CU=checknode_gen_unitary(0.1,0.1,0,0)

In [9]:
def apply_permutation(qubit_list,number_qubits,conditional_list1=[],conditional_list2=[]):
  n=number_qubits
  l=len(qubit_list)
  l1=len(conditional_list1)
  l2=len(conditional_list2)
  p=np.zeros(n,dtype=int)
  for j in range(l):
    p[qubit_list[j]]=l1+l2++j
  for j in range(l1):
    p[conditional_list1[j]]=j
  for j in range(l2):
    p[conditional_list2[j]]=l1+j
  c=0
  for j in range(n):
    if j not in qubit_list and j not in conditional_list1 and j not in conditional_list2:
      p[j]=l+l1+l2+c
      c=c+1
  p1=np.zeros((len(p),len(p)))
  for i in range(len(p)):
    p1[:,i][p[i]]=1
  p2=p1.T
  pi=np.zeros(n,dtype=int)
  for i in range(n):
    pi[i]=int(np.where(p1.T[:,i]==1)[0][0])
  return p,pi

In [10]:
def apply_unitary(qubit_list,unitary,q_state,number_qubits):
  n=number_qubits
  l=len(qubit_list)
  u=unitary
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  p,pi=apply_permutation(qubit_list,n)
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2**l,2**(n-l)))
  q_out=u@q2
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),pi)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out

def apply_unitary_inverse(qubit_list,unitary,q_state,number_qubits):
  n=number_qubits
  l=len(qubit_list)
  u=unitary
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  p,pi=apply_permutation(qubit_list,n)
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2**l,2**(n-l)))
  q_out=np.transpose(u)@q2
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),pi)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out

In [11]:
def apply_conditional_unitary(qubit_list,conditional_list1,conditional_list2,dg_list1,dg_list2,q_state,number_qubits):
  n=number_qubits
  l=len(qubit_list)
  l1=len(conditional_list1)
  l2=len(conditional_list2)
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  p=np.zeros(n,dtype=int)
  p,pi=apply_permutation(qubit_list,n,conditional_list1,conditional_list2)
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2**l1,2**l2,2**l,2**(n-l1-l2-l)))
  q_out=np.zeros((2**l1,2**l2,2**l,2**(n-l1-l2-l)))
  for i in range(2**l1):
    for j in range(2**l2):
      q_out[i,j,:]=bitnode_gen_unitary(dg_list1[i][0],dg_list2[j][0],dg_list1[i][1],dg_list2[j][1])@q2[i,j,:]
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),pi)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out
def apply_conditional_unitary_inverse(qubit_list,conditional_list1,conditional_list2,dg_list1,dg_list2,q_state,number_qubits):
  n=number_qubits
  l=len(qubit_list)
  l1=len(conditional_list1)
  l2=len(conditional_list2)
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  p=np.zeros(n,dtype=int)
  p,pi=apply_permutation(qubit_list,n,conditional_list1,conditional_list2)
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2**l1,2**l2,2**l,2**(n-l1-l2-l)))
  q_out=np.zeros((2**l1,2**l2,2**l,2**(n-l1-l2-l)))
  for i in range(2**l1):
    for j in range(2**l2):
      q_out[i,j,:]=np.transpose(bitnode_gen_unitary(dg_list1[i][0],dg_list2[j][0],dg_list1[i][1],dg_list2[j][1]))@q2[i,j,:]
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),pi)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out

In [12]:
def sample_state(l,v):
  o=int(np.random.random()>l[0])
  if o==0:
    return v[:,0]
  else:
    return v[:,1]
def sample_state_one(l,v):
  o=int
  pass
def sample_joint_state(l,v,number_qubits):
  if number_qubits==1:
    return sample_state(l,v)
  else:
    return np.kron(sample_joint_state(l,v,number_qubits-1),sample_state(l,v))


def sample_joint_random_state(l,v,v1,number_qubits,bit_sequence):
  b=bit_sequence
  if b[0]==0:
    s=sample_state(l,v)
  else:
    s=sample_state(l,v1)
  for i in range(1,number_qubits):
    if b[i]==0:
      s=np.kron(s,sample_state(l,v))
    else:
      s=np.kron(s,sample_state(l,v1))
  return s

In [13]:
def cnop(quantum_state,informtion_qubit1,information_qubit2,conditional_list1,conditional_list2,number_qubits):
  qubit_list=np.array([informtion_qubit1,information_qubit2])
  q_out=apply_unitary(qubit_list,CU,quantum_state,number_qubits)
  l1=len(conditional_list1[0])
  l2=len(conditional_list2[0])
  conditional_list=[[information_qubit2],[]]
  c1=[]
  c2=[]
  for i in range(2**l1):
    for j in range(2**l2):
      rc=checknode(conditional_list1[1][i][0],conditional_list2[1][j][0],conditional_list1[1][i][1],conditional_list2[1][j][1])[0]
      c1.append(rc[0])
      c2.append(rc[1])
  conditional_list[1]=c1+c2

  for i in range(l1):
    conditional_list[0].append(conditional_list1[0][i])
  for j in range(l2):
    conditional_list[0].append(conditional_list2[0][j])
  return q_out, conditional_list

def cnop_inverse(quantum_state,informtion_qubit1,information_qubit2,conditional_list1,conditional_list2,number_qubits):
  qubit_list=np.array([informtion_qubit1,information_qubit2])
  q_out=apply_unitary_inverse(qubit_list,CU,quantum_state,number_qubits)
  return q_out

In [14]:
def vnop(quantum_state,informtion_qubit1,information_qubit2,conditional_list1,conditional_list2,number_qubits):
  qubit_list=np.array([informtion_qubit1,information_qubit2])
  l1=len(conditional_list1[0])
  l2=len(conditional_list2[0])
  if l1==0 and l2==0:
    q_out=apply_unitary(qubit_list,bitnode_gen_unitary(conditional_list1[1][0][0],conditional_list2[1][0][0],conditional_list1[1][0][1],conditional_list2[1][0][1]),quantum_state,number_qubits)
  else:
    q_out=apply_conditional_unitary(qubit_list,conditional_list1[0],conditional_list2[0],conditional_list1[1],conditional_list2[1],quantum_state,number_qubits)
  conditional_list=[[information_qubit2],[]]
  c1=[]
  c2=[]
  for i in range(2**l1):
    for j in range(2**l2):
      rc=bitnode(conditional_list1[1][i][0],conditional_list2[1][j][0],conditional_list1[1][i][1],conditional_list2[1][j][1])[0]
      c1.append(rc[0])
      c2.append(rc[1])
  conditional_list[1]=c1+c2

  for i in range(l1):
    conditional_list[0].append(conditional_list1[0][i])
  for j in range(l2):
    conditional_list[0].append(conditional_list2[0][j])
  return q_out, conditional_list

def vnop_inverse(quantum_state,informtion_qubit1,information_qubit2,conditional_list1,conditional_list2,number_qubits):
  qubit_list=np.array([informtion_qubit1,information_qubit2])
  l1=len(conditional_list1[0])
  l2=len(conditional_list2[0])
  if l1==0 and l2==0:
    q_out=apply_unitary_inverse(qubit_list,bitnode_gen_unitary(conditional_list1[1][0][0],conditional_list2[1][0][0],conditional_list1[1][0][1],conditional_list2[1][0][1]),quantum_state,number_qubits)
  else:
    q_out=apply_conditional_unitary_inverse(qubit_list,conditional_list1[0],conditional_list2[0],conditional_list1[1],conditional_list2[1],quantum_state,number_qubits)
  return q_out


In [15]:
#returns Helstrom error probability after measuring first qubit of joint qubit state in some arbitrary basis
def measure_first_qubit(quantum_state,number_qubits,projection_operator):
  q=quantum_state
  n=number_qubits
  pi=projection_operator
  q1=np.reshape(q,(2,2**(n-1)))
  q2=pi@q1
  q3=np.reshape(q2,2**n)
  return q3@q


def apply_projection_first_qubit(quantum_state,number_qubits,projection_operator):
  q=quantum_state
  n=number_qubits
  pi=projection_operator
  q1=np.reshape(q,(2,2**(n-1)))
  q2=pi@q1
  q3=np.reshape(q2,2**n)
  return q3


In [16]:
def apply_flip(q_state,number_qubits,index):
  n=number_qubits
  x=np.array([[0,1],[1,0]])
  p=np.array(range(n))
  p[0]=index
  p[index]=0
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2,2**(n-1)))
  q_out=x@q2
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),p)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out

def hard_decision_flip(q_state,number_qubits,hard_decision,index):
  if hard_decision==1:
    return apply_flip(q_state,number_qubits,index)
  else:
    return q_state

In [37]:
# def polar_transform(u):
#   if len(u)==1:
#     x=u
#   else:
#     u1=np.mod(u[0::2]+u[1::2],2)
#     u2=u[1::2]
#     x=np.concatenate([polar_transform(u1),polar_transform(u2)])
#   return x

@njit('(int64[:])(int64[:])') # Input/output specifications to make Numba work
def polar_transform(u):
    '''
    Encode polar information vector u

          Arguments:
                  u (int64[:]): Numpy array of input bits

          Returns:
                  x (int64[:]): Numpy array of encoded bits
    '''
    # Recurse down to length 1
    if (len(u)==1):
        return u;
    else:
        # R_N maps odd/even indices (i.e., u1u2/u2) to first/second half
        # Compute odd/even outputs of (I_{N/2} \otimes G_2) transform
        x = np.zeros(len(u), dtype=np.int64)
        x[:len(u)//2] = polar_transform((u[::2]+u[1::2])%2)
        x[len(u)//2:] = polar_transform(u[1::2])
        return x

In [19]:
def polar_decoder_cq_output(yi,yd,quantum_state,number_qubits,frozen_set):
  q=quantum_state
  N=len(yi)
  f=frozen_set
  #pi=np.kron(np.outer(qubit(0),qubit(0)),np.eye(2**number_qubits//2))
  pi=np.outer(qubit(0),qubit(0))
  if N==1:
    if f[0]==1:
      p=measure_first_qubit(q,number_qubits,pi)
      o=int(np.random.random()>p)
      if o==0:
        q=apply_projection_first_qubit(q,number_qubits,pi)
        return o
      else:
        q=apply_projection_first_qubit(q,number_qubits,np.eye(2)-pi)
        return o
    else:
      return f[0]

  yd1=[]
  for i in range(0,N,2):
    q,list1=cnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd1.append(list1)
  out1=polar_decoder_cq_output(yi[::2],yd1,q,number_qubits,f[:(N//2)])
  for i in range(0,N,2):
    q=cnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)

  yd2=[]
  for i in range(0,N,2):
    q,list2=vnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd2.append(list2)
  out2=polar_decoder_cq_output(yi[::2],yd2,q,number_qubits,f[N//2:])

  for i in range(0,N,2):
    q=vnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)

  out=np.hstack([out1,out2])
  return out


In [20]:
def polar_decoder_cq_out_avg_error(dp,gp,number_samples,number_qubits,frozen_set):
  r=rhom(dp,gp,0)
  l,v=LA.eigh(r)
  list1=[[],[[dp,gp]]]
  yi=np.array(range(number_qubits),dtype=int)
  yd=[]
  f=frozen_set
  for i in range(number_qubits):
    yd.append(list1)

  h=np.zeros(number_qubits)

  for i in range(number_samples):
    q=sample_joint_state(l,v,number_qubits)
    h=h+polar_decoder_cq_output(yi,yd,q,number_qubits,f)
  return h/number_samples

In [305]:
def polar_decoder_cq(yi,yd,quantum_state,number_qubits,frozen_set):
  q=quantum_state
  N=len(yi)
  #pi=np.kron(np.outer(qubit(0),qubit(0)),np.eye(2**number_qubits//2))
  pi=np.outer(qubit(0),qubit(0))
  
  f=frozen_set
  if N==1:
    p1=measure_first_qubit(q,number_qubits,pi)
    o=int(np.random.random()>p1)
    q=hard_decision_flip(q,number_qubits,f[0],0)
    # print(q,'flip')
    # print(q)
    #if o==1:
     # q=apply_flip(q,number_qubits,0)
    #o=np.random.choice([0,1],p=[p1,1-p1])
    # if f[0]==1:
      
    #   q=apply_flip(q,number_qubits,0) 
    # else:
    #   print(1)
    return o,q
  yd1=[]
  for i in range(0,N,2):
    q,list1=cnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd1.append(list1)
  # print(q,'cnop')
  out1,q=polar_decoder_cq(yi[::2],yd1,q,number_qubits,f[:(N//2)])

  
  for i in range(0,N,2):
    q=cnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
  # print(q,'cnop_inverse')  
#   hard1=f[:(N//2)]
#   for i in range(0,N,2):
#     q=hard_decision_flip(q,number_qubits,hard1[i//2],yi[i])
    
  yd2=[]
  for i in range(0,N,2):
    q,list2=vnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd2.append(list2)
  # print(q,'vnop')
  out2,q=polar_decoder_cq(yi[::2],yd2,q,number_qubits,f[(N//2):])
  
  for i in range(0,N,2):
    q=vnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
  # print(q,'vnop_inverse')
  # for i in range(0,N,2):
  #   q=hard_decision_flip(q,number_qubits,hard1[i//2],yi[i])
  out=np.hstack([out1,out2])
  return out,q


In [306]:
def polar_decoder_cq_avg_error(dp,gp,number_samples,number_qubits,codeword):
  r=rhom(dp,gp,0)
  encoded_bits=polar_transform(np.array(codeword))
  x=np.array([[0,1],[1,0]])
  l,v=LA.eigh(r)
  l1,v1=LA.eigh(x@r@x)
  list1=[[],[[dp,gp]]]
  
  yi=np.array(range(number_qubits),dtype=int)
  yd=[]
  for i in range(number_qubits):
    yd.append(list1)
    

  h=np.zeros(number_qubits)

  for i in range(number_samples):
    q=sample_joint_random_state(l,v,v1,number_qubits,encoded_bits)
    h1,q1=polar_decoder_cq(yi,yd,q,number_qubits,codeword)
    h=h+np.mod(h1+codeword,2)
               
  return h/number_samples

In [307]:
h=polar_decoder_cq_avg_error(0.1,0.1,1000,8,[0,0,0,0,0,0,0,0])

In [308]:
print(h)

[0.423 0.28  0.296 0.068 0.271 0.053 0.052 0.004]


In [267]:
dp=0.1
gp=0.1
n=3
d=np.ones(20000)*dp
g=np.ones(20000)*gp
p=polar(n,d,g)
#print(p.shape)
ww = np.mean(p,axis=2)[:,0]
#print(ww)
#f=polar_design(ww,0.001)
# print('rate for code with given design constraint:',np.mean(f==0.5))
for i in range(2**n):
  d=np.mean(p[i][0])
  g=np.mean(p[i][1])
  r1=rhom(d,g,0)
  h=helstrom_success(r1,pauli(1))
  print(f'Error rate for channel {i+1}:', 1-h)

Error rate for channel 1: 0.41604365268962873
Error rate for channel 2: 0.27872502680798883
Error rate for channel 3: 0.2796547846448365
Error rate for channel 4: 0.08216633626864445
Error rate for channel 5: 0.27888671816715527
Error rate for channel 6: 0.07190103587830066
Error rate for channel 7: 0.058029348294409755
Error rate for channel 8: 0.0033242216626361287


In [244]:
dp,gp=0.1,0.1
r=rhom(dp,gp,0)
x=np.array([[0,1],[1,0]])
l,v=LA.eigh(r)
l1,v1=LA.eigh(x@r@x)
pi=np.outer(qubit(0),qubit(0))
print(v[:,1]@pi@v[:,1])
print(v1[:,1]@pi@v1[:,1])
x=polar_transform(np.array([1,1,1,1]))
print(x)

0.014928749927334057
0.9850712500726662
[0 0 0 1]


In [245]:
q=np.kron(np.kron(v[:,1],v[:,1]),np.kron(v[:,1],v1[:,1]))
pi1=np.kron(pi,np.eye(8))

In [246]:
q1=apply_unitary([0,1],CU,q,4)
q2=apply_unitary([2,3],CU,q1,4)
q3=apply_unitary([0,2],CU,q2,4)
print(q3@pi1@q3)

0.05709342560553635


In [247]:
flip1=np.kron(pauli(1),np.eye(8))

#q4=flip1@q3
q4=apply_flip(q3,4,0)
q5=apply_unitary_inverse([0,2],CU,q4,4)

In [248]:
rc,pc=checknode(dp,dp,gp,gp)
dc1=rc[0][0]
dc2=rc[1][0]
gc1=rc[0][1]
gc2=rc[1][1]

q6=apply_conditional_unitary([0,2],[1],[3],np.array([[dc1,gc1],[dc2,gc2]]),np.array([[dc1,gc1],[dc2,gc2]]),q5,4)
print(q6@pi1@q6)

0.004206649156410568


In [249]:
q7=flip1@q6
q7=apply_conditional_unitary_inverse([0,2],[1],[3],np.array([[dc1,gc1],[dc2,gc2]]),np.array([[dc1,gc1],[dc2,gc2]]),q7,4)


q8=apply_unitary_inverse([2,3],CU,q7,4)
q9=apply_unitary_inverse([0,1],CU,q8,4)
V=bitnode_gen_unitary(dp,dp,gp,gp)
q10=apply_unitary([0,1],V,q9,4)
q11=apply_unitary([2,3],V,q10,4)
q12=apply_unitary([0,2],CU,q11,4)
print(q12@pi1@q12)

0.002548782266514117


In [250]:
q12=flip1@q12
q13=apply_unitary_inverse([0,2],CU,q12,4)
rb,pb=bitnode(dp,dp,gp,gp)
db1=rb[0][0]
db2=rb[1][0]
gb1=rb[0][1]
gb2=rb[1][1]
q14=apply_conditional_unitary([0,2],[1],[3],np.array([[db1,gb1],[db2,gb2]]),np.array([[db1,gb1],[db2,gb2]]),q13,4)
print(q14@pi1@q14)

7.633160611898096e-06


In [252]:
list1=[[],[[dp,gp]]]
yi=np.array(range(4),dtype=int)
yd=[]
for i in range(4):
  yd.append(list1)
print(polar_decoder_cq(yi,yd,q,4,[1,1,1,1])[0])

[5.70934256e-02 4.20664916e-03 2.54878227e-03 7.63316061e-06]


In [273]:
def polar_decoder_cq_output(yi,yd,quantum_state,number_qubits,info_set,codeword):
  q=quantum_state
  N=len(yi)
  #pi=np.kron(np.outer(qubit(0),qubit(0)),np.eye(2**number_qubits//2))
  pi=np.outer(qubit(0),qubit(0))
  
  f=info_set
  if N==1:
    if f[0]==1:
        p1=measure_first_qubit(q,number_qubits,pi)
        o=int(np.random.random()>p1)
        if o==0:
            q=apply_projection_first_qubit(q,number_qubits,pi)
            return o,q
        else:
            q=apply_projection_first_qubit(q,number_qubits,np.eye(2)-pi)
            q=apply_flip(q,number_qubits,0)
            return o,q
    else:
        q=hard_decision_flip(q,number_qubits,codeword[0],0)
        return codeword[0],q
    # print(q,'flip')
    # print(q)
    #if o==1:
     # q=apply_flip(q,number_qubits,0)
    #o=np.random.choice([0,1],p=[p1,1-p1])
    # if f[0]==1:
      
    #   q=apply_flip(q,number_qubits,0) 
    # else:
    #   print(1)
    return o,q
  yd1=[]
  for i in range(0,N,2):
    q,list1=cnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd1.append(list1)
  # print(q,'cnop')
  out1,q=polar_decoder_cq_output(yi[::2],yd1,q,number_qubits,f[:(N//2)],codeword[:(N//2)])

  
  for i in range(0,N,2):
    q=cnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
  # print(q,'cnop_inverse')  
#   hard1=f[:(N//2)]
#   for i in range(0,N,2):
#     q=hard_decision_flip(q,number_qubits,hard1[i//2],yi[i])
    
  yd2=[]
  for i in range(0,N,2):
    q,list2=vnop(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
    yd2.append(list2)
  # print(q,'vnop')
  out2,q=polar_decoder_cq_output(yi[::2],yd2,q,number_qubits,f[(N//2):],codeword[(N//2):])
  
  for i in range(0,N,2):
    q=vnop_inverse(q,yi[i],yi[i+1],yd[i],yd[i+1],number_qubits)
  # print(q,'vnop_inverse')
  # for i in range(0,N,2):
  #   q=hard_decision_flip(q,number_qubits,hard1[i//2],yi[i])
  out=np.hstack([out1,out2])
  return out,q


In [279]:
def polar_decoder_cq_output_avg_error(dp,gp,number_samples,number_qubits,info_set,codeword):
  r=rhom(dp,gp,0)
  encoded_bits=polar_transform(np.array(codeword))
  x=np.array([[0,1],[1,0]])
  l,v=LA.eigh(r)
  l1,v1=LA.eigh(x@r@x)
  list1=[[],[[dp,gp]]]
  
  yi=np.array(range(number_qubits),dtype=int)
  yd=[]
  for i in range(number_qubits):
    yd.append(list1)
    

  h=np.zeros(number_qubits)
  block=0

  for i in range(number_samples):
    q=sample_joint_random_state(l,v,v1,number_qubits,encoded_bits)
    h1,q1=polar_decoder_cq_output(yi,yd,q,number_qubits,info_set,codeword)
    h=h+np.mod(h1+codeword,2)
    if np.sum(np.mod(h1+codeword,2))>=1:
        block=block+1
               
  return h/number_samples,block/number_samples

In [342]:
h,block=polar_decoder_cq_output_avg_error(0.05,0.15,500,8,[0,0,0,1,0,1,1,1],[1,1,1,1,1,1,0,1])

In [343]:
print(h)

[0.    0.    0.    0.012 0.    0.014 0.06  0.002]


In [344]:
print(block)

0.064


In [336]:
h,block=polar_decoder_cq_output_avg_error(0.05,0.15,1000,4,[0,0,1,1],[1,0,1,0])

In [337]:
print(h)

[0.    0.    0.061 0.049]


In [338]:
print(block)

0.075
