# Worm Moebius qubit

In [1]:
import numpy as np
from coherentinfo.moebius_qubit import MoebiusCodeQubit
from coherentinfo.errormodel import ErrorModelBernoulli
from coherentinfo.worm_qubit import run_worm_plaquette_qubit
import scipy
import timeit
import jax
import jax.numpy as jnp
from typing import Tuple

In [7]:
length = 11
width = 11
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code_qubit.h_z
h_x = moebius_code_qubit.h_x
logical_x = moebius_code_qubit.logical_x
logical_z = moebius_code_qubit.logical_z
p_error = 0.1
em_moebius_qubit = ErrorModelBernoulli(
    moebius_code_qubit.num_edges, 2, p_error
)

base_key = jax.random.PRNGKey(434)
initial_error = em_moebius_qubit.generate_random_error(base_key)
# Here we consider the full syndrome including the plaquette
# we usually remove because of the constraint as this simplified the
# coding of the worm algorithm. In fact, in this the syndromes will
# always be annihilated in pairs, and the total number of syndromes is
# always even as one can check numerically.
syndrome = moebius_code_qubit.get_full_plaquette_syndrome(initial_error)
num_plaquette, num_edges = h_x.shape

Our first goal is given a vertex syndrome as above to run the worm algorithm to find a candidate errors. We want this to be compatible with JAX

In [None]:
max_worms = jnp.int((h_x.shape[0] + 1) / 2)
max_steps_per_worm = 1000
static_argnames = ["num_plaquette", "num_edges",
                   "p_error", "max_worms", "max_steps_per_worm"]
jitted_run_worm_plaquette_qubit = jax.jit(
    run_worm_plaquette_qubit, static_argnames=static_argnames)
new_error, new_syndrome = jitted_run_worm_plaquette_qubit(
    syndrome, base_key, jnp.mod(h_x, 2), num_plaquette, num_edges, p_error, max_worms, max_steps_per_worm)

In [17]:
print("Initial number of nonzero syndromes: {}".format(jnp.sum(syndrome)))
print("Final number of nonzero syndromes after the worm run: {}".format(
    jnp.sum(new_syndrome)))
new_error_syndrome = moebius_code_qubit.get_full_plaquette_syndrome(new_error)
is_same_syndrome = jnp.any(new_error_syndrome == syndrome)
print("Success: {}".format(is_same_syndrome))

Initial number of nonzero syndromes: 36
Final number of nonzero syndromes after the worm run: 0
Success: True
