# In this note,
1. no reset
2. use CZ gates

## import the decoder

In [1]:
import numpy as np
from mip import Model, xsum, minimize, BINARY
from mip import OptimizationStatus
from bposd.css import css_code
from ldpc import bposd_decoder
import itertools
from ldpc import bposd_decoder 
from scipy.sparse import coo_matrix, hstack 
from ldpc import mod2
from tabulate import tabulate
import pickle
import matplotlib.pyplot as plt

In [131]:
num_cycles = 4 ;

num_L = 9 ; num_R = 9 ;
index_Z_check = [0,1,2,6,7,8] ;
index_X_check = [0,1,3,4,6,7] ;

In [132]:
with open('Set_decoder_para_18_6_3.pkl', 'rb') as f:
    Set_decoder_para_18_6_3 = pickle.load(f)

In [133]:
k = 6 ; n2 = 9 ;

channel_probsX = Set_decoder_para_18_6_3[f"channel_probsX_{num_cycles}"]
HX = Set_decoder_para_18_6_3[f"HX_{num_cycles}"]  ;
HdecX = Set_decoder_para_18_6_3[f"HdecX_{num_cycles}"] ;

channel_probsZ = Set_decoder_para_18_6_3[f"channel_probsZ_{num_cycles}"]
HZ = Set_decoder_para_18_6_3[f"HZ_{num_cycles}"]  ;
HdecZ = Set_decoder_para_18_6_3[f"HdecZ_{num_cycles}"] ;

In [134]:
# Setup BP-OSD decoder parameters
my_bp_method = "ms"
my_max_iter = 10000
my_osd_method = "osd_cs"
my_osd_order = 7
my_ms_scaling_factor = 0

bpdX=bposd_decoder(
    HdecX, #the parity check matrix
    channel_probs=channel_probsX, #assign error_rate to each qubit. This will override "error_rate" input variable
    max_iter=my_max_iter, #the maximum number of iterations for BP)
    bp_method=my_bp_method,
    ms_scaling_factor=my_ms_scaling_factor, #min sum scaling factor. If set to zero the variable scaling factor method is used
    osd_method=my_osd_method, #the OSD method. Choose from:  1) "osd_e", "osd_cs", "osd0"
    osd_order=my_osd_order #the osd search depth
    )

bpdZ=bposd_decoder(
    HdecZ, #the parity check matrix
    channel_probs=channel_probsZ, #assign error_rate to each qubit. This will override "error_rate" input variable
    max_iter=my_max_iter, #the maximum number of iterations for BP)
    bp_method=my_bp_method,
    ms_scaling_factor=my_ms_scaling_factor, #min sum scaling factor. If set to zero the variable scaling factor method is used
    osd_method="osd_cs", #the OSD method. Choose from:  1) "osd_e", "osd_cs", "osd0"
    osd_order=my_osd_order #the osd search depth
    )

## Logical Z data

### data processing

In [135]:
with open (f'Experimental_data/remove_leakage_ldpc_test_result_logical_Z_cycle_num={num_cycles}','rb') as f:
    data_z = pickle.load(f)
data_z.keys()

dict_keys(['cycle_num', 'random_initial_state', 'states', 'final_cycle_states'])

In [136]:
# Z check matrix for the BB code [[18,4,4]] 
hz = Set_decoder_para_18_6_3["hz"] ;
lz = Set_decoder_para_18_6_3["lz"] ;

In [137]:
# data_z['random_initial_state']

In [138]:
# data qubit的初态 共有 M 个初态
random_initial_state_z = data_z['random_initial_state']

num_initial_z = random_initial_state_z.shape[0]
# print(random_initial_state_z[0])

# 按照L0...L8,R0...R8的顺序排列
for iter in range( num_initial_z ) :
    random_initial_state_z[iter] = dict(sorted(random_initial_state_z[iter].items(), key=lambda x: (x[0][2:])))

#将 dict 转化为 list
list_initial_z = [] ;
for iter in range( num_initial_z  ) :
    list_initial_z.append( list(random_initial_state_z[iter].values()) )

In [142]:
# record the Z parity check outcomes and the logical Z outcomes
dict_cycle_z_parity = {} ;
dict_cycle_z_devents = {}
for i in range(1, num_cycles+1):
    dict_cycle_z_parity[f'cycle_{i}'] = [] ;
    dict_cycle_z_devents[f'cycle_{i}'] = [] ;

dict_cycle_z_parity['final'] = [] ;
dict_cycle_z_devents['final'] = [] ;

final_logical_z_outcome = [] ;

for s in range(num_initial_z):

    #第 s 个初始态
    initial_z_parity = hz @ np.array(list_initial_z[s]).T % 2

    num_trials_z = len(data_z['states']['cycle_1'][f'q_Z{1}'][s]) ; 
    
    initial_logical_z_outcome = (lz @ list_initial_z[s] ) % 2

    for p in range(num_trials_z):
    
        dict_cycle_z_parity['cycle_1'].append( np.array( [data_z['states']['cycle_1'][f'q_Z{i}'][s][p] for i in index_Z_check  ] )  )
    
        for cycle in range(2, 1+num_cycles):
            current = np.array([data_z['states'][f'cycle_{cycle}'][f'q_Z{i}'][s][p] for i in index_Z_check  ]) ;
            former = np.array([data_z['states'][f'cycle_{cycle-1}'][f'q_Z{i}'][s][p] for i in index_Z_check  ]) ;    
            dict_cycle_z_parity[f'cycle_{cycle}'].append( (  current + former ) %2  )
        
        final_state_z = [ data_z['final_cycle_states'][f'q_L{i}'][s][p]  for i in range(num_L) ] + \
                     [ data_z['final_cycle_states'][f'q_R{i}'][s][p]  for i in range(num_R) ]  ;
    
        dict_cycle_z_parity['final'].append( hz @ np.array(final_state_z).T % 2 ) ;   
    
        final_logical_z_outcome.append(   (lz @ final_state_z - initial_logical_z_outcome) % 2   )

        # record the deventions, a error devent occur if the current parity is different from the former one
##---------------------------------------------------------------------------------------------------------------------         
        dict_cycle_z_devents['cycle_1'].append( (dict_cycle_z_parity['cycle_1'][-1]  + initial_z_parity) %2  )
        
        for cycle in range(2, num_cycles+1):

            dict_cycle_z_devents[f'cycle_{cycle}'].append( (dict_cycle_z_parity[f'cycle_{cycle}'][-1] + \
                                                            dict_cycle_z_parity[f'cycle_{cycle-1}'][-1]) %2 ) ;

        dict_cycle_z_devents['final'].append( (dict_cycle_z_parity['final'][-1] + dict_cycle_z_parity[f'cycle_{num_cycles}'][-1]) %2  )

In [143]:
syndrome_history_Z = [np.hstack((arr)) for arr in zip(*dict_cycle_z_devents.values())] ;
# final_logical_z_outcome
num_detect_z = len(syndrome_history_Z[0]) ;
num_instance_z = len(syndrome_history_Z)

### error detections

In [None]:
labels_z = ['Z0', 'Z1', 'Z2', 'Z6', 'Z7', 'Z8']
# plt.bar(labels_z, np.sum(dict_cycle_z_devents['cycle_2'], axis=0)/num_trials)

In [None]:
z_list_detect_prob = [np.mean(np.mean(dict_cycle_z_devents[f'cycle_{cycle}'], axis=0)) for cycle in range(1, num_cycles+1)]
z_list_detect_prob += [np.mean(np.mean(dict_cycle_z_devents['final'], axis=0)) ]

for cycle, prob in  zip(range(1, num_cycles+2), z_list_detect_prob):
    print( f'The mean probability of error detection in cycle {cycle} is:', prob )

In [None]:
plt.plot(range(1, num_cycles+2), z_list_detect_prob)

### Correlation matrix $p_{ij}$ for error detection events

In [None]:
def p_ij(x_i, x_j, x_ij):
	return (1 / 2) - (1 / 2) * np.sqrt(  1 - ( 4 * ( x_ij - x_i * x_j)) / (1 - 2 * x_i - 2 * x_j + 4 * x_ij )  )

In [None]:
Z_Set_averages = {
    f"z({s1},{s2})": np.mean([np.prod([syndrome_history_Z[trial][s1], syndrome_history_Z[trial][s2]]) for trial in range(num_trials_z)])
    for s1 in range(num_detect_z-1)
    for s2 in range(s1, num_detect_z)  }

Z_Set_averages.update({
    f"z({s3})": np.mean([syndrome_history_Z[trial][s3] for trial in range(num_trials_z)])
    for s3 in range(num_detect_z)   })

In [None]:
Z_Correlation_matrix = np.zeros( (num_detect_z, num_detect_z) ) ;

for i in range(num_detect_z-1):
    for j in range(i+1, num_detect_z):
        Z_Correlation_matrix[i,j] = p_ij(Z_Set_averages[f"z({i})"], Z_Set_averages[f"z({j})"], Z_Set_averages[f"z({i},{j})"])
        Z_Correlation_matrix[j,i] = Z_Correlation_matrix[i,j] ;

In [None]:
plt.figure(figsize=(8, 6))
cax = plt.imshow(Z_Correlation_matrix, cmap='YlOrRd', interpolation='nearest')
plt.colorbar(cax)
plt.gca().invert_yaxis() 
plt.title('Correlation matrix $p_{ij}$', fontsize=16)
plt.xlabel('Z Check index i', fontsize=12)
plt.ylabel('Z Check index j', fontsize=12)

### error decoding

In [None]:
# correct errors for logical Z (Z checks)
Z_no_correction_good_trials, Z_no_correction_good_list, Z_good_trials, Z_good_list =     \
            error_decoding(num_instance_z, k, bpdZ, HZ, syndrome_history_Z, final_logical_z_outcome )

In [None]:
print("no error correction for logical Z basis:")
print(f'Logical error over {num_cycles+1} cycles (six logical qubits):', 1-Z_no_correction_good_trials/num_instance_z)
print(f'Logical error over {num_cycles+1} cycles (single logical qubit):', 1-Z_no_correction_good_list/num_instance_z)
print("\n")
print("error correction for logical Z basis:")
print(f'Logical error over {num_cycles+1} cycles (six logical qubits):', 1-Z_good_trials/num_instance_z)
print(f'Logical error over {num_cycles+1} cycles (single logical qubit):', 1-Z_good_list/num_instance_z)

In [None]:
# print( "Estimated logical error (Z) per cycle:", 1 - (Z_good_trials/num_trials_z) ** (1/(num_cycles+1)) )

## Logical X data

### data processing

In [None]:
with open (f'Experimental_data/remove_leakage_ldpc_test_result_logical_X_cycle_num={num_cycles}','rb') as f:
    data_x = pickle.load(f)
data_x.keys()

# X check matrix for the BB code [[18,4,4]] 
hx = Set_decoder_para_18_6_3["hx"] ;
lx = Set_decoder_para_18_6_3["lx"] ;

# data qubit的初态
random_initial_state_x = data_x['random_initial_state']
num_initial_x = random_initial_state_x.shape[0]
# print(random_initial_state_x[0])

# 按照L0...L8,R0...R8的顺序排列
for iter in range( num_initial_x ) :
    random_initial_state_x[iter] = dict(sorted(random_initial_state_x[iter].items(), key=lambda x: (x[0][2:])))

#将 dict 转化为 list
list_initial_x = [] ;
for iter in range( num_initial_x  ) :
    list_initial_x.append( list(random_initial_state_x[iter].values()) )

In [None]:
# record the X parity check outcomes and the logical Z outcomes
dict_cycle_x_parity = {} ;
dict_cycle_x_devents = {}
for i in range(1, num_cycles+1):
    dict_cycle_x_parity[f'cycle_{i}'] = [] ;
    dict_cycle_x_devents[f'cycle_{i}'] = [] ;

dict_cycle_x_parity['final'] = [] ;
dict_cycle_x_devents['final'] = [] ;

final_logical_x_outcome = [] ;

for s in range(num_initial_x):

    #第 s 个初始态
    initial_x_parity = hx @ np.array(list_initial_x[s]).T % 2

    num_trials_x = len(data_x['states']['cycle_1'][f'q_X{1}'][s]) ; 
    
    initial_logical_x_outcome = (lx @ list_initial_x[s] ) % 2

    for p in range(num_trials_x):
    
        dict_cycle_x_parity['cycle_1'].append( np.array( [data_x['states']['cycle_1'][f'q_X{i}'][s][p] for i in index_X_check  ] )  )
    
        for cycle in range(2, 1+num_cycles):
            current = np.array([data_x['states'][f'cycle_{cycle}'][f'q_X{i}'][s][p] for i in index_X_check  ]) ;
            former = np.array([data_x['states'][f'cycle_{cycle-1}'][f'q_X{i}'][s][p] for i in index_X_check  ]) ;    
            dict_cycle_x_parity[f'cycle_{cycle}'].append( (  current + former ) %2  )
        
        final_state_x = [ data_x['final_cycle_states'][f'q_L{i}'][s][p]  for i in range(num_L) ] + \
                     [ data_x['final_cycle_states'][f'q_R{i}'][s][p]  for i in range(num_R) ]  ;
    
        dict_cycle_x_parity['final'].append( hx @ np.array(final_state_x).T % 2 ) ;   
    
        final_logical_x_outcome.append(   (lx @ final_state_x - initial_logical_x_outcome) % 2   )

        # record the deventions, a error devent occur if the current parity is different from the former one
##---------------------------------------------------------------------------------------------------------------------         
        dict_cycle_x_devents['cycle_1'].append( (dict_cycle_x_parity['cycle_1'][-1]  + initial_x_parity) %2  )
        
        for cycle in range(2, num_cycles+1):

            dict_cycle_x_devents[f'cycle_{cycle}'].append( (dict_cycle_x_parity[f'cycle_{cycle}'][-1] + \
                                                            dict_cycle_x_parity[f'cycle_{cycle-1}'][-1]) %2 ) ;

        dict_cycle_x_devents['final'].append( (dict_cycle_x_parity['final'][-1] + dict_cycle_x_parity[f'cycle_{num_cycles}'][-1]) %2  )

In [None]:
syndrome_history_X = [np.hstack((arr)) for arr in zip(*dict_cycle_x_devents.values())] ;
# final_logical_x_outcome
num_detect_x = len(syndrome_history_X[0]) ;
num_instance_x = len(syndrome_history_X)

### error detections

In [None]:
labels_x = ['X0', 'X1', 'X3', 'X4', 'X6', 'X7']

x_list_detect_prob = [np.mean(np.mean(dict_cycle_x_devents[f'cycle_{cycle}'], axis=0)) for cycle in range(1, num_cycles+1)]
x_list_detect_prob += [np.mean(np.mean(dict_cycle_x_devents['final'], axis=0)) ]

for cycle, prob in  zip(range(1, num_cycles+2), x_list_detect_prob):
    print( f'The mean probability of error detection in cycle {cycle} is:', prob )

plt.plot(range(1, num_cycles+2), x_list_detect_prob)

### Correlation matrix $p_{ij}$ for error detection events

In [None]:
X_Set_averages = {
    f"x({s1},{s2})": np.mean([np.prod([syndrome_history_X[trial][s1], syndrome_history_X[trial][s2]]) for trial in range(num_trials_x)])
    for s1 in range(num_detect_x-1)
    for s2 in range(s1, num_detect_x)  }

X_Set_averages.update({
    f"x({s3})": np.mean([syndrome_history_X[trial][s3] for trial in range(num_trials_x)])
    for s3 in range(num_detect_x)   })

In [None]:
X_Correlation_matrix = np.zeros( (num_detect_x, num_detect_x) ) ;

for i in range(num_detect_x-1):
    for j in range(i+1, num_detect_x):
        X_Correlation_matrix[i,j] = p_ij(X_Set_averages[f"x({i})"], X_Set_averages[f"x({j})"], X_Set_averages[f"x({i},{j})"])
        X_Correlation_matrix[j,i] = X_Correlation_matrix[i,j] ;

In [None]:
# i = 6; j = 11

In [None]:
plt.figure(figsize=(8, 6))
cax = plt.imshow(X_Correlation_matrix, cmap='YlOrRd', interpolation='nearest')
plt.colorbar(cax)
plt.gca().invert_yaxis() 
plt.title('Correlation matrix $p_{ij}$', fontsize=16)
plt.xlabel('X Check index i', fontsize=12)
plt.ylabel('X Check index j', fontsize=12)

In [None]:
# print(tabulate(X_Correlation_matrix))

## error decoding

In [None]:
# correct errors for logical X (X checks)
X_no_correction_good_trials, X_no_correction_good_list, X_good_trials, X_good_list =     \
            error_decoding(num_instance_x, k, bpdX, HX, syndrome_history_X, final_logical_x_outcome )

In [None]:
print("no error correction for logical X basis:")
print(f'Logical error over {num_cycles+1} cycles (six logical qubits):', 1-X_no_correction_good_trials/num_instance_x)
print(f'Logical error over {num_cycles+1} cycles (single logical qubit):', 1-X_no_correction_good_list/num_instance_x)
print("\n")
print("error correction for logical X basis:")
print(f'Logical error over {num_cycles+1} cycles (six logical qubits):', 1-X_good_trials/num_instance_x)
print(f'Logical error over {num_cycles+1} cycles (single logical qubit):', 1-X_good_list/num_instance_x)