In [None]:
import pathlib
from datetime import datetime
from typing import Optional

import numpy as np
import xarray as xr
import keras
from matplotlib import pyplot as plt

from qrennd import get_model, Config, Layout
from qrennd.layouts.plotter import plot

In [None]:
%load_ext tensorboard

In [None]:
def get_syndromes(anc_meas: xr.DataArray) -> xr.DataArray:
    syndromes = anc_meas ^ anc_meas.shift(qec_round=1, fill_value=0)
    syndromes.name = "syndromes"
    return syndromes


def get_defects(
    syndromes: xr.DataArray, frame: Optional[xr.DataArray] = None
) -> xr.DataArray:
    shifted_syn = syndromes.shift(qec_round=1, fill_value=0)

    if frame is not None:
        shifted_syn[dict(qec_round=0)] = frame

    defects = syndromes ^ shifted_syn
    defects.name = "defects"
    return defects


def get_final_defects(
    syndromes: xr.DataArray,
    proj_syndrome: xr.DataArray,
) -> xr.DataArray:
    last_syndrome = syndromes.isel(qec_round=-1)
    proj_anc = proj_syndrome.anc_qubit

    final_defects = last_syndrome.sel(anc_qubit=proj_anc) ^ proj_syndrome
    final_defects.name = "final_defects"
    return final_defects


def preprocess_data(dataset, proj_mat):
    syndromes = get_syndromes(dataset.anc_meas)
    defects = get_defects(syndromes)

    proj_syndrome = (dataset.data_meas @ proj_mat) % 2
    final_defects = get_final_defects(syndromes, proj_syndrome)

    init_states = dataset.init_state.sum(dim="data_qubit") % 2
    log_states = dataset.data_meas.sum(dim="data_qubit") % 2

    labels = log_states.astype(int) ^ init_states

    #inputs = dict(defects=defects.data, final_defects=final_defects.data)
    inputs = dict(defects=syndromes.data, final_defects=dataset.data_meas.data)
    #inputs = dict(defects=dataset.anc_meas.data, final_defects=dataset.data_meas.data)
    outputs = labels.data

    return inputs, outputs

# Load the datasets

In [None]:
NOTEBOOK_DIR = pathlib.Path.cwd() # define the path where the notebook is placed.

LAYOUT_DIR = NOTEBOOK_DIR / "layouts"
if not LAYOUT_DIR.exists():
    raise ValueError("Layout directory does not exist.")

CONFIG_DIR = NOTEBOOK_DIR / "configs"
if not CONFIG_DIR.exists():
    raise ValueError("Config directory does not exist.")

# The train/dev/test data directories are located in the local data directory
DATA_DIR = NOTEBOOK_DIR / "data"
if not DATA_DIR.exists():
    raise ValueError("Train data directory does not exist.")

cur_datetime = datetime.now()
datetime_str = cur_datetime.strftime("%Y%m%d-%H%M%S")

LOG_DIR = NOTEBOOK_DIR / f"logs/{datetime_str}"
LOG_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = NOTEBOOK_DIR / "tmp/checkpoint"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
LAYOUT_FILE = "d3_layout.yaml"
layout = Layout.from_yaml(LAYOUT_DIR / LAYOUT_FILE)

fig, ax = plt.subplots(figsize=(4, 4))
plot(layout, label_qubits=True, draw_patches=True, axis=ax)
plt.tight_layout()
plt.show()

In [None]:
CONFIG_FILE = "base_config.yaml"
config = Config.from_yaml(CONFIG_DIR / CONFIG_FILE)

In [None]:
proj_mat = layout.projection_matrix(stab_type="z_type")

train_dataset = xr.load_dataset(
    DATA_DIR / "train/d3_surf_code_seq_round_state_0_shots_1000000_rounds_40.nc"
)
train_input, train_output = preprocess_data(train_dataset, proj_mat)

dev_dataset = xr.load_dataset(
    DATA_DIR / "dev/d3_surf_code_seq_round_state_0_shots_20000_rounds_40.nc"
)
dev_input, dev_output = preprocess_data(dev_dataset, proj_mat)

In [None]:
main_metrics = [
    keras.metrics.BinaryAccuracy(name="acc"),
    keras.metrics.AUC(num_thresholds=100, name="AUC", curve="ROC")
]

aux_metrics = [
    keras.metrics.BinaryAccuracy(name="acc"),
]

In [None]:
num_rounds = train_dataset.qec_round.size
num_anc = train_dataset.anc_qubit.size

model = get_model(
    defects_shape=(num_rounds, num_anc),
    final_defects_shape=(9, ),
    config=config,
    metrics=dict(
        main_output = main_metrics,
        aux_output = aux_metrics
    )
)

In [None]:
model.summary()

# Training

In [None]:
#  %tensorboard --logdir={LOG_DIR}

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_DIR / "weights.hdf5",
        monitor="val_loss",
        mode="min",
        save_best_only=True,
    ),
    keras.callbacks.TensorBoard(log_dir=LOG_DIR, histogram_freq=1),
    keras.callbacks.EarlyStopping(
        monitor="val_loss", mode="min", min_delta=0, patience=3
    ),
]


In [None]:
history = model.fit(
    x=train_input,
    y=train_output,
    validation_data=[dev_input, dev_output],
    batch_size=64,
    epochs=5,
    callbacks=callbacks
)

In [None]:
test_dataset = xr.load_dataset(
    DATA_DIR / "test/d3_surf_code_seq_round_state_0_shots_20000_rounds_20.nc"
)
test_input, test_output = preprocess_data(test_dataset, proj_mat)

In [None]:
eval_output = model.evaluate(x=test_input, y=test_output, batch_size=64)

In [None]:
test_dataset = xr.load_dataset(
    DATA_DIR / "test/d3_surf_code_seq_round_state_0_shots_20000_rounds_20_v2.nc"
)
test_input, test_output = preprocess_data(test_dataset, proj_mat)

In [None]:
eval_output = model.evaluate(x=test_input, y=test_output, batch_size=64)