In [80]:
import jax.numpy as jnp
from jax import random, jit, lax, value_and_grad
import optax
from tqdm import tqdm
import os

from src.neural_network_decoders import TransformerDecoder, print_params_structure
from qecsim.models.rotatedplanar import RotatedPlanarCode

key = random.key(0)

# Load data

In [2]:
data_dir = "../data_sets/stim_spin_3x3_r3/"
code = RotatedPlanarCode(3, 3)

data = {
    file_name.split(".")[0]: jnp.load(f"{data_dir}/{file_name}") for file_name in os.listdir(data_dir) if file_name.endswith(".npy")
}
for name, val in data.items():
    print(f"{name}: {val.shape}")

num_deformations, num_shots, num_rounds, num_syndromes = data["syndromes_rounds"].shape
x_init = jnp.append(
    data["syndromes_initial"], 
    jnp.tile(data["deformations"][:, None, :], (1, num_shots, 1))+2, 
    axis=-1
).reshape(num_deformations*num_shots, -1)
x = jnp.append(
    data["syndromes_rounds"], 
    jnp.tile(data["deformations"][:, None, None, :], (1, num_shots, num_rounds, 1))+2, 
    axis=-1
).reshape(num_deformations*num_shots, num_rounds, -1)
y = data["observables"].flatten()
bz = jnp.tile(data["is_using_the_z_basis"][:, None], (1, num_shots)).flatten()
print()
print(f"x_init: {x_init.shape}")
print(f"x: {x.shape}")
print(f"y: {y.shape}")
print(f"bz: {bz.shape}")

deformations: (2000, 9)
is_using_the_z_basis: (2000,)
observables: (2000, 1000)
syndromes_initial: (2000, 1000, 8)
syndromes_rounds: (2000, 1000, 3, 8)

x_init: (2000000, 17)
x: (2000000, 3, 17)
y: (2000000,)
bz: (2000000,)


In [3]:
def shuffle_data(key, *sets):
    """
    Shuffle the data along the first axis.
    
    Args:
        key: JAX random key.
        *sets: Variable number of arrays to be shuffled in unison along the first axis.
    Returns:
        Tuple:
            - new_key: JAX random key after splitting.
            - shuffled_data: List of shuffled arrays.
    """
    subkey, key = random.split(key)
    perm = random.permutation(subkey, sets[0].shape[0])
    shuffled_data = [set[perm] for set in sets]
    return key, shuffled_data

key, [x_init, x, y, bz] = shuffle_data(key, x_init, x, y, bz)

# Show data for first training sample
print("\nx_init:")
print(x_init[0])
print("\nx:")
print(x[0])
print("\ny:", y[0])
print("\nbz:", bz[0])


x_init:
[1 1 0 0 1 1 1 0 4 7 5 7 3 6 4 2 6]

x:
[[1 1 1 0 1 1 1 1 4 7 5 7 3 6 4 2 6]
 [1 1 1 0 1 1 1 1 4 7 5 7 3 6 4 2 6]
 [1 1 1 0 1 1 0 1 4 7 5 7 3 6 4 2 6]]

y: 0

bz: False


# Initialize model

In [117]:
batch_size = 64
num_epochs = 10

init_key, shuffle_key = random.split(random.key(0), num=2)

In [118]:
def cords_from_code(code: RotatedPlanarCode) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]:
    """Get the coordinates of the plaquettes and data qubits from a rotated planar code and returns them as two separate lists."""
    plaquette_coords = code._plaquette_indices
    data_qubit_coords = [(x-.5, y-.5) for y in range(code.size[0]) for x in range(code.size[1])]
    return plaquette_coords, data_qubit_coords

In [119]:
plaquette_coords, data_qubit_coords = cords_from_code(code)
model = TransformerDecoder(
    site_locations=jnp.array(plaquette_coords + data_qubit_coords),
    output_features=4,
    vocab_size=8,
    num_layers=2,
    heads=4,
    d_model=32,
    mlp_dim=128,
    training=False
)
model_params = model.init(init_key, x_init[:batch_size], x[:batch_size])  # Initialize model parameters
print_params_structure(model_params)

params
  embedder
    embedder
      embedding:	 shape (8, 32)
  transformer_first_round
    encoder_layers_0
      norm_attention
        scale:	 shape (32,)
        bias:	 shape (32,)
      attention
        query
          kernel:	 shape (32, 4, 8)
          bias:	 shape (4, 8)
        key
          kernel:	 shape (32, 4, 8)
          bias:	 shape (4, 8)
        value
          kernel:	 shape (32, 4, 8)
          bias:	 shape (4, 8)
        out
          kernel:	 shape (4, 8, 32)
          bias:	 shape (32,)
      norm_mlp
        scale:	 shape (32,)
        bias:	 shape (32,)
      gated_mlp
        fc_layer_0
          kernel:	 shape (32, 128)
          bias:	 shape (128,)
        fc_layer_1
          kernel:	 shape (64, 32)
          bias:	 shape (32,)
    encoder_layers_1
      norm_attention
        scale:	 shape (32,)
        bias:	 shape (32,)
      attention
        query
          kernel:	 shape (32, 4, 8)
          bias:	 shape (4, 8)
        key
          kernel:	 shape (

In [120]:
shuffle_key, [x_init, x, y, bz] = shuffle_data(shuffle_key, x_init, x, y, bz)
input = (x_init[:1], x[:1])  # Input for a single training sample
output = model.apply(model_params, *input)[0]  # Forward pass with a batch of data
print("\nModel output for first training sample:")
print(f"in:\n {input[0][0]}\n{input[1][0]}")
print(f"Out:\n [{', '.join(f'{p:.2%}' for p in output)}]")


Model output for first training sample:
in:
 [0 1 0 0 0 0 1 1 5 2 5 3 2 7 4 4 2]
[[0 1 0 0 0 0 1 1 5 2 5 3 2 7 4 4 2]
 [0 1 0 0 0 0 1 1 5 2 5 3 2 7 4 4 2]
 [0 1 0 0 0 0 1 1 5 2 5 3 2 7 4 4 2]]
Out:
 [0.76%, 2.36%, 54.59%, 42.28%]


# Evaluation functions

In [121]:
@jit
def binary_cross_entropy(y_true, y_pred):
    # Clip predictions to avoid log(0)
    y_pred = jnp.clip(y_pred, 1e-7, 1.0 - 1e-7)
    # Calculate binary cross-entropy
    return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred))

@jit
def loss_fn(params, x_init, x, y, bz):
    # Get model predictions
    p = model.apply(params, x_init, x)
    p_I, p_X, p_Y, p_Z = p[:,0], p[:,1], p[:,2], p[:,3]
    # Calculate probability of a logical flip based on the basis
    p_flip_z = p_X + p_Y
    p_flip_x = p_Z + p_Y
    # Chose which flip to predict based on the basis used
    p_flip = bz * p_flip_z + (1.0 - bz) * p_flip_x
    # Calculate binary cross-entropy loss
    return binary_cross_entropy(y, p_flip)

loss_fn(model_params, x_init[:batch_size], x[:batch_size], y[:batch_size], bz[:batch_size])

Array(1.0337628, dtype=float32)

In [122]:
@jit
def estimate_ler(params, x_init, x, y, bz):
    probs = model.apply(params, x_init, x)
    # Get predicted logical Pauli error
    prediction = probs.argmax(axis=1) # Get the class with the highest probability (0: I, 1: X, 2: Y, 3: Z)
    err_I = prediction == 0
    err_X = prediction == 1
    err_Y = prediction == 2
    err_Z = prediction == 3
    # Determine if a logical flip has occurred based on the basis
    y_pred = (err_X & bz) | (err_Y) | (err_Z & ~bz)
    # Compare predictions to true labels and estimate the ler
    sucess_rate = jnp.mean(y_pred == y)
    ler = 1.0 - sucess_rate
    return ler

estimate_ler(model_params, x_init[:batch_size], x[:batch_size], y[:batch_size], bz[:batch_size])

Array(0.5, dtype=float32)

# Setup the optimizer

In [123]:
num_batches = num_epochs * x.shape[0] // batch_size
learning_rate = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=1e-3,
    warmup_steps=num_batches * 0.05,
    decay_steps=num_batches,
    end_value=1e-5
)
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate)
)
opt_state = optimizer.init(model_params)

# Training loop

In [124]:
@jit
def training_step(
    opt_state, 
    model_params: dict, 
    x_init: jnp.ndarray, 
    x: jnp.ndarray, 
    y: jnp.ndarray, 
    bz: jnp.ndarray, 
    batch_idx: int
) -> tuple[optax.OptState, dict, float]:
    """
    Perform a single training step: compute loss and gradients, update model parameters.

    Args:
        opt_state: Current state of the optimizer.
        model_params: Current model parameters.
        x_init (jnp.ndarray (batch_size, n_sites)): Syndrome measurements at initial round.
        x (jnp.ndarray (batch_size, n_rounds, n_sites)): Syndrome measurements for all rounds except the initial round.
        y (jnp.ndarray (batch_size,)): Observable outcomes.
        bz (jnp.ndarray (batch_size,)): Basis for logical state initialization and measurement (True for Z-basis, False for X-basis).
        batch_idx (int): Index of the current batch.

    Returns:
        Tuple:
            - new_opt_state: Updated optimizer state.
            - new_model_params: Updated model parameters.
            - loss: Computed loss for the batch.
    """
    # Get batch data
    start = batch_idx * batch_size
    batch_x_init = lax.dynamic_slice(x_init, (start, 0), (batch_size, x_init.shape[1]))
    batch_x = lax.dynamic_slice(x, (start, 0, 0), (batch_size, x.shape[1], x.shape[2]))
    batch_y = lax.dynamic_slice(y, (start,), (batch_size,))
    batch_bz = lax.dynamic_slice(bz, (start,), (batch_size,))
    # Compute loss and gradients
    loss, grads = value_and_grad(loss_fn)(model_params, batch_x_init, batch_x, batch_y, batch_bz)
    # Update model parameters
    updates, opt_state = optimizer.update(grads, opt_state, model_params)
    model_params = optax.apply_updates(model_params, updates)
    return opt_state, model_params, loss

In [125]:
for epoch in range(num_epochs):
    # Shuffle data at the start of each epoch
    shuffle_key, [x_init, x, y, bz] = shuffle_data(shuffle_key, x_init, x, y, bz)

    # Training loop
    num_batches = x.shape[0] // batch_size
    epoch_loss = 0.0
    for batch_idx in tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}", ncols=200):
        opt_state, model_params, loss = training_step(
            opt_state, model_params, x_init, x, y, bz, batch_idx
        )
        epoch_loss += loss
    epoch_loss /= num_batches
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}")

Epoch 1/10:   0%|                                                                                                                                                             | 0/31250 [00:00<?, ?it/s]

Epoch 1/10:   4%|██████                                                                                                                                            | 1306/31250 [01:31<34:53, 14.30it/s]


KeyboardInterrupt: 