# Worm Moebius qubit

In [20]:
import numpy as np
from coherentinfo.moebius_qubit import MoebiusCodeQubit


from coherentinfo.errormodel import ErrorModelBernoulli
import scipy
import timeit
import jax
import jax.numpy as jnp

In [121]:
length = 5
width = 5
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code_qubit.h_z
h_x = moebius_code_qubit.h_x
logical_x = moebius_code_qubit.logical_x
logical_z = moebius_code_qubit.logical_z
p_error = 0.1
em_moebius_qubit = ErrorModelBernoulli(
    moebius_code_qubit.num_edges, 2, p_error
)

base_key = jax.random.PRNGKey(109)
initial_error = em_moebius_qubit.generate_random_error(base_key)
# Here we consider the full syndrome including the plaquette
# we usually remove because of the constraint as this simplified the
# coding of the worm algorithm. In fact, in this the syndromes will
# always be annihilated in pairs, and the total number of syndromes is
# always even as one can check numerically.
syndrome = moebius_code_qubit.get_full_plaquette_syndrome(initial_error)
jnp.sum(syndrome)

Array(6, dtype=int32)

Our first goal is given a vertex syndrome as above to run the worm algorithm to find a candidate errors. We want this to be compatible with JAX

In [133]:
def run_worm_plaquette(
    syndrome: jax.Array,
    base_key: jax.Array,
    moebius_code_qubit: MoebiusCodeQubit,
    p_error: int,
    max_worms: int,
    max_steps_per_worm: int
) -> jax.Array:
    """We code it to work for the plaquettes first 
    because it is more challenging due to the fact that the 
    number of edges associated with one plaquette can be 3 or 4"""
    # For now I just define a p_error inside the function
    # Later we need to pass an error model
    num_plaquette = moebius_code_qubit.num_plaquette_checks
    num_edges = moebius_code_qubit.num_edges
    # master_keys = jax.random.split(base_key, max_steps_per_worm * max_worms)
    p_error = 0.7
    # H_X full plaquette stabilizers including the plaquette we usually
    # remove
    h_x = jnp.mod(moebius_code_qubit.h_x, 2)

    def random_edge_boundary(key):
        return jax.random.randint(key, 1, 0, 3)

    def random_edge_bulk(key):
        return jax.random.randint(key, 1, 0, 4)

    # Note this is the standard form of a function to be passed
    # to jax.lax.scan. Similarly, this function needs to return
    # a tuple of two elements, where the second one is the output
    # to be collected at each step, which is needed only to track
    # how the computation goes
    def worm_step(carry_worm_step, x):
        # Get the current state
        error, syndrome, head, tail, continue_worm, key = carry_worm_step
        # Split the key into a new key that will get passed at
        # the end and a subkey that will be used in this step
        key, subkey = jax.random.split(key)
        # Etract the locations, i.e., plaquette stabilizers
        # with nonzero syndrome. We pad -1 to make it compatible
        # with jax.jit.

        # Note that the following approach works only in the case where
        # the error is the same for every edge. It would have to be modified
        # if that is not the case.

        def reject(args):
            error, syndrome, head, continue_worm = args
            return error, syndrome, head, continue_worm

        def accept(args):
            error, syndrome, head, continue_worm = args

            syndrome_locations = jnp.nonzero(
                syndrome, size=num_plaquette, fill_value=-1)[0]
            # We extract the edges that take part in the stabilizer
            # identified by head. These are in total either 3 or 4
            # dependeing on whether the stabilizer is at the rough boundary
            # or not. Again we pad with -1 in case it is 3 to make the code
            # compatible with jax.jit

            head_edges = jnp.nonzero(h_x[head], size=4, fill_value=-1)[0]
            # We sample a random integer between either 3 or 4 possible
            # values depening on whether head stabilizer has 3 or 4 edges
            random_integer = jax.lax.cond(
                head_edges[-1] == -1, random_edge_boundary, random_edge_bulk, subkey)

            # We select the candidate edge based on the random integer we samples
            candidate_edge = head_edges[random_integer]

            # Note that now for the vertex case we would need to check whether
            # we are at the boundary or not while this is not necessary for the plaquette
            # as syndromes will always be annihilated in pairs

            # We know that the new candidate head is associated with the
            # candidate_edge which has vertices the current head and the
            # new one. The following identifies the candidate_head in
            # a way compatible with JAX
            candidate_head_vec = jnp.nonzero(
                h_x[:, candidate_edge], size=num_edges, fill_value=-1)[0]

            def candidate_head_is_first(candidate_head_vec):
                return candidate_head_vec[0]

            def candidate_head_is_second(candidate_head_vec):
                return candidate_head_vec[1]

            candidate_head = jax.lax.cond(
                candidate_head_vec[1] == head,
                candidate_head_is_first,
                candidate_head_is_second,
                candidate_head_vec,
            )

            error = error.at[candidate_edge].set(
                jnp.mod(error[candidate_edge] + 1, 2))

            head = candidate_head
            success = jnp.logical_and(
                jnp.any(syndrome_locations == head), head != tail)

            def worm_success(args):
                syndrome, head, continue_worm = args
                continue_worm = False
                jax.debug.print("syndrome 1: {}", syndrome)
                error_syndrome = moebius_code_qubit.get_full_plaquette_syndrome(
                    error)
                syndrome = jnp.mod(syndrome + error_syndrome, 2)
                jax.debug.print("syndrome 2: {}", syndrome)

                return syndrome, head, continue_worm

            def worm_fail(args):
                syndrome, head, continue_worm = args
                return syndrome, head, continue_worm

            syndrome, head, continue_worm = jax.lax.cond(
                success, worm_success, worm_fail, (syndrome, head, continue_worm))

            return error, syndrome, head, continue_worm

        random_number = jax.random.uniform(subkey)

        condition = jnp.logical_and(random_number < p_error, continue_worm)

        error, syndrome, head, continue_worm = jax.lax.cond(
            condition, accept, reject, (error, syndrome, head, continue_worm))

        return (error, syndrome, head, tail, continue_worm, key), None

    def run_worm(carry_worm, x):
        error, syndrome, key = carry_worm

        syndrome_locations = jnp.nonzero(
            syndrome, size=num_plaquette, fill_value=-1)[0]
        jax.debug.print("syndrome locations: {}", syndrome_locations)
        # I do not think there is any problem in starting always from
        # the first non-zero syndrome.
        head = syndrome_locations[0]
        # jax.debug.print("head: {}", head)
        tail = head.copy()

        continue_worm = True
        new_error = jnp.zeros(num_edges, dtype=jnp.int32)
        initial_carry_worm_step = (
            new_error, syndrome, head, tail, continue_worm, key)

        new_error, syndrome, head, tail, continue_worm, key = \
            jax.lax.scan(worm_step, initial_carry_worm_step,
                         jnp.arange(max_steps_per_worm))[0]

        # This line seems crucial...and in fact this is wrong.
        # It is just setting the syndrome to the one generated
        # by the error. It is dead wrong.
        # syndrome = jnp.mod(h_x @ error, 2)

        error = jnp.mod(error + new_error, 2)

        return (error, syndrome, key), None

    error = jnp.zeros(num_edges, dtype=jnp.int32)
    initial_carry_worm = (error, syndrome, base_key)

    error, syndrome, key = jax.lax.scan(
        run_worm, initial_carry_worm, jnp.arange(max_worms))[0]

    return error, syndrome, key


static_argnames = ["moebius_code_qubit",
                   "p_error", "max_worms", "max_steps_per_worm"]
# jitted_run_worm_plaquette = jax.jit(
#     run_worm_plaquette, static_argnames=static_argnames)
new_error, new_syndrome, key = run_worm_plaquette(
    syndrome, base_key, moebius_code_qubit, p_error, 3, 1000)

syndrome locations: [ 0  1 10 15 16 24 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1]
syndrome 1: [1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 0 1]
syndrome 2: [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1]
syndrome locations: [ 1 15 16 24 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1]
syndrome 1: [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1]
syndrome 2: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1]
syndrome locations: [16 24 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1]
syndrome 1: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1]
syndrome 2: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [134]:
jnp.sum(syndrome)

Array(6, dtype=int32)

In [135]:
jnp.sum(new_syndrome)

Array(0, dtype=int32)

In [137]:
error_syndrome = moebius_code_qubit.get_full_plaquette_syndrome(new_error)
print(error_syndrome - syndrome)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [79]:
error_syndrome = moebius_code_qubit.get_full_plaquette_syndrome(new_error)
pippo_syndrome = jnp.mod(syndrome + error_syndrome, 2)

In [80]:
new_syndrome - pippo_syndrome

Array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0], dtype=int32)

In [47]:
new_error

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0], dtype=int32)

In [17]:
syndrome_worm_error = moebius_code_qubit.get_full_plaquette_syndrome(
    initial_error + new_error)
print(syndrome_worm_error)

[0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0
 0 0 0 0 0 0 0 0 0 1 0 0]


In [18]:
syndrome

Array([0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 1, 0, 1], dtype=int32)

In [10]:
jnp.any(syndrome - syndrome_worm_error == 0)

Array(True, dtype=bool)

In [None]:
def fun(carry, x):
    return carry + x, x


jax.lax.scan(fun, 0.0, jnp.arange(4))

(Array(6., dtype=float32, weak_type=True), Array([0, 1, 2, 3], dtype=int32))

In [55]:
head

Array(2, dtype=int32)

In [11]:
0

0

In [27]:
head_edges

Array([ 1, 21, 22, -1], dtype=int32)

In [19]:
syndrome[0]

Array(1, dtype=int32)

In [18]:
h_x[head]

Array([ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, -1, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int16)

In [11]:
head_edges

Array([ 1, 21, 22, -2], dtype=int32)

In [24]:
jnp.nonzero(h_x[head], size=4, fill_value=-2)[0]

Array([ 0, 20, 21, -2], dtype=int32)

In [8]:
h_x[head]

Array([ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, -1, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int16)

In [86]:
jnp.nonzero(h_x[head], size=4, fill_value=-1)

(Array([ 1, 21, 22, -1], dtype=int32),)

In [None]:
h_x[]

In [66]:
candidate_edge

Array([20], dtype=int32)

In [70]:
h_x[head, 22]

Array(1, dtype=int16)

In [37]:
pippo

Array([0], dtype=int32)

In [25]:
jnp.nonzero(syndrome == 1, size=syndrome.shape[0], fill_value=-1)[0]

Array([ 0,  3,  4, 14, 15, 18, 23, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1], dtype=int32)

In [23]:
syndrome[3]

Array(1, dtype=int32)

In [15]:
jitted_run_worm(syndrome, base_key, moebius_code_qubit, 100, 100)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function run_worm at /tmp/ipykernel_27040/3071337049.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument syndrome.

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [None]:
syndrome_locations = jnp.where(syndrome != 0)[0]
# I do not think there is any problem in starting always from
# the first non-zero syndrome.
head = syndrome_locations[0]
# most likely tail is not needed.
tail = head
h_z = moebius_code_qubit.h_z
head_edges = jnp.where(h_z[head, :] != 0)[0]

In [116]:
base_key = jax.random.PRNGKey(18)
jax.random.randint(base_key, 1, 0, 3)

Array([2], dtype=int32)

In [None]:
def fun_true(x):
    return x + 2


def fun_false(x):
    return x + 10


jax.lax.cond(False, fun_true, fun_false, jnp.array([0]))

Array([10], dtype=int32)

In [101]:
head_edges

Array([ 0, 19, 20, 25], dtype=int32)

In [76]:
h_x[head, :]

Array([ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, -1, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int16)

In [90]:
syndrome.shape

(24,)