A Physics-Informed Transformer Decoder for Quantum Error Correction on Rotated Surface Codes
We train a small (1.3M parameter) Transformer encoder to perform classification on syndrome data generated by STIM under a phenomenological noise model for the rotated surface code. Given a syndrome bit-string
The decoder is evaluated against the Minimum Weight Perfect Matching (MWPM) baseline (via PyMatching) across code distances
Model Overview: A 1.3M parameter JAX/Flax Transformer encoder (bfloat16 compute). The physical error rate
Physics-Informed (2+1)D RoPE: Standard 1D positional encodings discard the
-
Non-Lorentzian Frame: Under phenomenological noise, qubit decoherence (spatial correlation) and measurement errors (temporal correlation) operate on fundamentally distinct coordinate axes.
-
Capacity Allocation: We partition the RoPE frequency bands into a 3:1 spatial-to-temporal ratio, reflecting the richer structure of the 2D stabilizer lattice compared to the 1D temporal chain.
This decomposition induces an anisotropic attention kernel that evaluates spatial and temporal correlations additively while enforcing translational invariance:
-
Data Pipeline (STIM): 10M synthetic syndromes per distance (
$d \in {3,5,7}$ ) generated via STIM (surface_code:rotated_memory_z) under phenomenological noise. Error rates are geometrically sampled:$p \in [0.0009, 0.016]$ . -
Objective (Focal Loss): To counteract the extreme class imbalance at low
$p$ regimes (where logical errors constitute$\sim 0.02%$ of samples), we optimize using Focal Loss ($\gamma=2.0, \alpha=0.75$ ):$$\mathcal{L}_{\text{focal}}(p_t) = -\alpha_t ,(1 - p_t)^\gamma ,\log(p_t)$$
The Transformer decoder is evaluated against the MWPM baseline across three code distances. MWPM decoding is performed via PyMatching using STIM's detector error model.
Figure 1. Logical error rate vs. physical error rate for the Transformer decoder (solid) and MWPM (dashed) across code distances d = 3, 5, 7. The Transformer outperforms MWPM in the below-threshold regime at d = 3 and d = 5, achieving up to ~29% reduction in logical error rate.
Key results at
| Distance | MWPM |
Transformer |
Relative Improvement | |
|---|---|---|---|---|
| 24 | 19.5% Reduction | |||
| 120 | Match | |||
| 336 | Capacity Bound (4.5x worse) |
At
Figure 2. Log-log plot of logical error rates confirming correct decoder scaling. All three distances exhibit the expected monotonic decrease of pL with decreasing p. The d = 7 curve follows the correct scaling trend (see discussion).
Under the surface code power-law transformer_ler on p should yield slope
| d | Fitted slope | Theory |
Implied |
||
|---|---|---|---|---|---|
| 3 | 2.003 | 2.0 | ~3.8% | −4.2 | 0.9993 |
| 5 | 3.010 | 3.0 | ~3.2% | −7.3 | 0.9988 |
| 7 | 4.205 | 4.0 | ~2.8% | −9.9 | 0.9965 |
Data: results/evaluation_results.csv
Takeaways: Slopes land within ~5% of theory (
-
Exploiting
$Y$ -Errors ($d=3, 5$ ): The Transformer outperforms MWPM below threshold by learning correlated defect signatures of$Y$ errors, which standard MWPM strictly treats as independent$X$ and$Z$ defect pairs. -
OOD Homological Generalization: Extrapolating beyond the training distribution (
$p \le 0.016$ ) yields a$d=3/5$ pseudo-threshold at$p_{\text{th}} \approx 0.027$ . Approaching the theoretical asymptotic limit ($\sim 2.9%$ ) on unseen noise regimes indicates the model learns generalized topological homology rather than interpolating local error distributions. The threshold degradation observed at$d=5/7$ strictly bounds the capacity of the current parameter space. -
The
$d=7$ Capacity Bound: The model maintains correct topological scaling (exponential suppression of$p_L$ ) at$d=7$ , confirming the decoding mechanism is sound. However, the 1.3M parameter capacity hits a representational ceiling against the combinatorially richer 336-detector syndrome volume, establishing a clear trajectory for architectural scaling. - Circuit-Level Noise: Future work will swap the phenomenological assumption for full circuit-level simulations (matching the AlphaQubit baseline) to test the robustness of the (2+1)D RoPE inductive bias against spatial crosstalk and leakage.
-
Inference Latency: While
$O(L^2)$ attention poses asymptotic challenges compared to optimized$O(L^3)$ MWPM C++ implementations, future benchmarks will quantify the throughput-accuracy tradeoff using JAX batched inference on TPU hardware. -
Power-Law Scaling Confirmed: Log-log regression yields fitted slopes within ~5% of the theoretical
$(d+1)/2$ exponents ($R^2 > 0.996$ ), with implied thresholds in the expected ~1–4% range and prefactor$C$ shrinking rapidly with$d$ .
TransformerQEC/
├── notebooks/
│ ├── 01_data_exploration.ipynb # STIM circuit inspection, syndrome visualization,
│ │ # defect statistics, and noise model characterization
│ ├── 02_model_and_training.ipynb # Model architecture, (2+1)D RoPE, and focal loss
│ └── 03_evaluation.ipynb # MWPM comparison, threshold estimation
│ # Wilson CI, and result visualization
├── results/
│ ├── transformer_qec_d{3,5,7}.pkl # Trained checkpoints (params + config + coords)
│ ├── evaluation_results.csv # Numerical LER comparison across (d, p)
│ ├── transformer_vs_mwpm.png # Decoder comparison plot
│ ├── logical_error_rates.png # Scaling behavior plot
│ └── threshold_estimates.txt # Threshold crossing analysis
├── tests/
│ └── test_notebooks_compat.py # End-to-end integration tests (9 phases)
└── README.md
- AlphaQubit — Bausch, J. et al. "Learning to decode the surface code with a recurrent, transformer-based neural network." Nature (2024).
- RoPE — Su, J. et al. "RoFormer: Enhanced transformer with rotary position embedding." arXiv:2104.09864 (2021).
- STIM — Gidney, C. "Stim: a fast stabilizer circuit simulator." Quantum 5, 497 (2021).
- PyMatching — Higgott, O. "PyMatching: A Python package for decoding quantum codes with minimum-weight perfect matching." ACM Transactions on Quantum Computing (2022).
- Focal Loss — Lin, T.-Y. et al. "Focal loss for dense object detection." ICCV (2017).
- Surface Codes — Fowler, A. G. et al. "Surface codes: Towards practical large-scale quantum computation." Physical Review A 86, 032324 (2012).
Built with JAX on TPU. Synthetic data generated with STIM.

