# Coherent info

In [3]:
import numpy as np
from collections import Counter
from typing import Tuple, Iterable, Dict, List
from coherentinfo.moebius import MoebiusCode, MoebiusCodeOddPrime, MoebiusCodeQubit
from coherentinfo.linalg import (finite_field_gauss_jordan_elimination,
                                 finite_field_matrix_rank,
                                 finite_field_inverse,
                                 is_prime)

from numpy.typing import NDArray

from coherentinfo.errormodel import ErrorModel, ErrorModelPoisson
# import galois
from functools import partial
import scipy
import timeit
import jax
import jax.numpy as jnp

In [12]:
length = 5
width = 5
p = 3
moebius_code = MoebiusCodeOddPrime(length=length, width=width, d=2 * p)
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code.h_z
h_x = moebius_code.h_x
logical_x = moebius_code.logical_x
logical_z = moebius_code.logical_z

In [16]:
gamma = 0.1

poisson_em_moebius = ErrorModelPoisson(moebius_code.num_edges, 2 * p, gamma)
em_moebius_qubit = ErrorModelPoisson(moebius_code.num_edges, 2, gamma)
my_error = em_moebius_qubit.generate_random_error()

type(moebius_code_qubit.get_vertex_syndrome(my_error))

jaxlib._jax.ArrayImpl

In [None]:
func_vertex_result = partial(
    moebius_code_qubit.compute_vertex_syndrome_chi_probabilities, error_model=em_moebius_qubit)
num_samples_vec = jnp.array([100, 200, 300])
my_result = jax.vmap(func_vertex_result)(num_samples_vec)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
This BatchTracer with object id 140503986410544 was created on line:
  /tmp/ipykernel_12816/3463451019.py:3:12 (<module>)
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError

# Error model

In [5]:
gamma = 0.1
num_errors = 1000
poisson_em = ErrorModelPoisson(num_errors, 2 * p, gamma)

In [32]:
probs = poisson_em.get_probabilities()
print(probs)

[8.18730753e-01 1.63746151e-01 1.63746151e-02 1.09164100e-03
 5.45820502e-05 2.18328201e-06]


In [33]:
my_error = poisson_em.generate_random_error()

In [34]:
sampled_frequencies = np.array([
    np.argwhere(my_error == x).shape[0] / num_errors
    for x in range(2 * p)
])
print(sampled_frequencies)

[0.837 0.152 0.011 0.    0.    0.   ]


# Compute results conditional entropy

In [41]:
poisson_em_moebius = ErrorModelPoisson(moebius_code.num_edges, 2 * p, gamma)
em_moebius_qubit = ErrorModelPoisson(moebius_code.num_edges, 2, gamma)
num_samples = 1000
result = moebius_code.compute_vertex_conditional_entropy(
    num_samples=num_samples, error_model=poisson_em_moebius)
print(result)

0.2093580813571175


In [42]:
coherent_info = moebius_code.compute_coherent_information(
    num_samples=num_samples, error_model=poisson_em_moebius)
print("Coherent Information: {}".format(coherent_info))

Coherent Information: 0.7924636880589835


In [44]:
coherent_info_qubit = moebius_code_qubit.compute_coherent_information(
    num_samples=num_samples * 1000, error_model=em_moebius_qubit)
print("Coherent Information: {}".format(coherent_info_qubit))

Coherent Information: -0.5290947317020129


In [None]:
res_pq = moebius_code_qubit.compute_plaquette_syndrome_chi_probabilities(
    num_samples, em_moebius_qubit)
print(res_pq)

{'1_0_0_1_1_1_0_1': [0.0, 0.001], '0_0_1_1_1_0_0_1': [0.008, 0.001], '1_1_1_0_1_1_0_1': [0.001, 0.001], '0_1_0_0_0_0_0_0': [0.0, 0.015000000000000006], '0_1_1_0_0_1_1_0': [0.004, 0.0], '1_1_0_0_0_0_0_0': [0.004, 0.002], '0_0_0_1_0_1_1_1': [0.0, 0.002], '1_0_1_0_0_1_1_0': [0.002, 0.001], '0_0_0_0_0_0_0_0': [0.06900000000000005, 0.001], '0_0_1_0_0_0_1_0': [0.005, 0.005], '0_1_0_1_1_1_0_0': [0.003, 0.004], '1_0_0_0_1_0_0_1': [0.005, 0.001], '0_0_1_0_1_1_1_1': [0.0, 0.003], '0_1_0_1_0_0_1_1': [0.005, 0.001], '0_0_1_0_0_0_0_0': [0.014000000000000005, 0.0], '0_1_0_1_0_0_1_0': [0.003, 0.006], '0_0_0_1_1_0_0_0': [0.011000000000000003, 0.001], '1_0_0_1_1_0_0_0': [0.004, 0.0], '1_0_1_0_0_0_0_0': [0.009000000000000001, 0.001], '0_0_1_0_0_1_1_0': [0.003, 0.0], '0_1_0_0_0_0_0_1': [0.001, 0.002], '1_0_1_1_1_1_0_1': [0.003, 0.001], '0_1_0_0_1_0_0_0': [0.012000000000000004, 0.0], '0_0_0_0_1_1_0_0': [0.001, 0.009000000000000001], '1_0_0_0_0_0_1_1': [0.004, 0.001], '1_0_0_1_0_0_0_1': [0.002, 0.004], '0_

In [14]:
# 2. Use lambda to wrap the method call on the instance
execution_time = timeit.timeit(
    # The lambda function is the callable that timeit executes
    stmt=lambda: moebius_code.compute_plaquette_conditional_entropy(
        num_samples=100000, error_model=poisson_em_moebius),
    number=1  # Run the entire thing 10 times
)

print(f"Time taken for 1 run: {execution_time} seconds")
# Note: The result of the method call itself is ignored by timeit.timeit,
# which only records the execution time.

KeyboardInterrupt: 

# JAX



In [48]:
def syndrome_jax(error, h_x_jax):
    return jnp.mod(h_x_jax @ error.T, 2 * p)

In [49]:
h_x_jax = jnp.array(h_x)
errors = []
num_errors = 100009
for _ in range(num_errors - 1):
    error = poisson_em_moebius.generate_random_error()
    errors.append(error)
errors = jnp.array(errors, dtype=np.int16)
jitted_syndrome = jax.jit(syndrome_jax)

In [50]:
jitted_syndrome(errors[0], h_x_jax)

Array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int16)

In [51]:
%timeit syndrome_jax(errors[0], h_x_jax).block_until_ready()

54.7 μs ± 1.15 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [26]:
jitted_syndrome(errors[0], h_x_jax)
%timeit jitted_syndrome(errors[0], h_x_jax).block_until_ready()

130 μs ± 6.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [52]:
from functools import partial
syndrome_new = partial(syndrome_jax, h_x_jax=h_x_jax)
auto_syndrome = jax.vmap(syndrome_new, in_axes=0)

In [53]:
%timeit auto_syndrome(errors).block_until_ready()

6.55 ms ± 126 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
