In [3]:
# IMPORTS
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import scienceplots
plt.style.use('science')
plt.rcParams['figure.dpi'] = 300
import matplotlib.colors as colors
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=[colors.hex2color("#66c2a5"), colors.hex2color("#fc8d62"), colors.hex2color("#8da0cb"), colors.hex2color("#e78ac3"), colors.hex2color( "#a6d854")])
import torch.nn as nn

from src.rotated_surface_code import RotatedCode
from src.graph_representation import get_torch_graph

# Create error chains and syndromes

In [6]:
d = 5
code = RotatedCode(d)
H_check = code.H
p = 0.1

In [8]:
def generate_error_chain(p, code_size):
    n_qubits = int(code_size ** 2)
    p_x = p_y = p_z = p/3 # Unbiased noise
    noise = np.zeros(n_qubits, dtype = np.uint8)
    draws = np.random.random(n_qubits)
    noise[(draws <= p_x)] = 1
    noise[(p_x < draws) & (draws <= p_x+p_y)] = 2
    noise[(p_x+p_y < draws) & (draws <= p_x+p_y+p_z)] = 3
    return noise

In [9]:
def get_syndrome(error_string, H_check, code_size):
    # Create separate binary strings for X and Z errors
    error_string_x = (error_string == 1) + (error_string == 2)
    error_string_z = (error_string == 3) + (error_string == 2)

    # Re-order X-errors to get correct qubit indices for Z-stabilizers
    error_string_x = np.rot90(error_string_x.reshape(code_size, code_size)).flatten()

    # Since they are identical, use same check matrix for both Z and X
    # Create separate syndrome strings for X and Z stabilizers
    # Note: X errors are checked by Z stabilizers and vice versa
    syndrome_string_x = H_check@error_string_z % 2
    syndrome_string_z = (H_check@error_string_x % 2) * 3
    
    # convert syndrome strings to matrix:
    M = code_size + 1
    syndrome_matrix_X = np.zeros((M, M), dtype=np.uint8)
    syndrome_string_x = syndrome_string_x.reshape(M, - 1)
    syndrome_matrix_X[::2, 2::2] = syndrome_string_x[::2]
    syndrome_matrix_X[1::2, 1:M - 2:2] = syndrome_string_x[1::2]

    syndrome_matrix_Z = np.zeros((M, M), dtype=np.uint8)
    syndrome_string_z = syndrome_string_z.reshape(M, - 1)
    syndrome_matrix_Z[::2, 2::2] = syndrome_string_z[::2]
    syndrome_matrix_Z[1::2, 1:M - 2:2] = syndrome_string_z[1::2]
    # Combine syndrome matrices where 1 entries 
    # correspond to x and 3 entries to z defects
    syndrome_matrix = syndrome_matrix_X + np.rot90(syndrome_matrix_Z, - 1)
    
    # Return the syndrome matrix
    return syndrome_matrix

def get_eq_class(error, code_size):
    error_matrix = error.reshape(code_size, code_size)
    # Count number of Z-errors on west edge
    Z_count = np.count_nonzero( 
        (error_matrix[:,0] == 3) + (error_matrix[:,0] == 2) ) 
    # Count number of X-errors on north edge
    X_count = np.count_nonzero(
        (error_matrix[0,:] == 1) + (error_matrix[0,:] == 2) ) 
    # Determine equivalence class from parity of Z-errors on west edge
    # and X-errors on north edge
    if Z_count % 2 == 0:
        if X_count % 2 == 0:
            return np.array([0, 0]) # Equivalence class I
        else:
            return np.array([1, 0]) # Equivalence class X
    else:
        if X_count % 2 == 0:
            return np.array([0, 1]) # Equivalence class Z
        else:
            return np.array([1, 1]) # Equivalence class Y

In [15]:
error_chain = generate_error_chain(p, d)
# error_chain = np.ones_like(error_chain)
# Determine the syndrome consistent with the error
syndrome = get_syndrome(error_chain, H_check, d)
# Determine the equivalence class of the error chain
true_eq_class = get_eq_class(error_chain, d)


In [16]:
syndrome

array([[0, 0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]], dtype=uint8)

# Create matching graph with virtual nodes closest to each node

In [18]:
graph = get_torch_graph(syndrome, true_eq_class)

In [19]:
graph

[array([[0., 1., 1., 0.],
        [0., 1., 3., 0.]], dtype=float32),
 array([[0, 1],
        [1, 0]]),
 array([[0.5],
        [0.5]], dtype=float32),
 array([[0., 0.]], dtype=float32)]