In [20]:
import numpy as np
from numpy import linalg as LA
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from scipy.linalg import sqrtm
from numpy.linalg import matrix_rank
from jax.numpy import linalg as jla
from scipy.linalg import logm, expm

from scipy.linalg import block_diag
import copy
import numba
from numba import int64, float64, jit, njit, vectorize
import matplotlib.pyplot as plt

In [21]:
#choose random seeds to make deterministic tie-breaking

import os
import sys
random_bytes=int.from_bytes(os.urandom(4), sys.byteorder)

np.random.seed(random_bytes)
dither = np.random.randint(0, 2, 10)

In [22]:
def qubit(x):
  '''
    Compute orthonormal basis states

          Arguments:
                  x : 0 or 1

          Returns:
                  state |0> or |1>
  '''
  if x==0:
    return np.array([1,0])
  elif x==1:
    return np.array([0,1])

def pauli(x):
    '''
    Compute pauli matrices

          Arguments:
                  x : {0,1,2,3}

          Returns:
                 Pauli Matrices: {I,X,Y,Z}

    '''
    if x==0:
        return np.array([[1,0],[0,1]])
    if x==1:
        return np.array([[0,1],[1,0]])
    if x==2:
        return np.array([[0,-1j],[+1j,0]])
    if x==3:
        return np.array([[1,0],[0,-1]])
    else:
        print('invalid')
def rhom(delta,g,x):
  '''
   Code to compute density matrix for given parameters

       Arguments:
                d(float): delta parameter of the BSCQ channel
                g(float): gamma parameter of the BSCQ channel
                x(0 or 1): 1 indicates applying Pauli X on the density matrix

       Returns:
                (float[:,:]): density matrix

  '''
  return pauli(x)@np.array([[delta, g],[g, 1-delta]])@pauli(x)
@njit
def rhom_jit(delta:float64,g:float64) -> (float64[:,:]):
    '''
       Code to compute density matrix for given parameters using numba

          Arguments:
                d(float64): delta parameter of the BSCQ channel
                g(float)64: gamma parameter of the BSCQ channel


          Returns:
                (float64[:,:]): density matrix

    '''
    return np.array([[delta, g],[g, 1-delta]],dtype=float64)
@njit('(int64[:])(int64[:])') # Input/output specifications to make Numba work
def polar_transform(u):
    '''
    Encode polar information vector u

          Arguments:
                  u (int64[:]): Numpy array of input bits

          Returns:
                  x (int64[:]): Numpy array of encoded bits
    '''
    # Recurse down to length 1
    if (len(u)==1):
        return u;
    else:
        # R_N maps odd/even indices (i.e., u1u2/u2) to first/second half
        # Compute odd/even outputs of (I_{N/2} \otimes G_2) transform
        x = np.zeros(len(u), dtype=np.int64)
        x[:len(u)//2] = polar_transform((u[::2]+u[1::2])%2)
        x[len(u)//2:] = polar_transform(u[1::2])
        return x

@njit
def polar_design(biterrd,d):
    '''
    Design polar code from additive channel noise scores

          Arguments:
                  biterrd (float[:]): Numpy array of channel noise scores (e.g., error rates)
                  d (float): Sum constraint of total noise for unfrozen channels

          Returns:
                  f (float[:]): Numpy array indicating frozen bits by 0 and info bits by 0.5
    '''
    # Sort into increasing order and compute cumulative sum
    order = np.argsort(biterrd)
    SE = biterrd[order]
    CSE = np.cumsum(SE)

    # Find best frozen bits
    k = np.sum(CSE<d)
    f = np.zeros(len(biterrd))
    f[order[:k]] = 0.5
    return f



# Check-node operation in P1 domain
#   For two independent bits with P1 equal to w1,w2, return probability of even parity
@vectorize([float64(float64,float64)],nopython=True)
def cnop(w1,w2):
    return w1*(1-w2) + w2*(1-w1)

# Bit-node operation in P1 domain
#   For two independent P1 observations (w1,w2) a uniform bit, return P1 of the bit
@vectorize([float64(float64,float64)],nopython=True)
def vnop(w1,w2):
    return (w1*w2) / (w1*w2 + (1-w1)*(1-w2))

# Hard decision with randomized rounding in P1 domain
#   Return hard MAP decision with randomized tie breaking for P1 observation

# @vectorize([int64(float64)],nopython=True)
# def hard_dec_rr(w):
#     if w==0.5:
#       return 1
#     else:
#       return np.int64((1-np.sign(1-2*w)>2*np.random.random_sample(1)).all())

# @vectorize([int64(float64)],nopython=True)
# def hard_dec_rr(w):
#     return np.int64(((w+2e-12*np.random.random_sample(1))>0.5+1e-12).all())


@vectorize([int64(float64)],nopython=True)
def hard_dec_rr(w):
    return np.int64(((w+2e-12*np.array([1]))>0.5+1e-12).all())



# @vectorize([int64(float64)],nopython=True)
# def hard_dec_rr(w):
#     return np.int64(((w+2e-12*dither[0])>0.5+1e-12).all())
# def hard_dec(s:float64)->int64:
#   if abs(s-0.5)<10**(-5):
#     o=int(np.random.random()>0.5)
#     return o
#   else:
#     return np.round(s)

@njit # Input/output specifications below to make Numba work
def polar_decode(y: float64[:],f: float64[:]) -> (int64[:],int64[:]):
    '''
    Recursive succesive cancellation polar decoder from P1 observations

          Arguments:
                  y (float[:]): channel observations in output order
                  f (float[:]): input a priori probabilities in input order

          Returns:
                  u (int[:]): input hard decisions in input order
                  x (int[:]): output hard decisions in output order
    '''
    # Recurse down to length 1
    N = len(y)
    if (N==1):

        # If information bit (i.e., f=1/2 for P1 domain)
        x = hard_dec_rr(y)

        if (f[0]==1/2):
          #  print(y)
            # Make hard decision based on observation
            return x, x.copy()
        else:
            # Use frozen bit (u,x) = (f,f)

            return f.astype(np.int64),f.astype(np.int64)
    else:
        # Compute soft mapping back one stage
        u1est = cnop(y[::2],y[1::2])
       # print(u1est,'u estimate cnop')
        # R_N^T maps u1est to top polar code
        uhat1, u1hardprev = polar_decode(u1est,f[:(N//2)])
       # print(uhat1,'u hard dec')

        # Using u1est and x1hard, we can estimate u2
        u2est = vnop(cnop(u1hardprev,y[::2]),y[1::2])
      #  print(u2est,' u estimate vnop')
        # R_N^T maps u2est to bottom polar code
        uhat2, u2hardprev = polar_decode(u2est,f[(N//2):])
       # print(uhat2, 'u hard dec')
    # Pass u decisions up and interleave x1,x2 hard decisions
    #   note: Numba doesn't like np.concatenate
    u = np.zeros(N,dtype=np.int64)
    u[:(N//2)] = uhat1
    u[(N//2):] = uhat2
    x1 = cnop(u1hardprev,u2hardprev)
    x2 = u2hardprev
    x = np.zeros(N,dtype=np.int64)
    x[::2] = x1
    x[1::2] = x2

    return u, x


def classical_cnop(y1,y2):
  return y1*(1-y2)+y2*(1-y1)
def classical_vnop(y1,y2):
  return (y1*y2)/(y1*y2+(1-y1)*(1-y2))



def hard_dec(s):
  if abs(s-0.5)<10**(-5):
    #o=np.random.choice([0,1],p=[0.5,0.5])
    return 1
  else:
    return np.round(s)

def hard_dec_vec(s):
  o=np.zeros(len(s))
  for i in range(len(s)):
    o[i]=hard_dec_rr(s[i])
  return o



def message_permute(msg,source, des):
  '''
  permutating the message probabilities according to qubit indices

  Arguments
          msg: message probabilities 1D numpy array
          source: indices of source frozen qubits message is dependent on
                  log2(len(msg)) length numpy 1D array
          des:    indices of destination frozen qubits message is dependent on
                  log2(len(msg)) length numpy 1D array

  Returns
         numpy 1D array for pemutated message probabilties

  '''
  if np.shape(source)==np.shape(des):
    source=source
  else:
    source=source[0]

  des_indices=[np.argwhere(el==source)[0,0] for el in des]
  #print(des_indices)
  source_indices=np.array(range(len(source)))
 # print(source_indices)
  msg1=np.reshape(msg,2*np.ones(len(source),dtype=int))
  #print(np.shape(msg1))
  msg2=np.moveaxis(msg1,des_indices,source_indices)
  msg3=np.reshape(msg2,np.shape(msg))
  return msg3





def cnop_vec(msg1,msg2,conditional_list1,conditional_list2):

  '''
   code to checknode combine messages conditioned on frozen qubits

   Arguments:

         msg1: first message vector (1D numpy array)
         msg2: second message vector (1D numpy array)
         conditional_list1: list of indices of frozen qubits first message vector is conditioned on
         conditional_list2: list of indices of frozen qubits second message vector is conditioned on


   Returns:
         msg: checknode combined message (sorted 1D numpy array)
         ordered_list: sorted list of indices of frozen qubits (in ascending order)

  '''
  msg=[]
  cl1=conditional_list1
  cl2=conditional_list2

  ci=np.intersect1d(cl1,cl2)
  ci1=np.setdiff1d(cl1,ci)
  ci2=np.setdiff1d(cl2,ci)
  new_cl1=np.concatenate([ci,ci1])
 # print(new_cl1)
  new_cl2=np.concatenate([ci,ci2])
  new_msg1=message_permute(msg1,cl1,new_cl1)
 # print(new_msg1,'msg1')
 # print(new_cl2)
  new_msg2=message_permute(msg2,cl2,new_cl2)
 # print(new_msg2)
  for i in range(2**len(ci)):
    for j in range(2**len(ci1)):
      for k in range(2**len(ci2)):
        msg.append(classical_cnop(new_msg1[i*2**len(ci1)+j],new_msg2[i*2**len(ci2)+k]))
  msg=np.array(msg)
  union_list=np.concatenate([ci,ci1,ci2])
  ordered_list=np.sort(union_list)
  msg=message_permute(msg,union_list,ordered_list)
  # cl,indices=np.unique(np.concatenate([ci,ci1,ci2]),return_index=True)
  # conditional_list=cl[indices]

  return msg, ordered_list


def vnop_vec(msg1,msg2,conditional_list1,conditional_list2):
  '''
   code to bitnode combine messages conditioned on frozen qubits

   Arguments:

         msg1: first message vector (1D numpy array)
         msg2: second message vector (1D numpy array)
         conditional_list1: list of indices of frozen qubits first message vector is conditioned on
         conditional_list2: list of indices of frozen qubits second message vector is conditioned on


   Returns:
         msg: bitnode combined message (sorted 1D numpy array)
         ordered_list: sorted list of indices of frozen qubits (in ascending order)

  '''
  msg=[]
  cl1=conditional_list1
  cl2=conditional_list2
  ci=np.intersect1d(cl1,cl2)
  ci1=np.setdiff1d(cl1,ci)
  ci2=np.setdiff1d(cl2,ci)
  new_cl1=np.concatenate([ci,ci1])

  new_cl2=np.concatenate([ci,ci2])
  new_msg1=message_permute(msg1,cl1,new_cl1)
  new_msg2=message_permute(msg2,cl2,new_cl2)
  for i in range(2**len(ci)):
    for j in range(2**len(ci1)):
      for k in range(2**len(ci2)):
        msg.append(classical_vnop(new_msg1[i*2**len(ci1)+j],new_msg2[i*2**len(ci2)+k]))
  msg=np.array(msg)
  union_list=np.concatenate([ci,ci1,ci2])
  ordered_list=np.sort(union_list)
  msg=message_permute(msg,union_list,ordered_list)
  # cl,indices=np.unique(np.concatenate([ci,ci1,ci2]),return_index=True)
  # conditional_list=cl[indices]

  return msg, ordered_list

def bin_array(k,w):
  a=np.zeros(w)
  for j in range(w):
    a[j]=int(k[j])
  return np.array(a)
def bin2dec(b):
  o=0
  for i in range(len(b)):
    o=o+b[i]*2**(len(b)-i-1)
  return o

def bit_reversal(v,n):
  '''
   code to apply bit reversal

   Arguments:
           v: decimal value
           n: number of bits for the binary representation of the decimal value

    Returns:
           decimal value of the bit reversed number

   '''
  a=np.flip(np.array(bin_array(np.binary_repr(v,n),n),dtype=int))
  return bin2dec(a)
def sample_qubit_sequence(seq):
  '''
  code to sample multi-qubit state in |0>,|1> basis
  Arguments:
       seq: binary sequence to determine the joint state

   returns:
         joint quantum state 2^(len(seq)) 1D array

  '''
  if seq[0]==0:
    q=qubit(0)
  else:
    q=qubit(1)
  for i in range(1,len(seq)):
    if seq[i]==0:
      q1=qubit(0)
    else:
      q1=qubit(1)

    q=np.kron(q,q1)
  return q
def conditional_frozen_unitary(hard_dec,frozen_set_length):

  '''
  code to design conditional unitary depending  on frozen qubits

   Arguments:

          hard_dec: hard dec message vector conditioned on the frozen vector
          frozen_set_length: length of the frozen set the unitary is dependent on

    returns:
           unitary of dimension (2^(frozen_set_length+1)\times 2^(frozen_set_length+1))
  '''
  l=frozen_set_length
  u=np.zeros((2**(l+1),2**(l+1)))
  for i in range(2**l):
    seq=np.array(bin_array(np.binary_repr(i,l),l),dtype=int)

    qs=sample_qubit_sequence(seq)
    if hard_dec[i]==0:
      m=np.eye(2)
    else:
      m=pauli(1)
    u=u+np.kron(np.outer(qs,qs),m)
  return u


def compute_permutation(number_qubits,information_qubit,frozen_list):
  '''
   code to compute permutation and inverse permutation to apply unitary based on indices
     Arguments:
            number_qubits: number of qubits present in the decoding
            information_qubit: index of the information qubit
            frozen_list: list of indices of frozen qubits the information qubit is dependent on

      Returns:
            p: permuation
            pi: inverse permuation to flip back to the orginal indices

  '''
  n=number_qubits

  l=len(frozen_list)

  p=np.zeros(n,dtype=int)
  p[information_qubit[0]]=l
  for j in range(l):
    p[frozen_list[j]]=j

  c=0
  for j in range(n):
    if j not in information_qubit and j not in frozen_list:
      p[j]=l+1+c
      c=c+1

  p1=np.zeros((len(p),len(p)))
  for i in range(len(p)):
    p1[:,i][p[i]]=1
  p2=p1.T
  pi=np.zeros(n,dtype=int)
  for i in range(n):
    pi[i]=int(np.where(p1.T[:,i]==1)[0][0])
  return p,pi

def apply_frozen_unitary(q_state,number_qubits,unitary,information_qubit,frozen_list):
  '''
      Code to apply conditional frozen unitary on given quantum systems

          Arguments:
                q_state(float[:]): joint qubit state
                number_qubits: number of qubits present in the decoding
                unitary: the unitary to be applied on the information qubit
                information_qubit: index of the information qubit
                frozen_list: list of indices of frozen qubits the information qubit is dependent on

         Returns:
                 q_out(float[:]): final qubit state after applying the unitary

  '''
  n=number_qubits
  list1=np.concatenate([frozen_list,information_qubit])
  l=len(list1)
  u=unitary
  q=np.reshape(q_state,2*np.ones(n,dtype=int))
  p,pi=compute_permutation(n,information_qubit,frozen_list)
  q1=np.moveaxis(q,np.array(range(n)),p)
  q2=np.reshape(q1,(2**l,2**(n-l)))
  q_out=u@q2
  q_out=np.reshape(q_out,2*np.ones(n,dtype=int))
  q_out=np.moveaxis(q_out,np.array(range(n)),pi)
  q_out=np.reshape(q_out,np.shape(q_state))
  return q_out


def polar_compression(q_state,yi,msg,yd,frozen_set,code_length):
  '''
       Successive cancellation decoding for binary polar codes to recover complete quantum state from its compressed
       version.
       Arguments:
               q_state(float[:]): joint quantum state
               yi(int[:]):qubit indices for the code bits
               msg(List(float[:])): conditional messages
               yd (List(int[:])): list of frozen qubits message is conditioned on
               number_qubits(int):total number of qubits in the joint state
               frozen_set: list of indices for frozen qubit
               code_length: length (N) of the polar code

        Returns:
               q_state: recovered quantum state
               msg4:message list sent back after bit-node combine
               yd4: list of frozen qubits message is conditioned on after bit-node combine

  '''

  N=len(yi)
  f=frozen_set
  lc=code_length
  if N==1:
    if f[0]==1:
     # print(yd[0],'info')
      # print(msg[0],'info message')
      h=hard_dec_vec(msg[0])
      # print(h,'hard dec')
     # print(h)
      u=conditional_frozen_unitary(h,len(yd[0]))
      information_qubit=np.array([bit_reversal(yi[0],lc)])
     # print(information_qubit)
      q_state=apply_frozen_unitary(q_state,2**lc,u,information_qubit,np.array(yd[0],dtype=int))
    #  print(q_state,'quantum state')
      return q_state,np.array([h]),np.array([yd[0]])

      # if type(yd[0]) is list==True:
      #   return np.array([hard_dec_vec(msg[0])]),np.array([yd[0]])
      # else:
      #   return np.array([hard_dec_vec(msg[0])]),np.array(yd[0])
    else:
      #print(yi[0],'frozen')

      return q_state,np.array([[0,1]]),np.array([bit_reversal(yi[0],lc)])

  msg1=[]
  yd1=[]
  for i in range(0,N,2):

    msg11,list1=cnop_vec(msg[i],msg[i+1],yd[i],yd[i+1])
    msg1.append(msg11)
    yd1.append(list1)

  #print(yd1)
  # print(msg1,'checknode combine')
  # print(yd1,'conditional list')
  q_state,new_msg1,new_list1=polar_compression(q_state,yi[::2],np.array(msg1),np.array(yd1),f[:(N//2)],lc)

  msg2=[]
  yd2=[]
  # print(new_msg1,'new msg1')
  # print(new_list1,'new_list1')
  # print(msg, 'message previous stage')

  for i in range(0,N,2):
    msg21,list2=cnop_vec(msg[i],new_msg1[i//2],yd[i],np.array([new_list1[i//2]]))
    msg2.append(msg21)
    yd2.append(list2)
  # msg2=np.array(msg2)
  # yd2=np.array(yd2)
  # print(msg2,'check-combine-return')
  # print(yd2,'conditional list 2')

  msg3=[]
  yd3=[]
  #print(msg)
  for i in range(0,N,2):
    msg31,list3=vnop_vec(msg[i+1],msg2[i//2],yd[i+1],yd2[i//2])
    msg3.append(msg31)
    yd3.append(list3)


  # msg3=np.array(msg3)
  # yd3=np.array(yd3)
  # print(msg3,'bitnode-combine')
  # print(yd3,'coditional list 3')

  q_state,new_msg2,new_list2=polar_compression(q_state,yi[1::2],msg3,yd3,f[(N//2):],lc)

  msg4=[]
  yd4=[]
  # print(new_msg2,'hard dec back2')
  # print(new_list2,'clist2')
  # print(new_msg1,'hard dec back 1')
  # print(new_list1,'clist1')
  for i in range(0,N,2):
    msg41,list4=cnop_vec(new_msg1[i//2],new_msg2[i//2],np.array([new_list1[i//2]]),np.array([new_list2[i//2]]))
    msg42,list42=cnop_vec(np.array([0]),new_msg2[i//2],np.array([]),np.array([new_list2[i//2]]))
    msg4.append(msg41)
    msg4.append(msg42)
    yd4.append(list4)
    yd4.append(list42)
  # msg4=np.array(msg4)
  # yd4=np.array(yd4)

  # print(msg4,'bitnode-return')
  # print(yd4,'condtional list4')
  # new_msg=np.array([msg4[0],new_msg2[0]])
  # new_list=np.array([yd4[0],new_list2[0]])





  return q_state,msg4,yd4


In [23]:
'''
Code to verify decompression of quantum state using polar code works

'''


def compute_error_pattern(frozen_set,frozen_pattern,chan_err):
  '''
    Code to find complete sequence from frozen set and frozen values using classical polar decoding
    Arguments
        frozen_set: list of indices for frozen qubits
        frozen_pattern: values of frozen bits
        chan_err: probabilites corresponding to bit values 0 and 1

    returns
         u:complete sequence in input order
         polar_transform(u): complete sequence in the output order
  '''
  f=np.where(frozen_set==0)[0]
  info=np.where(frozen_set==1)[0]
  # print(f)
  # print(info)
  l=len(f)
  l1=len(info)
  s=np.zeros(l+l1)
  for k in range(l):
    s[f[k]]=frozen_pattern[k]
  for i in range(l1):
    s[info[i]]=0.5
  u,x=polar_decode(chan_err*np.ones(l+l1),s)
  return u,polar_transform(u)


def compute_error_pattern_classical(frozen_set,chan_err):
  '''
    Code to find complete sequences over all frozen patterns from a fixed frozen set using classical polar decoding
    Arguments
        frozen_set: list of indices for frozen qubits
        chan_err: probabilites corresponding to bit values 0 and 1

    returns
         o: list of recovered sequences for each frozen pattern over the frozen set
  '''
  f=np.where(frozen_set==0)[0]
  l=len(f)
  o=[]
  for i in range(2**l):
    frozen_pattern=np.array(bin_array(np.binary_repr(i,l),l),dtype=int)
    o.append([frozen_pattern,compute_error_pattern(frozen_set,frozen_pattern,chan_err)])
  return o


def compute_syndrome_pattern(frozen_set,frozen_pattern,chan_err,code_length):
  '''
    Code to find complete state from frozen set and frozen values using quantum polar decoding
    Arguments
        frozen_set: list of indices for frozen qubits
        frozen_pattern: values of frozen bits
        chan_err: probabilites corresponding to state |0> and |1>
        code_length: length (N) of the polar code

    returns
         u: binary representation of the index of basis in |0> and |1> basis in the input order
         polar_transform(u): binary representation of the index of basis in |0> and |1> basis in the output order
  '''
  l=len(frozen_set)
  yi=np.array(range(l))
  msg=[]
  yd=[]
  for i in range(l):
    msg.append(np.array([chan_err]))
    yd.append(np.array([]))
  f=np.where(frozen_set==0)[0]
  info=np.where(frozen_set==1)[0]
  lf=len(f)
  li=len(info)
  s=np.zeros(l)
  for k in range(lf):
    s[f[k]]=frozen_pattern[k]
  for i in range(li):
    s[info[i]]=0
  q=sample_qubit_sequence(np.array(s))
  p=polar_compression(q,yi,msg,yd,frozen_set,code_length)
  a=np.nonzero(p[0])[0][0]
  u=np.array(bin_array(np.binary_repr(a,l),l),dtype=int)
  return u, polar_transform(u)

def compute_error_pattern_quantum(frozen_set,chan_err,code_length):
  '''
    Code to find complete states over all frozen patterns from a fixed frozen set using quantum polar decoding
    Arguments
        frozen_set: list of indices for frozen qubits
        chan_err: probabilites corresponding to bit values 0 and 1
        code_length: length (N) of the polar code

    returns
         o: list of recovered sequences for each frozen pattern over the frozen set
  '''
  f=np.where(frozen_set==0)[0]
  l=len(f)
  o=[]
  for i in range(2**l):
    frozen_pattern=np.array(bin_array(np.binary_repr(i,l),l),dtype=int)
    o.append([frozen_pattern,compute_syndrome_pattern(frozen_set,frozen_pattern,chan_err,code_length)])
  return o


Test Cases
Length-16 Polar code with pre-defined frozen set, 0 indicates frozen, 1 indicates information qubit
classical output sequences should match with qunatum outputs for each frozen pattern

In [24]:
o=compute_error_pattern_classical(np.array([0,1,0,1,0,1,0,1,0,1,0,1,1,1,1,1]),0.9)
o1=compute_error_pattern_quantum(np.array([0,1,0,1,0,1,0,1,0,1,0,1,1,1,1,1]),0.9,4)

In [25]:
print(o[23])
print(o1[23])

[array([0, 1, 0, 1, 1, 1]), (array([0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]), array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]))]
[array([0, 1, 0, 1, 1, 1]), (array([0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]), array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]))]


Verify decoding works over superposition

In [26]:
#testing for quantum
yi=np.array(range(8))
msg=[]
yd=[]
for i in range(8):
  msg.append(np.array([0.9]))
  yd.append(np.array([]))

q=(1/np.sqrt(3))*(sample_qubit_sequence(np.array([1,0,0,0,1,0,1,0]))+sample_qubit_sequence(np.array([1,0,1,0,1,0,1,0]))+sample_qubit_sequence(np.array([1,0,1,0,0,0,0,0])))
p=polar_compression(q,yi,msg,yd,np.array([0,1,0,1,0,1,0,1]),3)
print(np.nonzero(p[0]))

(array([206, 241, 254]),)


In [27]:
#verifying with individual states
q1=sample_qubit_sequence(np.array([1,0,0,0,1,0,1,0]))
p1=polar_compression(q1,yi,msg,yd,np.array([0,1,0,1,0,1,0,1]),3)
print(np.nonzero(p1[0]))
q1=sample_qubit_sequence(np.array([1,0,1,0,1,0,1,0]))
p1=polar_compression(q1,yi,msg,yd,np.array([0,1,0,1,0,1,0,1]),3)
print(np.nonzero(p1[0]))
q1=sample_qubit_sequence(np.array([1,0,1,0,0,0,0,0]))
p1=polar_compression(q1,yi,msg,yd,np.array([0,1,0,1,0,1,0,1]),3)
print(np.nonzero(p1[0]))

(array([206]),)
(array([254]),)
(array([241]),)
