In [None]:
!pip install torch

In [34]:
from src.GNN_Decoder import GNN_Decoder
import src.gnn_models as gnn
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
import stim
import numpy as np

In [13]:
_PATH = "models/circuit_level_noise/d3/d3_d_t_3.pt"

In [14]:
# Set and export parameters
_CODE_SIZE=3
_REPETITIONS=3
## Training settings
_NUM_ITERATIONS=3
_BATCH_SIZE=10
_LEARNING_RATE=0.00001
_MANUAL_SEED=12345
## Benchmark
_BENCHMARK=1
## Buffer settings
_BUFFER_SIZE=10
_REPLACEMENTS_PER_ITERATION=2
# test_size is len(error_rate) * batch_size * test_size
_TEST_SIZE=1
## Graph settings
_NUM_NODE_FEATURES=5
_EDGE_WEIGHT_POWER=2
_M_NEAREST_NODES=6
_USE_CUDA=0
_USE_VALIDATION=1
## Criterion
_CRITERION = torch.nn.BCEWithLogitsLoss()

In [15]:
GNN_params = {
    'model': {
        'class': gnn.GNN_7,
        'num_classes': 1, # 1 output class for two-headed model
        'loss': _CRITERION,
        'num_node_features': _NUM_NODE_FEATURES,
        'initial_learning_rate': _LEARNING_RATE,
        'manual_seed': _MANUAL_SEED,
    },
    'graph': {
        'num_node_features': _NUM_NODE_FEATURES,
        'm_nearest_nodes': _M_NEAREST_NODES,
        'power': _EDGE_WEIGHT_POWER,
    },
    'cuda': _USE_CUDA,
    'save_path': "test/test.pt", 
    'save_prefix': "test",
}

In [31]:
attributes = torch.load(_PATH, weights_only=False, map_location=torch.device('cpu'))
decoder = GNN_Decoder(GNN_params)
decoder.load_training_history(attributes)
decoder.model.eval()

GNN_7(
  (graph1): GraphConv(5, 32)
  (graph2): GraphConv(32, 128)
  (graph3): GraphConv(128, 256)
  (graph4): GraphConv(256, 512)
  (graph5): GraphConv(512, 512)
  (graph6): GraphConv(512, 256)
  (graph7): GraphConv(256, 256)
  (lin1): Linear(in_features=256, out_features=256, bias=True)
  (lin2): Linear(in_features=256, out_features=128, bias=True)
  (lin3): Linear(in_features=128, out_features=64, bias=True)
  (lin4): Linear(in_features=64, out_features=1, bias=True)
)

In [None]:
def evaulate_model(decoder, graph_data, save_to=None):
    loader = DataLoader(graph_data, batch_size=1000)
    decoder.model.eval()
    for data in loader:
        data.batch = data.batch.to(data.x.device)
        prediction = decoder.model(data.x, data.edge_index, data.edge.attr, data.batch)
        target = data.y

        if save_to is not None:
            torch.save((prediction, target), save_to)
        else:
            return prediction, target



def generate_test_batch(test_size):
    '''Generates a test batch at one test error rate'''
    # Keep track of trivial syndromes
    correct_predictions_trivial = 0

    stim_data_list, observable_flips_list = [], []

    # repeat each experiments multiple times to get enough non-empty 
    # syndromes. This number decreases with increasing p
    stim_data, observable_flips = sampler.sample(shots =  test_size, separate_observables = True)
    # remove empty syndromes:
    # (don't count imperfect X(Z) in second to last time)
    non_empty_indices = (np.sum(stim_data, axis = 1) != 0)
    stim_data_list.extend(stim_data[non_empty_indices, :])
    observable_flips_list.extend(observable_flips[non_empty_indices])
    # count empty instances as trivial predictions: 
    correct_predictions_trivial += len(observable_flips[~ non_empty_indices])
    # if there are more non-empty syndromes than necessary
    stim_data_list = stim_data_list[: test_size]
    observable_flips_list = observable_flips_list[: test_size]
    buffer = generate_batch(stim_data_list, observable_flips_list,
                            detector_coordinates, mask, m_nearest_nodes, power)
    # convert list of numpy arrays to torch Data object containing torch GPU tensors
    test_batch = []
    for i in range(len(buffer)):
        X = torch.from_numpy(buffer[i][0]).to(device)
        edge_index = torch.from_numpy(buffer[i][1]).to(device)
        edge_attr = torch.from_numpy(buffer[i][2]).to(device)
        y = torch.from_numpy(buffer[i][3]).to(device)
        test_batch.append(Data(x=X, edge_index=edge_index, edge_attr=edge_attr, y = y))
    return test_batch, correct_predictions_trivial

In [32]:
## Experiment variables
_distance = 3
_time_steps = _distance
_error_rate = 0.1
_test_size = 10000

In [None]:
circuit = stim.Circuit.generated(
    "surface_code:rotated:memory_z",
    rounds=_distance,
    distance=_distance,
    after_clifford_depolarization=_error_rate,
    after_reset_flip_probability=_error_rate,
    befire_measure_flip_probability=_error_rate,
    before_round_data_depolarization=_error_rate,
)

detector_coordinates = circuit.get_detector_coordinates()
detector_coordinates = np.array(list(detector_coordinates.values()))
# rescale space like coordinates:
detector_coordinates[:, : 2] = detector_coordinates[:, : 2] / 2
# convert to integers
detector_coordinates = detector_coordinates.astype(np.uint8)

sampler = circuit.compile_detector_sampler()

factor = 50


stim_data, observable_flips = sampler.sample(shots = _test_size, separate_observables = True)
