Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# AI pre-decoder for surface-code memory circuits
# Ising Decoding

This repo implements a **pre-decoder** for surface-code memory experiments:
This repo offers AI training frameworks and recipes to build, customize and deploy scalable quantum error correction **decoders**:

- A neural network consumes detector syndromes across space **and** time
- It predicts corrections that reduce syndrome density / improve decoding
Expand Down Expand Up @@ -92,8 +92,8 @@ pip install -r code/requirements_public_inference.txt

2. **Get the pre-trained models**
This repo ships two pre-trained model files (tracked with Git LFS):
- `models/PreDecoderModelMemory_r9_v1.0.77.pt` (receptive field R=9, checkpoint 77)
- `models/PreDecoderModelMemory_r13_v1.0.86.pt` (receptive field R=13, checkpoint 86)
- `models/Ising-Decoder-SurfaceCode-1-Fast.pt` (receptive field R=9)
- `models/Ising-Decoder-SurfaceCode-1-Accurate.pt` (receptive field R=13)

Clones get the files via `git lfs pull`. Optionally, set `PREDECODER_MODEL_URL` to the LFS/raw URL to fetch files when not in the working tree (e.g. in a minimal checkout or CI).

Expand Down Expand Up @@ -138,8 +138,16 @@ The pre-trained public models use `--model-id 1` (R=9) and `--model-id 4` (R=13)
After training (or starting from the shipped `.safetensors` files), you can export the model to
ONNX and optionally apply INT8 or FP8 post-training quantization for deployment.

Set the `ONNX_WORKFLOW` and (optionally) `QUANT_FORMAT` environment variables before running
inference with `local_run.sh`:
You may also change the surface code distance and number of rounds at inference
time. That is - you are not required retrain a new model when changing either
one of these parameters; since the model is a 3D convolutional neural network,
the model will simply be run over a new decoding volume.

- To run with a new distance, simply add `DISTANCE=<your distance>` to the commands below.
- To run with a new number of rounds, simply add `N_ROUNDS=<your number of rounds>` to the commands below.

Set the `ONNX_WORKFLOW` and (optionally) (`QUANT_FORMAT`, `DISTANCE`,
`N_ROUNDS`) environment variables before running inference with `local_run.sh`:

| `ONNX_WORKFLOW` | Behavior |
|---|---|
Expand Down Expand Up @@ -169,7 +177,16 @@ ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh
| `QUANT_FORMAT` | unset | `int8` or `fp8`. Unset means no quantization (FP32 ONNX). |
| `QUANT_CALIB_SAMPLES` | `256` | Calibration samples for INT8/FP8 post-training quantization. |

**Circuit variables:**

| Variable | Default | Description |
|---|---|---|
| `CONFIG_NAME` | `config_public` | Use the defaults from the `conf/$CONFIG_NAME.yaml` file |
| `DISTANCE` | Use the distance specified in the `conf/$CONFIG_NAME.yaml` file | surface code distance |
| `N_ROUNDS` | Calibration samples for INT8/FP8 post-training quantization. | number of rounds in memory experiment |

Notes:

- TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`.
- FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently.
- ONNX and engine files are written to the current working directory.
Expand Down Expand Up @@ -215,7 +232,7 @@ Results are written to `outputs/<EXPERIMENT_NAME>/plots/`.
| Decoder | Source | Notes |
|---|---|---|
| No-op | — | Pre-decoder output only, no global correction |
| Union-Find | `ldpc` | Fast, sub-optimal |
| Union-Find | `ldpc` | Fast, sub-optimal LER (Logical Error Rate) |
| BP-only | `ldpc` | Belief propagation, no OSD |
| BP+LSD-0 | `ldpc` | BP with localized statistics decoding |
| Uncorr-PM | PyMatching | Uncorrelated minimum-weight perfect matching |
Expand Down Expand Up @@ -556,4 +573,4 @@ Presence of these headers is enforced automatically by the `spdx-header-check` C
`.github/workflows/ci.yml`).

Third-party open source components bundled with or required by this project are listed with their
respective copyright notices and license texts in [NOTICE](NOTICE).
respective copyright notices and license texts in [NOTICE](NOTICE).
2 changes: 1 addition & 1 deletion TRAINING.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ For local single-machine usage, see `README.md`.
## Prerequisites

- Docker with NVIDIA GPU support (`nvidia-docker` / `--gpus`)
- One or more NVIDIA GPUs (H100, A100, or similar)
- One or more NVIDIA GPUs (B200, H200 or similar)
- A persistent directory for checkpoints and logs

## Quick start (Docker — recommended)
Expand Down
50 changes: 30 additions & 20 deletions code/evaluation/failure_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import os
import random
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -179,27 +180,39 @@ def _build_cudaq_decoders(det_model):

def _decode_cudaq_batch(decoder, L_dense, syndromes_np):
"""
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder (single-shot loop).
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder.
Returns (obs, stats) where:
- obs: observable predictions as np.ndarray of shape (B,)
- stats: dict with per-sample convergence flags, iteration counts
The decoder.decode() takes list[float] and returns DecoderResult with .result (list[float]).
"""
B = syndromes_np.shape[0]
obs = np.zeros(B, dtype=np.uint8)
n_bits = L_dense.shape[1]
converged_flags = np.zeros(B, dtype=bool)
iter_counts = np.zeros(B, dtype=np.int32)
for i in range(B):
syndrome_list = syndromes_np[i].astype(np.float64).tolist()
result = decoder.decode(syndrome_list)
correction = np.array(result.result, dtype=np.uint8)
obs[i] = int((L_dense @ correction).item() %
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
corrections = np.empty((B, n_bits), dtype=np.uint8)
syndromes_f64 = np.ascontiguousarray(syndromes_np, dtype=np.float64)

def _unpack(i, result):
corrections[i] = np.array(result.result, dtype=np.uint8)
converged_flags[i] = result.converged
# Collect iteration count if available via opt_results
opt = getattr(result, 'opt_results', None)
if opt and isinstance(opt, dict) and 'num_iter' in opt:
iter_counts[i] = opt['num_iter']

def _loop_decode():
for i in range(B):
_unpack(i, decoder.decode(syndromes_f64[i].tolist()))

try:
results = decoder.decode_batch(syndromes_f64.tolist())
except Exception as exc:
warnings.warn(f"decode_batch failed ({exc}); falling back to per-sample loop")
_loop_decode()
else:
for i, result in enumerate(results):
_unpack(i, result)

obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts}


Expand Down Expand Up @@ -249,20 +262,17 @@ def _build_ldpc_decoders(det_model):

def _decode_ldpc_batch(decoder, L_dense, syndromes_np):
"""
Decode a batch of syndromes with an ldpc decoder (single-shot loop).
Decode a batch of syndromes with an ldpc decoder.
Returns observable predictions as np.ndarray of shape (B,).
"""
B = syndromes_np.shape[0]
obs = np.zeros(B, dtype=np.uint8)
n_bits = L_dense.shape[1]
syndromes_c = np.ascontiguousarray(syndromes_np, dtype=np.uint8)
corrections = np.empty((B, n_bits), dtype=np.uint8)
for i in range(B):
# Get the most-likely error configuration from the decoder for this syndrome.
correction = decoder.decode(syndromes_np[i])
# Project the correction onto the logical observable via L_dense (mod 2).
# L_dense has shape (num_obs, num_errors); the first observable row is used.
obs[i] = (
int((L_dense @ correction).item() %
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
)
corrections[i] = decoder.decode(syndromes_c[i])

obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
return obs


Expand Down
4 changes: 2 additions & 2 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.nd
class _DetCalibReader(CalibrationDataReader):

def __init__(self, data):
self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))]
self._rows = [{"dets": data[i:i + 1]} for i in range(len(data))]
self._iter = iter(self._rows)

def get_next(self):
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
mq.quantize(
onnx_path=fp32_onnx_path,
quantize_mode=quant_format,
calibration_data={"dets": calib_dets.astype("float32")},
calibration_data={"dets": calib_dets},
output_path=onnx_path,
**quant_kwargs,
)
Expand Down
4 changes: 2 additions & 2 deletions code/export/checkpoint_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

Usage:
PYTHONPATH=code python code/export/checkpoint_to_safetensors.py \\
--checkpoint models/PreDecoderModelMemory_r9_v1.0.77.pt \\
--checkpoint models/Ising-Decoder-SurfaceCode-1-Fast.pt \\
--model-id 1 [--fp16]

Then run inference with:
PREDECODER_SAFETENSORS_CHECKPOINT=models/PreDecoderModelMemory_r9_v1.0.77_fp16.safetensors \\
PREDECODER_SAFETENSORS_CHECKPOINT=models/Ising-Decoder-SurfaceCode-1-Fast_fp16.safetensors \\
WORKFLOW=inference DISTANCE=9 N_ROUNDS=9 EXPERIMENT_NAME=predecoder_model_1 \\
bash code/scripts/local_run.sh
"""
Expand Down
Loading
Loading