### Training a referee decoder

#### NOTE: Not finished!

Deepq requires a referee decoder to evaluate actions taken during episodes. The paper doesn't describe the details of the referee decoder and only gives weights for the `d=5` referee. In this notebook we train a referee decoder for `d=7` using stim.

In [74]:
import keras
from keras.models import load_model, clone_model

from deepq.Environments import *
from deepq.Function_Library import *
from deepq.Utils import *

import stim
import numpy as np

import os

In [75]:
# load model architecture from previous referee decoder
static_decoder_path = os.path.join(os.getcwd(), "referee_decoders/nn_d5_X_p5")
static_decoder = load_model(static_decoder_path)

In [76]:
# get summary of current referee architecture
static_decoder.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 1000)              37000     
_________________________________________________________________
dropout_1 (Dropout)          (None, 1000)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 500)               500500    
_________________________________________________________________
dropout_2 (Dropout)          (None, 500)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 250)               125250    
_________________________________________________________________
dropout_3 (Dropout)          (None, 250)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 50)               

In [77]:
new_referee = clone_model(static_decoder, keras.Input(shape=(64,)))
new_referee.summary()
new_referee.input_shape

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 1000)              65000     
_________________________________________________________________
dropout_1 (Dropout)          (None, 1000)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 500)               500500    
_________________________________________________________________
dropout_2 (Dropout)          (None, 500)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 250)               125250    
_________________________________________________________________
dropout_3 (Dropout)          (None, 250)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 50)               

(None, 64)

In [94]:
d = 5
rounds = 10
batch_size = 10_000
p = 0.009

# circuit = stim.Circuit.generated(
#     "surface_code:rotated_memory_z",
#     rounds=rounds,
#     distance=d,
#     before_round_data_depolarization=p)
circuit = stim.Circuit.from_file("sf-d5-x-error.stim")

In [95]:
# This is for sampling from detectors (which introduces temporal dependencies)
detector_sampler = circuit.compile_detector_sampler()
shots = detector_sampler.sample(batch_size, append_observables=True)

# first cycle: (d**2-1)/2 detector measurements.
# second cycle:, (d**2-1)/2 X/Z and Z/X checks
detector_parts = shots[:, : circuit.num_detectors]
# logical observable along boundary
actual_observable_parts = shots[:, circuit.num_detectors :]

In [96]:
mround = d**2 - 1
syndrome_coords = circuit.get_detector_coordinates()
syndrome_coords = (
    np.array([syndrome_coords[e][0:2] for e in range(len(syndrome_coords))], int) // 2
)

def get_syndrome(syndromes, coords):
    faulty_syndrome = np.zeros((d + 1, d + 1), int)
    for e in range(len(syndromes)):
        i, j = coords[e]
        faulty_syndrome[i, j] = syndromes[e]
    # We need to flip the surface code around vertical axis
    # faulty_syndrome = np.fliplr(faulty_syndrome)
    faulty_syndrome = np.flipud(faulty_syndrome)
    return faulty_syndrome


faulty_syndromes = np.zeros((batch_size, rounds, d + 1, d + 1), int)
for t in range(batch_size):
    for s in range(rounds):
        faulty_syndromes[t][s] = get_syndrome(
            detector_parts[t][(s+1) * mround : (s + 2) * mround],
            syndrome_coords[(s+1) * mround : (s + 2) * mround],
        )

# detectors detect differences between time steps.
# We want absolute values for each time step.
for t in range(batch_size):
    for s in range(1, rounds):
        faulty_syndromes[t, s] += faulty_syndromes[t, s - 1]
        faulty_syndromes[t, s] = faulty_syndromes[t, s] % 2

In [97]:
faulty_syndromes = faulty_syndromes[:,-1,:,:]

In [100]:
acc = 0
flipped = 0
one = 0
for t in range(batch_size):
    current_true_syndrome_vector = np.reshape(faulty_syndromes[t,:,:],(d+1)**2) 
    decoder_label = static_decoder.predict(np.array([current_true_syndrome_vector]), batch_size=1, verbose=0)
    actual_observable_parts[t] = int(actual_observable_parts[t])
    
    if actual_observable_parts[t] == 1:
        one += 1
    
    if np.argmax(decoder_label[0]) == actual_observable_parts[t]:
        acc += 1
        
        if np.argmax(decoder_label[0]) == 1:
            flipped += 1
        
print(f"accuracy: {acc/batch_size}")
print(f"predicted flip: {flipped/one}, number of logical errors: {one}, corrected: {flipped}")

accuracy: 0.8802
predicted flip: 0.6223300970873786, number of logical errors: 2060, corrected: 1282
