In [1]:
import sys, os
sys.path.append(os.path.abspath(
    os.getcwd()+"/core")
)

In [2]:
from core.quantum_error_correction_code import SurfaceCode
from core.neural_network import CNNDecoder, CNNDual, load_params
from core.perfect_maximum_likelihood_decoder import PMLD

import jax.numpy as jnp
from jax import random, vmap, jit

In [3]:
distances = [3, 5, 7]

p, nu = .01, 500
ERROR_PROBS = jnp.array([
    1 / (2 * (nu + 1)) * p,
    1 / (2 * (nu + 1)) * p,
    nu / (nu + 1) * p,
])

In [4]:
def get_data(
    data_key,
    code: SurfaceCode,
    batch_size: int,
    parity_info: tuple[jnp.ndarray],
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, any]:
    data_key = random.split(data_key, num=batch_size+1)
    errors = vmap(
        code.error,
        in_axes=(0, None),
        out_axes=0
    )(data_key[:-1], ERROR_PROBS)
    imgs, logicals = vmap(
        code.syndrome_img,
        in_axes=(0, None),
        out_axes=0
    )(errors, parity_info)
    syndromes, _ = vmap(
        code.syndrome,
        in_axes=(0, None),
        out_axes=0
    )(errors, parity_info)
    return imgs[:,None,:,:], syndromes, logicals, data_key[-1]

def logicals_of_recovery(
    code: SurfaceCode,
    recovery: jnp.ndarray,
    parity_info: tuple[jnp.ndarray],
) -> jnp.ndarray:
    _, logicals = vmap(
        code.syndrome,
        in_axes=(0, None),
        out_axes=0
    )(recovery, parity_info)
    return logicals

In [5]:
def evaluate_nn_decoder_batch(
    data_key: jnp.ndarray,
    decoder: CNNDual | CNNDecoder,
    model_params: jnp.ndarray,
    code: SurfaceCode,
    deformation: jnp.ndarray,
    batch_size: int,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    parity_info = code.deformation_parity_info(deformation)
    syndrome_imgs, syndromes, logicals, data_key = get_data(
        data_key,
        code,
        batch_size,
        parity_info
    )
    if isinstance(decoder, CNNDecoder):
        preds = decoder.apply_batch(model_params, syndrome_imgs) > 0.0
    elif isinstance(decoder, CNNDual):
        deformation_imgs = code.deformation_image(deformation)[None, :, :, :]
        preds = decoder.apply_batch(model_params, syndrome_imgs, deformation_imgs) > 0.0
    else:
        raise ValueError("Unknown decoder type", type(decoder))
    error_rate = jnp.any(logicals != preds, axis=1).mean()
    return error_rate

def evaluate_nn_decoder(
    data_key: jnp.ndarray,
    decoder: CNNDual | CNNDecoder,
    model_params: jnp.ndarray,
    code: SurfaceCode,
    deformation: jnp.ndarray,
    batch_size: int,
    num_errors: int,
):
    error_rates = jnp.empty(num_errors // batch_size)
    keys = random.split(data_key, error_rates.shape[0])
    for i in range(error_rates.shape[0]):
        error_rates = error_rates.at[i].set(evaluate_nn_decoder_batch(
            keys[i],
            decoder,
            model_params,
            code,
            deformation,
            batch_size
        ))
    return error_rates.mean()

In [None]:
table = jnp.empty((5*len(distances), 4), dtype=jnp.float32)

batch_size=10_000
num_errors=1_000_000

for l, L in enumerate(distances):
    code = SurfaceCode(L)

    # Use the same key every time to ensure that all the decoders see the same set of errors and thus ensure fair comparison between decoders
    key = random.key(723)

    for m, decoder in enumerate(["PML", "MWPM", "CNN", "CNN-S", "CNN-G"]):
        print(f"\nDistance {L} with decoder {decoder} ", end='')
        for n, deformation in enumerate([
            jnp.zeros(L**2, dtype=jnp.int32),
            jnp.zeros(L**2, dtype=jnp.int32).at[::2].set(3),
            jnp.zeros(L**2, dtype=jnp.int32).at[:].set(2),
            jnp.zeros((L, L), dtype=jnp.int32).at[1::2, ::2].set(3).flatten().at[::2].set(2)
        ]):
            print(".", end='')
            
            parity_info = code.deformation_parity_info(deformation)

            if decoder == "PML" and L == 3:
                perfect_decoder = PMLD(code, ERROR_PROBS, parity_info)
                table = table.at[5*l+m, n].set(perfect_decoder.exact_logical_error_rate())
            elif decoder == "CNN":
                model_name = f"data/CNN-{L}-{''.join([str(d) for d in deformation])}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    table = table.at[5*l+m, n].set(jnp.nan)
                    continue
                model = CNNDecoder(
                    input_shape = (1, L+1, L+1),
                    conv_layers = [tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS"]],
                    fc_layers = [int(v) for v in settings["FC_LAYERS"]],
                )
                table = table.at[5*l+m, n].set(evaluate_nn_decoder(
                    key,
                    model,
                    model_params,
                    code,
                    deformation,
                    batch_size,
                    num_errors
                ))
            elif decoder == "CNN-S":
                model_name = f"data/CNN-S-{L}-{''.join([str(d) for d in deformation])}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    print("(FileNotFoundError)", end="")
                    table = table.at[5*l+m, n].set(jnp.nan)
                    continue
                model = CNNDual(
                    input_shape_1=(1, L+1, L+1),
                    input_shape_2=(1, L, L),
                    conv_layers_input_1=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_INPUT_1"]],
                    conv_layers_input_2=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_INPUT_2"]],
                    conv_layers_stage_2=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_STAGE_2"]],
                    fc_layers=[int(v) for v in settings["FC_LAYERS"]],
                )
                table = table.at[5*l+m, n].set(evaluate_nn_decoder(
                    key,
                    model,
                    model_params,
                    code,
                    deformation,
                    batch_size,
                    num_errors
                ))
            elif decoder == "CNN-G":
                model_name = f"data/CNN-G-{L}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    table = table.at[5*l+m, n].set(jnp.nan)
                    continue
                model = CNNDual(
                    input_shape_1=(1, L+1, L+1),
                    input_shape_2=(1, L, L),
                    conv_layers_input_1=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_INPUT_1"]],
                    conv_layers_input_2=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_INPUT_2"]],
                    conv_layers_stage_2=[tuple(int(v) for v in vals) for vals in settings["CONV_LAYERS_STAGE_2"]],
                    fc_layers=[int(v) for v in settings["FC_LAYERS"]],
                )
                table = table.at[5*l+m, n].set(evaluate_nn_decoder(
                    key,
                    model,
                    model_params,
                    code,
                    deformation,
                    batch_size,
                    num_errors
                ))
            else:
                table = table.at[5*l+m, n].set(jnp.nan)
                continue
table


Distance 3 with decoder PML ....
Distance 3 with decoder MWPM ....
Distance 3 with decoder CNN ....
Distance 3 with decoder CNN-S ....
Distance 3 with decoder CNN-G ....
Distance 5 with decoder PML ....
Distance 5 with decoder MWPM ....
Distance 5 with decoder CNN ....
Distance 5 with decoder CNN-S ....
Distance 5 with decoder CNN-G ....
Distance 7 with decoder PML ....
Distance 7 with decoder MWPM ....
Distance 7 with decoder CNN ....
Distance 7 with decoder CNN-S ....
Distance 7 with decoder CNN-G ....

Array([[1.72710419e-03, 3.53634357e-04, 8.92877579e-05, 2.49743462e-05],
       [           nan,            nan,            nan,            nan],
       [1.78699999e-03, 3.70999973e-04, 9.79999968e-05, 2.29999969e-05],
       [1.78200006e-03, 3.84000014e-04, 9.79999968e-05, 2.59999979e-05],
       [1.82300003e-03, 3.62000021e-04, 3.87999957e-04, 1.49999978e-04],
       [           nan,            nan,            nan,            nan],
       [           nan,            nan,            nan,            nan],
       [3.18000006e-04, 3.19999963e-05, 1.08999986e-04, 1.29999989e-05],
       [2.92999961e-04, 3.49999973e-05, 2.99999992e-05, 7.99999998e-06],
       [3.12999997e-04, 5.79999942e-05, 1.81999989e-04, 7.49999890e-05],
       [           nan,            nan,            nan,            nan],
       [           nan,            nan,            nan,            nan],
       [7.69999897e-05, 1.79999988e-05, 2.29999991e-04, 2.99999992e-05],
       [1.16999981e-04, 2.09999980e-05, 1.72000000e

In [10]:
import pandas as pd

df = pd.DataFrame(table)
df.columns = ["CSS", "XZZX", "XY", "C1"]
df.index = [f"{decoder} - {L}" for L in distances for decoder in ["PML", "MWPM", "CNN", "CNN-S", "CNN-G"]]
df

Unnamed: 0,CSS,XZZX,XY,C1
PML - 3,0.001727,0.000354,8.9e-05,2.5e-05
MWPM - 3,,,,
CNN - 3,0.001787,0.000371,9.8e-05,2.3e-05
CNN-S - 3,0.001782,0.000384,9.8e-05,2.6e-05
CNN-G - 3,0.001823,0.000362,0.000388,0.00015
PML - 5,,,,
MWPM - 5,,,,
CNN - 5,0.000318,3.2e-05,0.000109,1.3e-05
CNN-S - 5,0.000293,3.5e-05,3e-05,8e-06
CNN-G - 5,0.000313,5.8e-05,0.000182,7.5e-05


In [11]:
def num_to_latex(num):
    if jnp.isnan(num):
        return ""
    if num == 0:
        return "0.0"
    num = float(f"{num:.2E}")
    val, power = f"{num:.2E}".split("E")
    val = float(val)
    power = int(power)
    return f"${val:.2f}\\times10^{{{power}}}$"

latex_code = "\\toprule\n"
latex_code += "    $L$&Decoder&CSS&  XZZX&  XY& C1\\\\\n"
latex_code += "\\midrule\n"

decoders = ["PML", "MWPM", "CNN", "CNN-S", "CNN-G"]
for l, L in enumerate(distances):
    for m, decoder in enumerate(decoders):
        if decoder == "PML" and L != 3:
            continue
        row = f"    {L}&{decoder}"
        for n in range(4):  # Iterate over CSS, XZZX, XY, C1
            row += f"& {num_to_latex(table[5 * l + m, n])}"
        row += "\\\\\n"
        latex_code += row
    if l < len(distances) - 1:
        latex_code += "\\midrule\n"

latex_code += "\\bottomrule"
print(latex_code)

\toprule
    $L$&Decoder&CSS&  XZZX&  XY& C1\\
\midrule
    3&PML& $1.73\times10^{-3}$& $3.54\times10^{-4}$& $8.93\times10^{-5}$& $2.50\times10^{-5}$\\
    3&MWPM& & & & \\
    3&CNN& $1.79\times10^{-3}$& $3.71\times10^{-4}$& $9.80\times10^{-5}$& $2.30\times10^{-5}$\\
    3&CNN-S& $1.78\times10^{-3}$& $3.84\times10^{-4}$& $9.80\times10^{-5}$& $2.60\times10^{-5}$\\
    3&CNN-G& $1.82\times10^{-3}$& $3.62\times10^{-4}$& $3.88\times10^{-4}$& $1.50\times10^{-4}$\\
\midrule
    5&MWPM& & & & \\
    5&CNN& $3.18\times10^{-4}$& $3.20\times10^{-5}$& $1.09\times10^{-4}$& $1.30\times10^{-5}$\\
    5&CNN-S& $2.93\times10^{-4}$& $3.50\times10^{-5}$& $3.00\times10^{-5}$& $8.00\times10^{-6}$\\
    5&CNN-G& $3.13\times10^{-4}$& $5.80\times10^{-5}$& $1.82\times10^{-4}$& $7.50\times10^{-5}$\\
\midrule
    7&MWPM& & & & \\
    7&CNN& $7.70\times10^{-5}$& $1.80\times10^{-5}$& $2.30\times10^{-4}$& $3.00\times10^{-5}$\\
    7&CNN-S& $1.17\times10^{-4}$& $2.10\times10^{-5}$& $1.72\times10^{-4}$& $4.20\times