# Coherent info

In [1]:
import numpy as np
from collections import Counter
from typing import Tuple, Iterable, Dict, List
from coherentinfo.moebius import MoebiusCode
from coherentinfo.linalg import (finite_field_gauss_jordan_elimination,
                                 finite_field_matrix_rank,
                                 finite_field_inverse,
                                 is_prime)

from numpy.typing import NDArray

from coherentinfo.errormodel import ErrorModel, ErrorModelPoisson
# import galois
import scipy
import timeit
import jax
import jax.numpy as jnp

In [2]:
length = 31
width = 31
p = 3
moebius_code = MoebiusCode(length=length, width=width, d=2 * p)
h_z = moebius_code.h_z
h_x = moebius_code.h_x
logical_x = moebius_code.logical_x
logical_z = moebius_code.logical_z

# Vertex syndrome

Now we want to understand how to construct a candidate vector given a certain vertex syndrome. We start with a basic example. Note that the syndrome is mod 2 * p

In [3]:
for _ in range(1000):
    error = np.random.randint(2 * p, size=moebius_code.num_edges)
    syndrome = moebius_code.h_z @ error.T % (2 * p)
    candidate_error = moebius_code.get_vertex_candidate_error(syndrome)
    syndrome_candidate = moebius_code.h_z @ candidate_error.T % (2 * p)
    if np.count_nonzero(syndrome - syndrome_candidate) != 0:
        print("Syndromes do not match")
        break

In [4]:
error = np.random.randint(2 * p, size=moebius_code.num_edges)
candidate_error = moebius_code.get_vertex_candidate_error(syndrome)
error_diff = error - candidate_error
res_com = error_diff @ logical_z.T % (2 * p)
print(res_com == 0 or res_com == p)

True


# Plaquette Syndrome

Now we want to understand how to construct a candidate vector given a certain plaquette syndrome. We start with a basic example. Note that the syndrome is mod 2 * p

In [5]:
for _ in range(100):
    error = np.random.randint(2 * p, size=moebius_code.num_edges)
    syndrome = moebius_code.h_x @ error.T % (2 * p)
    candidate_error = moebius_code.get_plaquette_candidate_error(syndrome)
    syndrome_candidate = moebius_code.h_x @ candidate_error.T % (2 * p)
    if np.count_nonzero(syndrome - syndrome_candidate) != 0:
        print("Syndromes do not match")
        break

error = np.random.randint(2 * p, size=moebius_code.num_edges)
candidate_error = moebius_code.get_plaquette_candidate_error(syndrome)
error_diff = error - candidate_error
print(error_diff @ logical_x.T % (2 * p))

0


# Error model

In [6]:
gamma = 0.01
num_errors = 100
poisson_em = ErrorModelPoisson(num_errors, 2 * p, gamma)

In [7]:
probs = poisson_em.get_probabilities()
print(probs)

[9.80198673e-01 1.96039735e-02 1.96039735e-04 1.30693156e-06
 6.53465782e-09 2.61386313e-11]


In [8]:
my_error = poisson_em.generate_random_error()

In [9]:
sampled_frequencies = np.array([
    np.argwhere(my_error == x).shape[0] / num_errors
    for x in range(2 * p)
])
print(sampled_frequencies)

[0.99 0.01 0.   0.   0.   0.  ]


# Compute results conditional entropy

In [10]:
poisson_em_moebius = ErrorModelPoisson(moebius_code.num_edges, 2 * p, gamma)
num_samples = 1000
result = moebius_code.compute_vertex_conditional_entropy(
    num_samples=num_samples, error_model=poisson_em_moebius)
print(result)

0.0


In [11]:
coherent_info = moebius_code.compute_coherent_information(
    num_samples=num_samples, error_model=poisson_em_moebius)
print("Coherent Information: {}".format(coherent_info))

Coherent Information: 1.0


In [13]:
# 2. Use lambda to wrap the method call on the instance
execution_time = timeit.timeit(
    # The lambda function is the callable that timeit executes
    stmt=lambda: moebius_code.compute_plaquette_conditional_entropy(
        num_samples=100000, error_model=poisson_em_moebius),
    number=1  # Run the entire thing 10 times
)

print(f"Time taken for 1 run: {execution_time} seconds")
# Note: The result of the method call itself is ignored by timeit.timeit,
# which only records the execution time.

Time taken for 1 run: 15.30926394700009 seconds


# JAX



In [13]:
def syndrome_jax(error, h_x_jax):
    return jnp.mod(h_x_jax @ error.T, 2 * p)

In [23]:
h_x_jax = jnp.array(h_x)
errors = []
num_errors = 100009
for _ in range(num_errors - 1):
    error = poisson_em_moebius.generate_random_error()
    errors.append(error)
errors = jnp.array(errors, dtype=np.int16)
jitted_syndrome = jax.jit(syndrome_jax)

In [24]:
jitted_syndrome(errors[0], h_x_jax)

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0,
       0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0,

In [25]:
%timeit syndrome_jax(errors[0], h_x_jax).block_until_ready()

155 μs ± 5.76 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [26]:
jitted_syndrome(errors[0], h_x_jax)
%timeit jitted_syndrome(errors[0], h_x_jax).block_until_ready()

130 μs ± 6.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [27]:
from functools import partial
syndrome_new = partial(syndrome_jax, h_x_jax=h_x_jax)
auto_syndrome = jax.vmap(syndrome_new, in_axes=0)

In [28]:
%timeit auto_syndrome(errors).block_until_ready()

KeyboardInterrupt: 