# Coherent info

In [1]:
import numpy as np
from collections import Counter
from typing import Tuple, Iterable, Dict, List
from coherentinfo.moebius import MoebiusCode, MoebiusCodeOddPrime, MoebiusCodeQubit
from coherentinfo.linalg import (finite_field_gauss_jordan_elimination,
                                 finite_field_matrix_rank,
                                 finite_field_inverse,
                                 is_prime)
from coherentinfo.postprocess import aggregate_data

from numpy.typing import NDArray

from coherentinfo.errormodel import ErrorModel, ErrorModelBernoulli, ErrorModelBernoulliJax
# import galois
import scipy
import timeit
import jax
import jax.numpy as jnp

In [2]:
length = 5
width = 5
p = 3
moebius_code = MoebiusCodeOddPrime(length=length, width=width, d=2 * p)
moebius_code_qubit = MoebiusCodeQubit(length=length, width=width)
h_z = moebius_code.h_z
h_x = moebius_code.h_x
logical_x = moebius_code.logical_x
logical_z = moebius_code.logical_z

In [3]:
# gamma = 0.5
p_error = 0.1
# poisson_em_moebius = ErrorModelPoisson(moebius_code.num_edges, 2 * p, gamma)
em_moebius_qubit = ErrorModelBernoulli(moebius_code.num_edges, 2, p_error)
em_moebius_qubit_jax = ErrorModelBernoulliJax(
    moebius_code.num_edges, 2, p_error
)

In [4]:
num_samples = 1000000
vertex_result = moebius_code_qubit.compute_batched_vertex_syndrome_chi_z(
    num_samples, em_moebius_qubit_jax)
plaquette_result = moebius_code_qubit.compute_batched_plaquette_syndrome_chi_x(
    num_samples, em_moebius_qubit_jax)

In [12]:
vertex_syndrome, vertex_counts = aggregate_data(vertex_result)

In [15]:
index = 0
print(vertex_syndrome[index, :])
print(vertex_counts[index])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[9005    3]


The following returns the unique rows and their count. Note that it is not jit compatible, but it can be made so if needed, by fixing the output size.

In [None]:
unique_vertex_result, vertex_counts = jnp.unique(
    vertex_result,
    axis=0,             # Operate along the rows (axis 0)
    return_counts=True  # Return the count of each unique row
)

In [None]:
unique_plaquette_result, plaquette_counts = jnp.unique(
    plaquette_result,
    axis=0,             # Operate along the rows (axis 0)
    return_counts=True  # Return the count of each unique row
)

In [None]:
index = 4
unique_vertex[index, :], counts[index]

(Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],      dtype=int32),
 Array(1227, dtype=int32))

# Error model

In [51]:
%timeit syndrome_jax(errors[0], h_x_jax).block_until_ready()

54.7 μs ± 1.15 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [26]:
jitted_syndrome(errors[0], h_x_jax)
%timeit jitted_syndrome(errors[0], h_x_jax).block_until_ready()

130 μs ± 6.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [52]:
from functools import partial
syndrome_new = partial(syndrome_jax, h_x_jax=h_x_jax)
auto_syndrome = jax.vmap(syndrome_new, in_axes=0)

In [53]:
%timeit auto_syndrome(errors).block_until_ready()

6.55 ms ± 126 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Test random numbers jax

In [15]:
import jax
import jax.numpy as jnp

# 1. Initialize a PRNG Key (Seed)
# JAX requires an explicit Pseudo-Random Number Generator (PRNG) key for all random operations.
# The number 42 here is the seed, which you can change.
key = jax.random.PRNGKey(48090)

# 2. Define the parameters
# The shape of the array you want (length 10)
shape = (10,)
# The minimum value (inclusive)
low = 0
# The maximum value (exclusive, so we use 6 to include 5)
high = 6

# 3. Generate the random array
random_array = jax.random.randint(
    key,      # The PRNG key
    shape,    # The desired shape of the array (10,)
    low,      # The lower bound (inclusive)
    high      # The upper bound (exclusive)
)

print("JAX Random Array:")
print(random_array)
print("Data Type:")
print(random_array.dtype)

JAX Random Array:
[0 0 0 3 2 2 0 5 0 2]
Data Type:
int32


In [17]:
jax.random.split(key, 2)

Array([[ 452261207,  475151072],
       [2888444814, 1655919347]], dtype=uint32)

In [8]:
# 1. Initialize a PRNG key with a seed
seed = 42
key = jax.random.PRNGKey(seed)

# 2. Split the key to get a subkey for a random operation
key, subkey = jax.random.split(key)

# 3. Use the subkey for a random function (e.g., uniform, normal)
random_value = jax.random.uniform(subkey)
print(f"Random value: {random_value}")

# 4. For the next random number, split the main key again
key, next_subkey = jax.random.split(key)
next_random_value = jax.random.normal(next_subkey)
print(f"Next random value: {next_random_value}")

Random value: 0.7276642322540283
Next random value: -0.21089035272598267


In [11]:
jax.random.uniform(next_subkey, 3)

Array([0.41648638, 0.08647358, 0.4820521 ], dtype=float32)

In [22]:
np.random.rand(6)

array([0.1336215 , 0.00612622, 0.55107387, 0.39960868, 0.98579066,
       0.3507201 ])