In [None]:
import stim 
import torch
import numpy as np


def build_spatial_mapping(circuit: stim.Circuit, padding=False, ):
    coords_dict = circuit.get_detector_coordinates()
    
    
    unique_xys = set()
    for idx, coords in coords_dict.items():
        if len(coords) >= 2:
            unique_xys.add((coords[0], coords[1]))
    sorted_spatial_locs = sorted(list(unique_xys), key=lambda p: (p[1], p[0]))
    if padding:
        n_pixels_row = max(loc[0] for loc in sorted_spatial_locs)//2 +1
        loc_to_idx = {loc: (loc[0]//2 + (loc[1]//2) * n_pixels_row) for i, loc in enumerate(sorted_spatial_locs)}
        num_spatial_features = n_pixels_row ** 2
    else:
        
        loc_to_idx = {loc: i for i, loc in enumerate(sorted_spatial_locs)}
        num_spatial_features = len(sorted_spatial_locs)
    

    num_detectors = circuit.num_detectors
    flat_to_round = torch.zeros(num_detectors, dtype=torch.long)
    flat_to_spatial = torch.zeros(num_detectors, dtype=torch.long)
    

    times = [coords_dict[i][2] if len(coords_dict[i]) > 2 else 0 for i in range(num_detectors)]
    unique_times = sorted(list(set(times)))
    time_to_round = {t: i for i, t in enumerate(unique_times)}
    
    for i in range(num_detectors):
        coords = coords_dict[i]
        x, y = coords[0], coords[1]
        t = coords[2] if len(coords) > 2 else 0
        flat_to_round[i] = time_to_round[t]
        flat_to_spatial[i] = loc_to_idx[(x, y)]+t*num_spatial_features
    return flat_to_round, flat_to_spatial


def SurfaceDataReshape(dets, num_pixels, rounds, flat_to_spatial):
    batch_size = dets.shape[0]
    spatial_dets = torch.zeros(batch_size, num_pixels*(rounds+1))
    spatial_dets[:, flat_to_spatial] = dets
    return spatial_dets.reshape(batch_size, rounds+1, -1)

In [30]:
d = 3
rounds = 5
error_rate = 0.1

circuit = stim.Circuit.generated("surface_code:rotated_memory_z",
                                rounds=rounds,
                                distance=d,
                                after_clifford_depolarization=error_rate,
                                after_reset_flip_probability=error_rate,
                                before_measure_flip_probability=error_rate,
                                before_round_data_depolarization=error_rate)

flat_to_round, flat_to_spatial= build_spatial_mapping(circuit, padding=True)
print("Flat to Round Mapping:", flat_to_round)
print("Flat to Spatial Mapping:", flat_to_spatial)

dem = circuit.detector_error_model(flatten_loops=True, decompose_errors=True)
sampler = dem.compile_sampler()
dets,_,_ = sampler.sample(2)

rdets = SurfaceDataReshape(torch.tensor(dets, dtype=torch.float32), ((d+1)**2), rounds, flat_to_spatial)

# reshaped = reshaper(torch.tensor(dets, dtype=torch.float32))
# print("Reshaped Output Shape:", reshaped.shape)

print(rdets[0, 2, :])
print(dets[0, 12:20]*1.)

Flat to Round Mapping: tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
        3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5])
Flat to Spatial Mapping: tensor([ 8,  5, 10,  7, 17, 21, 22, 23, 24, 25, 26, 30, 33, 37, 38, 39, 40, 41,
        42, 46, 49, 53, 54, 55, 56, 57, 58, 62, 65, 69, 70, 71, 72, 73, 74, 78,
        88, 85, 90, 87])
tensor([0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.])
[0. 1. 0. 1. 0. 1. 0. 0.]


In [6]:
flat_to_round, flat_to_spatial, _ = build_compact_mapping(circuit)
print("Flat to Round Mapping:", flat_to_round)
print("Flat to Spatial Mapping:", flat_to_spatial)

检测到唯一X坐标数: 4, 唯一Y坐标数: 4
最终生成的 Compact Grid 大小 (含Padding): 6 x 6
Flat to Round Mapping: tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,
        3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5])
Flat to Spatial Mapping: tensor([19, 14, 21, 16,  8, 14, 15, 16, 19, 20, 21, 27,  8, 14, 15, 16, 19, 20,
        21, 27,  8, 14, 15, 16, 19, 20, 21, 27,  8, 14, 15, 16, 19, 20, 21, 27,
        19, 14, 21, 16])
