# Coherent info

In [1]:
import numpy as np
from coherentinfo.moebius_qubit import MoebiusCodeQubit
from coherentinfo.moebius_odd_prime import MoebiusCodeOddPrime
from coherentinfo.postprocess import (aggregate_data,
                                      aggregate_data_jax,
                                      update_aggregated_data_jax,
                                      compute_conditional_entropy_term,
                                      miller_madow_conditional_entropy)


from coherentinfo.errormodel import ErrorModelBernoulli, ErrorModelPoisson
# import galois
import scipy
import timeit
import jax
import jax.numpy as jnp

In [2]:
length = 7
width = 7
p = 3
moebius_code = MoebiusCodeOddPrime(length=length, width=width, d=2 * p)
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code.h_z
h_x = moebius_code.h_x
logical_x = moebius_code.logical_x
logical_z = moebius_code.logical_z

In [3]:
gamma = 0.01
p_error = 0.1
em_moebius_poisson_jax = ErrorModelPoisson(
    moebius_code.num_edges, 2 * p, gamma)
em_moebius_qubit = ErrorModelBernoulli(
    moebius_code.num_edges, 2, p_error
)

In [15]:
num_samples = 1_000_000
num_batch = 1
num_samples_batch = int(num_samples / num_batch)
max_different_syndromes = 1_000_000  # upper bounded by num_samples
# vertex_result = moebius_code_qubit.compute_batched_vertex_syndrome_chi_z(
#     num_samples, em_moebius_qubit_jax)
# seeds = [np.random.randint(100_000) for _ in range(num_batch)]
# master_vertex_keys = jnp.array([jax.random.PRNGKey(seed) for seed in seeds])
base_key = jax.random.PRNGKey(0)
master_vertex_keys = jax.random.split(base_key, num_batch)

In [16]:
# vertex_result = moebius_code.compute_batched_vertex_syndrome_chi_z(
#     num_samples_batch, em_moebius_poisson_jax, master_vertex_keys[0])

In [17]:
# aggregate_data_jit = jax.jit(
#     aggregate_data_jax, static_argnums=(1,))
# vertex_pads = -1 * jnp.ones(vertex_result.shape[1] - 1)

# unique_syndromes, vertex_counts = aggregate_data_jit(
#     vertex_result,
#     max_different_syndromes,  # Passed as the static (Python int) argument
#     vertex_pads
# )

init_syndromes = -1 * \
    jnp.ones([max_different_syndromes,
             moebius_code.num_vertex_checks], dtype=jnp.int32)
init_counts = jnp.zeros([max_different_syndromes, 2], dtype=jnp.int32)
vertex_pads = -1 * jnp.ones(init_syndromes.shape)
# del vertex_result
update_aggregated_data_jit = jax.jit(
    update_aggregated_data_jax, static_argnums=(3,))

In [18]:
# for batch in range(0, num_batch):
#     vertex_result = moebius_code.compute_batched_vertex_syndrome_chi_z(
#         num_samples_batch,
#         em_moebius_poisson_jax,
#         master_vertex_keys[batch])

#     unique_syndromes, vertex_counts = \
#         update_aggregated_data_jit(
#             unique_syndromes,
#             vertex_counts,
#             vertex_result,
#             max_different_syndromes,
#             vertex_pads
#         )

#     del vertex_result

def run_compiled_simulation(
    master_keys,
    num_samples_batch,
    max_different_syndromes,
    vertex_pads,
    init_syndromes,
    init_counts
):
    """
    Runs the entire sampling and aggregation loop inside a single XLA call.
    """

    def scan_body(carry, key):
        # 1. Unpack the current state (the 'carry')
        current_syndromes, current_counts = carry

        # 2. Sample (Ensure this function is fully JAX-native!)
        # This part must be JAX-compatible to work inside scan.
        vertex_result = moebius_code.compute_batched_vertex_syndrome_chi_z(
            num_samples_batch,
            em_moebius_poisson_jax,
            key
        )

        # 3. Update using the JIT-compiled function we built
        next_syndromes, next_counts = update_aggregated_data_jax(
            current_syndromes,
            current_counts,
            vertex_result,
            max_different_syndromes,
            vertex_pads
        )

        # 4. Return the new state to be used in the next iteration
        # The second return value (None) is for 'stacking' results,
        # but we don't want to store per-batch results (to save RAM).
        return (next_syndromes, next_counts), None

    # This is where the magic happens.
    # It loops over 'master_keys' entirely within XLA.
    (final_syndromes, final_counts), _ = jax.lax.scan(
        scan_body,
        (init_syndromes, init_counts),
        master_keys
    )

    return final_syndromes, final_counts


# Compile the whole engine
run_simulation_jit = jax.jit(run_compiled_simulation, static_argnums=(1, 2))

In [19]:
final_s, final_c = run_simulation_jit(
    master_vertex_keys,
    num_samples_batch,
    max_different_syndromes,
    vertex_pads,
    init_syndromes,
    init_counts
)

In [20]:
# @jax.jit
def count_unique_syndromes(unique_syndromes: jax.Array):
    """
    Blazing fast count for sorted syndromes.
    Finds the first occurrence of -1 in the first column.
    """
    # We only care about the first element of each syndrome
    first_elements = unique_syndromes[:, 0]

    # Since searchsorted needs ascending order, we search for the
    # first element that is NOT greater than -1 from the right,
    # or we can just find the first -1 by checking 'is it -1'.

    # This returns the number of elements that are NOT -1
    # by finding the first index where the condition (val == -1) is true.
    is_pad = (first_elements == -1)

    # jnp.argmax on a boolean array returns the FIRST index where it's True
    first_pad_idx = jnp.argmax(is_pad)

    # EDGE CASE: If no pads exist, argmax returns 0.
    # We check if the first element is actually a pad.
    has_any_pad = is_pad[first_pad_idx]

    return jnp.where(has_any_pad, first_pad_idx, unique_syndromes.shape[0])

In [21]:
# count_unique_syndromes(unique_syndromes)

In [22]:
# unique_syndromes.shape

In [23]:
vertex_conditional_entropy = miller_madow_conditional_entropy(
    final_c / num_samples, num_samples)
# del vertex_probs

In [24]:
print(vertex_conditional_entropy)

-0.3984625


In [14]:
count_unique_syndromes(final_s)

Array(368349, dtype=int32)

In [9]:
master_plaquette_key = jax.random.PRNGKey(687090)
plaquette_result = moebius_code_qubit.compute_batched_plaquette_syndrome_chi_x(
    num_samples, em_moebius_qubit_jax, master_plaquette_key)

In [10]:
plaquette_pads = -1 * jnp.ones(plaquette_result.shape[1] - 1)
_, plaquette_probs = aggregate_data_jit(
    plaquette_result,
    num_samples,  # Passed as the static (Python int) argument
    plaquette_pads
)
del plaquette_result
del plaquette_pads

In [11]:
plaquette_conditional_entropy = miller_madow_conditional_entropy(
    plaquette_probs, num_samples)
del plaquette_probs

In [12]:
coherent_info = 1.0 - vertex_conditional_entropy - plaquette_conditional_entropy
print("Coherent information = {}".format(coherent_info))

Coherent information = -65402.00390625
