# Coherent info

In [1]:
import numpy as np
from coherentinfo.moebius import MoebiusCodeOddPrime, MoebiusCodeQubit
from coherentinfo.postprocess import (aggregate_data,
                                      aggregate_data_jax,
                                      compute_conditional_entropy_term,
                                      miller_madow_conditional_entropy)


from coherentinfo.errormodel import ErrorModel, ErrorModelBernoulli, ErrorModelBernoulliJax, ErrorModelPoissonJax
# import galois
import scipy
import timeit
import jax
import jax.numpy as jnp

In [2]:
length = 5
width = 5
p = 3
moebius_code = MoebiusCodeOddPrime(length=length, width=width, d=2 * p)
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code.h_z
h_x = moebius_code.h_x
logical_x = moebius_code.logical_x
logical_z = moebius_code.logical_z

In [3]:
gamma = 0.01
p_error = 0.1
em_moebius_poisson_jax = ErrorModelPoissonJax(
    moebius_code.num_edges, 2 * p, gamma)
em_moebius_qubit = ErrorModelBernoulli(moebius_code.num_edges, 2, p_error)
em_moebius_qubit_jax = ErrorModelBernoulliJax(
    moebius_code.num_edges, 2, p_error
)

In [None]:
num_samples = 1_000_000
# vertex_result = moebius_code_qubit.compute_batched_vertex_syndrome_chi_z(
#     num_samples, em_moebius_qubit_jax)
vertex_result = moebius_code.compute_batched_vertex_syndrome_chi_z(
    num_samples, em_moebius_poisson_jax)

In [None]:
aggregate_data_jit = jax.jit(
    aggregate_data_jax, static_argnums=(1,))
vertex_pads = -1 * jnp.ones(vertex_result.shape[1] - 1)

_, vertex_probs = aggregate_data_jit(
    vertex_result,
    num_samples,  # Passed as the static (Python int) argument
    vertex_pads
)
del vertex_result
del vertex_pads

In [None]:
vertex_conditional_entropy = miller_madow_conditional_entropy(
    vertex_probs, num_samples)
del vertex_probs

In [8]:
plaquette_result = moebius_code_qubit.compute_batched_plaquette_syndrome_chi_x(
    num_samples, em_moebius_qubit_jax)

In [9]:
plaquette_pads = -1 * jnp.ones(plaquette_result.shape[1] - 1)
_, plaquette_probs = aggregate_data_jit(
    plaquette_result,
    num_samples,  # Passed as the static (Python int) argument
    plaquette_pads
)
del plaquette_result
del plaquette_pads

In [10]:
plaquette_conditional_entropy = miller_madow_conditional_entropy(
    plaquette_probs, num_samples)
del plaquette_probs

In [11]:
coherent_info = 1.0 - vertex_conditional_entropy - plaquette_conditional_entropy
print("Coherent information = {}".format(coherent_info))

Coherent information = nan
