In [31]:
import sys
sys.path.append('../')
from BoundaryMPS import BoundaryMPS
import numpy as np
from scipy.linalg import block_diag

In [32]:
def solver(H, s):
    '''
    get pure error give syndrome and parity check matrix
    '''
    columns = []
    for i in range(H.shape[0]):
        for j in range(H.shape[1]):
            if H[i, j] == 1:
                columns.append(j)
                break

    H_reduce = H[:, columns]
    assert H_reduce.shape == (len(s), len(s))
    error_columns = s.copy()
    for i in range(len(s)):
        for j in range(i+1, len(s)):
            if H_reduce[j, i] == 1:
                H_reduce[j] = H_reduce[i] - H_reduce[j]
                error_columns[j] = abs(error_columns[i] - error_columns[j])

    for i in range(len(s)-1, -1, -1):
        for j in range(i-1, -1, -1):
            if H_reduce[j, i] == 1:
                H_reduce[j][i:] = H_reduce[i][i:] - H_reduce[j][i:]
                error_columns[j] = abs(error_columns[i] - error_columns[j])
    
    pure_error = np.array(
        [0 if i not in columns else error_columns[columns.index(i)] 
        for i in range(H.shape[1])]
    )

    return pure_error

def enumerate_config(shape):
    assert len(shape) == 4
    configs = np.array([
        [i, j, k, l] 
        for i in range(shape[0])
        for j in range(shape[1])
        for k in range(shape[2])
        for l in range(shape[3])
    ])
    return configs

class SurfaceCode:
    """
    k=1 surface code for testing
    """
    def __init__(self, d, error_prob:np.ndarray) -> None:
        self.d = d
        self.L = 2 * d - 1
        self.n = self.L ** 2 // 2 + 1
        self.m = self.n - 1

        # location of physical qubits
        self.loc_q = [
            (i, j) for i in range(self.L) for j in range(i%2, self.L, 2)
        ]
        # location of x checks
        self.loc_x = [
            (i, j) for i in range(1, self.L, 2) for j in range(0, self.L, 2)
        ]
        # locations of z checks
        self.loc_z = [
            (i, j) for i in range(0, self.L, 2) for j in range(1, self.L, 2)
        ]

        # constructing parity check matrix
        self.Hx = np.zeros([len(self.loc_x), self.n])
        self.Hz = np.zeros([len(self.loc_z), self.n])
        for i in range(len(self.loc_x)):
            x, y = self.loc_x[i]
            check_loc = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]
            for loc in check_loc:
                if loc in self.loc_q:
                    self.Hx[i, self.loc_q.index(loc)] = 1
        for i in range(len(self.loc_z)):
            x, y = self.loc_z[i]
            check_loc = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]
            for loc in check_loc:
                if loc in self.loc_q:
                    self.Hz[i, self.loc_q.index(loc)] = 1
        # self.Hx, self.Hz = self.Hz, self.Hx
        self.H = block_diag(*(self.Hz, self.Hx))
        
        self.error_prob = error_prob.astype(np.float64)

        # constructing logical operators, 0: I, 1: X, 2: Z, 3: Y
        self.logical_op = np.zeros([4, 2 * len(self.loc_q)])
        self.logical_op[1] = np.hstack([
            np.array([1 if i == 0 else 0 for i, j in self.loc_q]),
            np.zeros(len(self.loc_q))
        ])
        self.logical_op[2] = np.hstack([
            np.zeros(len(self.loc_q)),
            np.array([1 if j == 0 else 0 for i, j in self.loc_q])
        ])
        self.logical_op[3] = self.logical_op[1:3].sum(axis=0) % 2
        print(self.Hz)
        print(self.Hx)
        print(self.logical_op[1])
        print(self.logical_op[2])
        assert not (self.Hx @ self.logical_op[2, self.n:].T % 2).any()
        assert not (self.Hz @ self.logical_op[1, :self.n].T % 2).any()

        self.bmps_contractor = BoundaryMPS(np)
        pass
    
    def pure_error(self, syndrome):
        """
        function for generating error operator from syndrome
        """
        check_length = self.Hx.shape[0]
        assert len(syndrome) == 2 * check_length

        pure_error_x = solver(self.Hz, syndrome[:check_length])
        pure_error_z = solver(self.Hx, syndrome[check_length:])
                
        error = np.hstack([pure_error_x, pure_error_z])
        assert np.allclose((self.H @ error) % 2, syndrome)
            
        return error
    
    def syndrome(self, error_operator):
        return (self.H @ error_operator) % 2
    
    def mps_decoding(self, pure_error, print_coset=False):
        """
        function for finding most likely cosets given syndrome
        """
        prob = np.zeros(self.logical_op.shape[0], dtype=np.float64)
        for i in range(self.logical_op.shape[0]):
            tn = self.construct_tn(pure_error, self.logical_op[i])
            result, error = self.bmps_contractor(tn, 'normal', 32, 100)
            prob[i] = result[0] * 10 ** result[1]
        if print_coset:
            print('cosets probabilties:', prob)
        index = np.argmax(prob)
        return index, prob

    def construct_tn(self, pure_error, logical_op):
        """
        function for generating tensor network given syndrome and logical operator
        """
        tensors = [[0] * self.L for _ in range(self.L)]
        pure_error = pure_error.reshape(2, self.n)
        logical_op = logical_op.reshape(2, self.n)

        for i in range(self.L):
            for j in range(self.L):
                if (i, j) in self.loc_x or (i, j) in self.loc_z:
                    neigh = [(i-1, j), (i, j-1), (i+1, j), (i, j+1)]
                    shape = [1 if loc not in self.loc_q else 2 for loc in neigh]
                    tensors[i][j] = np.zeros(np.prod(shape), dtype=np.float64)
                    tensors[i][j][0] = tensors[i][j][-1] = 1
                elif (i, j) in self.loc_q:
                    ind = self.loc_q.index((i, j))
                    neigh = [(i-1, j), (i, j-1), (i+1, j), (i, j+1)]
                    shape = [1 if loc not in self.loc_x + self.loc_z else 2 for loc in neigh]
                    is_x = [False if loc not in self.loc_x else True for loc in neigh]
                    is_z = [False if loc not in self.loc_z else True for loc in neigh]
                    tensors[i][j] = np.zeros(np.prod(shape), dtype=np.float64)
                    configs = enumerate_config(shape)
                    for k, config in enumerate(configs):
                        pauli_config = pure_error[:, ind] + logical_op[:, ind]
                        pauli_config[0] += config[is_x].sum()
                        pauli_config[1] += config[is_z].sum()
                        tensors[i][j][k] = self.error_prob[((pauli_config % 2) * np.array([1, 2])).sum().astype(np.int8)]
                tensors[i][j] = tensors[i][j].reshape(shape)

        return tensors
    
    def run_sim(self, iter_num, seed=0):
        np.random.seed(seed)
        success_num = 0
        for i in range(iter_num):
            error_ixyz = np.random.choice([0, 1, 3, 2], self.n, p=self.error_prob)
            error_binary = np.hstack([(error_ixyz // 2 + error_ixyz % 2) % 2, error_ixyz // 2])
            syndrome = self.syndrome(error_binary)
            pure_error = self.pure_error(syndrome)
            ind_logical, coset_probs = self.mps_decoding(pure_error, False)
            recovery_op = (self.logical_op[ind_logical] + pure_error) % 2
            success = (block_diag(*(self.logical_op[2, self.n:], self.logical_op[1, :self.n])) @ ((recovery_op + error_binary) % 2)).sum() == 0
            if success:
                success_num += 1
            print(f'trial {i}:')
            print('error:', error_binary)
            print('syndrome:', syndrome)
            print('pure error:', pure_error)
            print(ind_logical, -np.log(coset_probs))
            print(recovery_op)
            print(success)
        return 1 - success_num/iter_num

In [33]:
error_prob = np.array([0.7, 0.1, 0.1, 0.1])
model = SurfaceCode(2, error_prob)
print(model.run_sim(1000, 0))


[[1. 1. 1. 0. 0.]
 [0. 0. 1. 1. 1.]]
[[1. 0. 1. 1. 0.]
 [0. 1. 1. 0. 1.]]
[1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 1. 0.]
trial 0:
error: [0 1 0 0 0 0 0 0 0 0]
syndrome: [1. 0. 0. 0.]
pure error: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
0 [3.52540099 3.52540099 5.05145729 5.05145729]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
False
trial 1:
error: [0 0 0 1 0 0 0 1 1 0]
syndrome: [0. 1. 0. 1.]
pure error: [1. 0. 1. 0. 0. 0. 1. 0. 0. 0.]
3 [4.67989373 4.67989373 4.67989373 3.62834895]
[0. 1. 1. 0. 0. 1. 1. 0. 1. 0.]
False
trial 2:
error: [1 0 0 1 0 0 0 0 1 0]
syndrome: [1. 1. 1. 0.]
pure error: [0. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
0 [4.4096034  5.05145729 4.4096034  5.05145729]
[0. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
False
trial 3:
error: [0 0 0 1 0 0 0 1 0 1]
syndrome: [0. 1. 1. 0.]
pure error: [1. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
2 [4.67989373 4.67989373 3.62834895 4.67989373]
[1. 0. 1. 0. 0. 0. 0. 0. 1. 0.]
False
trial 4:
error: [1 1 0 1 0 1 0 0 0 0]
syndrome: [0. 1. 1. 0.]
pure error: [1. 0. 1. 0. 0. 1. 0