In [1]:
import sys, os
sys.path.append(os.path.abspath(
    os.getcwd()+"/core")
)

In [2]:
from core.quantum_error_correction_code import SurfaceCode
from core.neural_network import CNNDecoder, CNNDual, load_params
from core.perfect_maximum_likelihood_decoder import PMLD

import jax.numpy as jnp
from jax import random, vmap, jit

In [3]:
distances = [3, 5, 7, 9]

p, nu = .01, 500
ERROR_PROBS = jnp.array([
    1 / (2 * (nu + 1)) * p,
    1 / (2 * (nu + 1)) * p,
    nu / (nu + 1) * p,
])

In [4]:
def get_data(
    data_key,
    code: SurfaceCode,
    batch_size: int,
    parity_info: tuple[jnp.ndarray],
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, any]:
    data_key = random.split(data_key, num=batch_size+1)
    errors = vmap(
        code.error,
        in_axes=(0, None),
        out_axes=0
    )(data_key[:-1], ERROR_PROBS)
    imgs, logicals = vmap(
        code.syndrome_img,
        in_axes=(0, None),
        out_axes=0
    )(errors, parity_info)
    syndromes, _ = vmap(
        code.syndrome,
        in_axes=(0, None),
        out_axes=0
    )(errors, parity_info)
    return imgs[:,None,:,:], syndromes, logicals, data_key[-1]

def logicals_of_recovery(
    code: SurfaceCode,
    recovery: jnp.ndarray,
    parity_info: tuple[jnp.ndarray],
) -> jnp.ndarray:
    _, logicals = vmap(
        code.syndrome,
        in_axes=(0, None),
        out_axes=0
    )(recovery, parity_info)
    return logicals

In [None]:
table = []

for L in distances:
    code = SurfaceCode(L)

    # Use the same key every time to ensure that all the decoders see the same set of errors and thus ensure fair comparison between decoders
    key = random.key(723)

    for decoder in ["PML", "MWPM", "CNN", "CNN-S", "CNN-G"]:
        print(f"\nDistance {L} with decoder {decoder} ", end='')
        table.append([])
        for deformation in [
            jnp.zeros(L**2, dtype=jnp.int32),
            jnp.zeros(L**2, dtype=jnp.int32).at[::2].set(3),
            jnp.zeros(L**2, dtype=jnp.int32).at[:].set(2),
            jnp.zeros((L, L), dtype=jnp.int32).at[1::2, ::2].set(3).flatten().at[::2].set(2)
        ]:
            print(".", end='')
            table[-1].append(0)
            
            parity_info = code.deformation_parity_info(deformation)
            imgs, syndromes, logicals, key = get_data(key, code, 1000, parity_info)

            if decoder == "PML" and L == 3:
                perfect_decoder = PMLD(code, ERROR_PROBS, parity_info)
                decoder_logicals = perfect_decoder.decode_batch(syndromes)
            elif decoder == "CNN":
                model_name = f"data/CNN-{L}-{''.join([str(d) for d in deformation])}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    table[-1][-1] = None
                    continue
                decoder = CNNDecoder(
                    input_shape = (1, L+1, L+1),
                    conv_layers = jnp.array(settings["CONV_LAYERS"]),
                    fc_layers = jnp.array(settings["FC_LAYERS"]),
                )
            elif decoder == "CNN-S":
                model_name = f"data/CNN-S-{L}-{''.join([str(d) for d in deformation])}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    table[-1][-1] = None
                    continue
            elif decoder == "CNN-G":
                model_name = f"data/CNN-G-{L}.json"
                try:
                    settings, model_params = load_params(model_name)
                except FileNotFoundError:
                    table[-1][-1] = None
                    continue
            else:
                table[-1][-1] = None
                continue

            table[-1][-1] += (logicals != decoder_logicals).any(axis=1).mean()


Distance 3 with decoder PML ....
Distance 3 with decoder MWPM ....
Distance 3 with decoder CNN ....
Distance 3 with decoder CNN-S ....
Distance 3 with decoder CNN-G ....
Distance 5 with decoder PML ....
Distance 5 with decoder MWPM ....
Distance 5 with decoder CNN ....
Distance 5 with decoder CNN-S ....
Distance 5 with decoder CNN-G ....
Distance 7 with decoder PML ....
Distance 7 with decoder MWPM ....
Distance 7 with decoder CNN ....
Distance 7 with decoder CNN-S ....
Distance 7 with decoder CNN-G ....
Distance 9 with decoder PML ....
Distance 9 with decoder MWPM ....
Distance 9 with decoder CNN ....
Distance 9 with decoder CNN-S .

FileNotFoundError: [Errno 2] No such file or directory: 'data/CNN-S-9-000000000000000000000000000000000000000000000000000000000000000000000000000000000.json'

In [None]:
import pandas as pd

df = pd.DataFrame(table)
df.columns = ["CSS", "XZZX", "XY", "C1"]
df.index = [f"{decoder} - {L}" for L in distances for decoder in ["PML", "MWPM", "CNN", "CNN-S", "CNN-G"]]
df

Unnamed: 0,CSS,XZZX,XY,C1
PML - 3,0.002,0.002,0.0,0.0
MWPM - 3,,,,
CNN - 3,,,,
CNN-S - 3,,,,
CNN-G - 3,,,,
PML - 5,,,,
MWPM - 5,,,,
CNN - 5,,,,
CNN-S - 5,,,,
CNN-G - 5,,,,
