In [6]:
import os
from tqdm import tqdm
from qecsim.models.rotatedplanar import RotatedPlanarCode
from src.stim_noise_models import get_noise_model
from src.circuit_builder import build_stim_circuit_v2
import jax.numpy as jnp
from jax import random
import stim

key = random.key(0) # Only affects the random deformations

# Collect data

In [8]:
unique_deformations = 1000
deformation_options = jnp.array([0,1,2,3,4,5])
samples_per_deformation = 1000
code_distance = 3
num_rounds = 3
noise_model_name = "spin"

num_stabilizers = code_distance**2-1
code = RotatedPlanarCode(code_distance, code_distance)
noise_model = get_noise_model(noise_model_name)

data = {
    "deformations": jnp.empty((unique_deformations*2, code_distance**2), dtype=jnp.int8),
    "syndromes_initial": jnp.empty((unique_deformations*2, samples_per_deformation, num_stabilizers), dtype=jnp.int8),
    "syndromes_rounds": jnp.empty((unique_deformations*2, samples_per_deformation, num_rounds, num_stabilizers), dtype=jnp.int8),
    "observables": jnp.empty((unique_deformations*2, samples_per_deformation), dtype=jnp.int8),
    "is_using_the_z_basis": jnp.empty((unique_deformations*2,), dtype=jnp.bool_),
}
for i in tqdm(range(unique_deformations), desc="Generating data", ncols=150):
    subkey, key = random.split(key)
    deformation = random.choice(subkey, deformation_options, shape=(code_distance**2,))
    x_basis_stim_circuit, z_basis_stim_circuit = build_stim_circuit_v2(code, deformation, num_rounds+1)
    for j, circ in enumerate([z_basis_stim_circuit, x_basis_stim_circuit]):
        noisy_circ: stim.Circuit = noise_model(circ)
        sampler = noisy_circ.compile_sampler()
        results = sampler.sample(shots=samples_per_deformation)
        syndromes_, observables_ = results[:, :-code_distance], results[:, -code_distance:]
        syndromes = syndromes_.reshape(samples_per_deformation, num_rounds+1, num_stabilizers)
        syndromes_init = syndromes[:, 0]
        syndromes_rounds = syndromes[:, 1:]  # Exclude initial round
        observables = observables_.sum(axis=1) % 2
        # Store data
        data["deformations"] = data["deformations"].at[i*2+j].set(deformation.astype(data["deformations"].dtype))
        data["syndromes_initial"] = data["syndromes_initial"].at[i*2+j].set(syndromes_init.astype(data["syndromes_initial"].dtype))
        data["syndromes_rounds"] = data["syndromes_rounds"].at[i*2+j].set(syndromes_rounds.astype(data["syndromes_rounds"].dtype))
        data["observables"] = data["observables"].at[i*2+j].set(observables.astype(data["observables"].dtype))
        data["is_using_the_z_basis"] = data["is_using_the_z_basis"].at[i*2+j].set(j == 0)

Generating data: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:45<00:00, 21.91it/s]


# Save the data

In [None]:
save_as = f"stim_{noise_model_name}_{code_distance}x{code_distance}_r{num_rounds}"
if os.path.exists(f"../data_sets/{save_as}"):
    raise FileExistsError(f"Directory ../data_sets/{save_as} already exists. Please choose a different name to avoid overwriting existing data or manually delete the file.")
os.makedirs(f"../data_sets/{save_as}")  
for name, val in data.items():
    jnp.save(f"../data_sets/{save_as}/{name}.npy", val)