In [1]:
import sys, os
sys.path.append(os.path.abspath(
    os.getcwd()+"/src")
)

In [2]:
from src.data_gen import sample_errors, sample_error_batch, noise_permutations_from_deformation
from src.ModifiedRotatedPlanarRMPSDecoder import ModifiedRotatedPlanarRMPSDecoder
from src.recursive_mwpm import recursive_mwpm, recursive_mwpm_batch
import jax.numpy as jnp
from jax import random, vmap
from qecsim.models.rotatedplanar import RotatedPlanarCode
from qecsim.models.generic import BiasedDepolarizingErrorModel
from multiprocessing import Pool, cpu_count
from functools import partial


key = random.key(0)



In [3]:
subkey, key = random.split(key)

code = RotatedPlanarCode(3,3)
noise_model = BiasedDepolarizingErrorModel(bias=10.0, axis='Y')
error_probability = 0.1

deformation = random.randint(subkey, shape=code.size, minval=0, maxval=6)
noise_permutations = noise_permutations_from_deformation(deformation)

bsv_decoder = ModifiedRotatedPlanarRMPSDecoder(chi=6)

In [4]:
batch_size = 1000
keys = random.split(key, num=batch_size+1)
subkeys, key = keys[1:], keys[0]
errors = sample_error_batch(key, batch_size, code, noise_model, error_probability, noise_permutations)
errors

Array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0]], dtype=int32)

In [5]:
get_syndromes = vmap(
    lambda code, error: code.stabilizers @ error % 2,
    in_axes=(None, 0),
    out_axes=0
)
syndromes = get_syndromes(code, errors)
syndromes

Array([[0, 0, 0, ..., 0, 0, 0],
       [1, 0, 1, ..., 0, 0, 0],
       [1, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0]], dtype=int32)

In [8]:
mwpm_recoveries = recursive_mwpm_batch(code, syndromes, noise_model, error_probability, noise_permutations)
print(mwpm_recoveries.shape)
mwpm_recoveries

(1000, 18)


Array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 1, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0]], dtype=int32)

In [9]:
bsv_recoveries = bsv_decoder.decode_batch(code, syndromes, noise_model, error_probability, noise_permutations)
print(bsv_recoveries.shape)
bsv_recoveries

(1000, 18)


array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0]], shape=(1000, 18), dtype=int32)