Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d8d99cc
Make cuStabilizer the sole DEM sampling backend and consolidate tests
kvmto Mar 31, 2026
9e26338
fixed license
kvmto Mar 31, 2026
8032945
fix: use cuquantum-python-cu12 wheel to avoid pkg_resources build fai…
kvmto Mar 31, 2026
0815796
lazy imports for safe separation between training and inference
kvmto Mar 31, 2026
5d4b98a
quick fix to CI
kvmto Mar 31, 2026
485ea80
route cuQuantum dem_sampling tests to GPU CI
kvmto Mar 31, 2026
f7d7349
left behind change
kvmto Mar 31, 2026
e8b30d6
missing bash session
kvmto Mar 31, 2026
a2f559e
Make CUDA major version specific requirements files and use custabili…
bmhowe23 Mar 31, 2026
96d37ba
Revert some changes to test files that are hopefully no longer needed
bmhowe23 Mar 31, 2026
5517d84
Revert REQUIRE_CUQUANTUM changes
bmhowe23 Mar 31, 2026
f3d6ff3
Change custabilizer version to 0.3.0
bmhowe23 Mar 31, 2026
3bd1740
Change custabilizer back to cuquantum-python
bmhowe23 Apr 1, 2026
ee15114
Skip test_dem_sampling.py if required deps are not present
bmhowe23 Apr 1, 2026
a43279e
Try again
bmhowe23 Apr 1, 2026
3b8015d
Skip a few more tests if cuquantum-python not installed
bmhowe23 Apr 1, 2026
ce055a2
Revert CUDA major version specific requirements files
bmhowe23 Apr 1, 2026
05e92f8
Revert "Revert CUDA major version specific requirements files"
bmhowe23 Apr 1, 2026
91b0fa3
small torch device object bug fix for nccl
kvmto Apr 1, 2026
0e7150e
overcome custab device id limitation
kvmto Apr 1, 2026
5a5da42
added tiny logging
kvmto Apr 1, 2026
894e508
linted
kvmto Apr 1, 2026
84c814b
Revert "Revert "Revert CUDA major version specific requirements files""
bmhowe23 Apr 1, 2026
413d1f8
Revert "Revert "Revert "Revert CUDA major version specific requiremen…
bmhowe23 Apr 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ jobs:
python3.13 -m venv .venv_mid
. .venv_mid/bin/activate
python -m pip install --upgrade pip setuptools wheel
pip install -r code/requirements_public_train.txt
# TODO: matrix by CUDA major version [cu12, cu13]
pip install -r code/requirements_public_train-cu12.txt

- name: Mid-tier training + inference with LER check (32k train, 2 epochs)
shell: bash
Expand Down Expand Up @@ -212,7 +213,8 @@ jobs:
python3 -m venv .venv_gpu_cov
. .venv_gpu_cov/bin/activate
python -m pip install --upgrade pip setuptools wheel
pip install -r code/requirements_public_inference.txt
# TODO: matrix by CUDA major version [cu12, cu13]
pip install -r code/requirements_public_train-cu12.txt
pip install -r code/requirements_ci.txt

- name: Run tests with GPU coverage
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/long-running-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ jobs:
python${{ env.PYTHON_VERSION }} -m venv .venv
. .venv/bin/activate
python -m pip install --upgrade pip setuptools wheel
pip install -r code/requirements_public_train.txt
# TODO: matrix by CUDA major version [cu12, cu13]
pip install -r code/requirements_public_train-cu12.txt

- name: Verify GPU
run: |
Expand Down Expand Up @@ -302,7 +303,8 @@ jobs:
python${{ env.PYTHON_VERSION }} -m venv .venv
. .venv/bin/activate
python -m pip install --upgrade pip setuptools wheel
pip install -r code/requirements_public_train.txt
# TODO: matrix by CUDA major version [cu12, cu13]
pip install -r code/requirements_public_train-cu12.txt

- name: Verify GPU
run: |
Expand Down
13 changes: 9 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ RUN python${PYTHON_VERSION} -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"

COPY code/requirements_public_inference.txt /tmp/requirements_public_inference.txt
COPY code/requirements_public_train.txt /tmp/requirements_public_train.txt

RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
COPY code/requirements_public_train-cu*.txt /tmp/

# Derive the CUDA major version from the base image's $CUDA_VERSION env var
# (e.g. "12.1.0" -> "12") and install the matching requirements file.
RUN CUDA_MAJOR_VERSION=$(echo "${CUDA_VERSION}" | cut -d. -f1) && \
echo "Detected CUDA major version: ${CUDA_MAJOR_VERSION}" && \
echo "export CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}" >> /etc/bash.bashrc && \
pip install --no-cache-dir --upgrade pip setuptools wheel && \
pip install --no-cache-dir \
-r /tmp/requirements_public_train.txt \
-r /tmp/requirements_public_train-cu${CUDA_MAJOR_VERSION}.txt \
--index-url "https://download.pytorch.org/whl/${TORCH_CUDA}" \
--extra-index-url https://pypi.org/simple && \
python -c "import torch; print('PyTorch', torch.__version__, '(CUDA build:', torch.version.cuda, ')')"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Target Python versions: **3.11, 3.12, 3.13**.
Two minimal requirements files are provided:

- `code/requirements_public_inference.txt` (Stim + PyTorch path)
- `code/requirements_public_train.txt` (training path)
- `code/requirements_public_train-cuXY.txt` (training path, where XY = 12 or 13)

Install examples (virtual environment is optional but recommended):

Expand All @@ -41,8 +41,8 @@ export TORCH_CUDA=cu130
# Inference-only (training install is a superset)
pip install -r code/requirements_public_inference.txt

# Training (includes inference deps)
pip install -r code/requirements_public_train.txt
# Training (includes inference deps, adjust to cu13 as appropriate)
pip install -r code/requirements_public_train-cu12.txt

bash code/scripts/check_python_compat.sh
```
Expand Down
177 changes: 73 additions & 104 deletions code/qec/dem_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@
"""
DEM sampling utilities for training data generation.

When cuQuantum's cuStabilizer (BitMatrixSampler) is installed the sampling
runs on the GPU via the cuST sparse sampler with optional CuPy zero-copy
DLPack transfers. When cuST is absent or disabled (USE_CUSTAB=0) the module
falls back to a pure-torch implementation.
Sampling runs on the GPU via cuQuantum's cuStabilizer BitMatrixSampler with
optional CuPy zero-copy DLPack transfers. cuquantum>=26.3.0 is required.

This module provides the sampling functions needed by MemoryCircuitTorch
to generate training batches from precomputed DEM matrices (H, p, A).
"""

from __future__ import annotations

import os
import time
from collections import deque

Expand All @@ -38,6 +35,9 @@
from cuquantum.stabilizer.simulator import Options
_CUSTAB_AVAILABLE = True
except ImportError:
# This should only happen if cuquantum is not installed. That is expected
# for some test environments that don't need DEM sampling, so handle that
# gracefully here.
BitMatrixSampler = None # type: ignore[misc, assignment]
Options = None # type: ignore[misc, assignment]
_CUSTAB_AVAILABLE = False
Expand All @@ -54,14 +54,14 @@ def _custab_available() -> bool:
return _CUSTAB_AVAILABLE


_cached_sampler: "BitMatrixSampler | None" = None
_cached_H_id: int | None = None
_cached_sampler = None
_cached_H: "torch.Tensor | None" = None
_cached_HT: "torch.Tensor | None" = None
_cached_max_shots: int = 0
_cached_device_id: int | None = None

_DEM_TIMINGS_S: deque[float] = deque(maxlen=200)
_use_custab_cached: bool | None = None
_custab_path_logged: bool = False
_fallback_path_logged: bool = False

_MIN_MAX_SHOTS = 1024

Expand All @@ -75,42 +75,75 @@ def get_dem_sampling_avg_ms() -> float:

def _reset_sampler_cache() -> None:
"""Reset the module-level sampler cache."""
global _cached_sampler, _cached_H_id, _cached_max_shots
global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots, _cached_device_id
_cached_sampler = None
_cached_H_id = None
_cached_H = None
_cached_HT = None
_cached_max_shots = 0
_cached_device_id = None


def custab_matrix_sampling(
H: torch.Tensor, p: torch.Tensor, batch_size: int, device_id: int = 0
def dem_sampling(
H: torch.Tensor,
p: torch.Tensor,
batch_size: int,
device_id: int | None = None
) -> torch.Tensor:
"""
Sample from a DEM using cuST BitMatrixSampler. H must be (errors, result) layout.
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.

Args:
H: (2*num_detectors, num_errors) uint8 - Detector-error incidence matrix
p: (num_errors,) float32 - Per-error probabilities
batch_size: int - Number of samples to generate
device_id: Optional int - Device ID for cuST. If omitted, infer from
H.device when H is on CUDA.

When CuPy is available the entire pipeline stays on GPU:
torch CUDA -> CuPy (zero-copy DLPack) -> cuStabilizer -> CuPy -> torch (zero-copy DLPack)
Returns:
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
"""
if not _CUSTAB_AVAILABLE or BitMatrixSampler is None or Options is None:
raise RuntimeError("custab_matrix_sampling requires cuquantum.stabilizer")
from cuquantum.stabilizer.dem_sampling import BitMatrixSampler
from cuquantum.stabilizer.simulator import Options

global _cached_sampler, _cached_H, _cached_HT, _cached_max_shots
global _cached_device_id, _custab_path_logged

if H.ndim != 2:
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
if p.ndim != 1:
raise ValueError(f"p must be 1-D, got ndim={p.ndim}")
if H.shape[1] != p.shape[0]:
raise ValueError(f"H has {H.shape[1]} columns but p has {p.shape[0]} entries")

if device_id is None:
if H.is_cuda:
device_index = H.device.index
device_id = int(torch.cuda.current_device() if device_index is None else device_index)
else:
device_id = 0

gpu_native = _CUPY_AVAILABLE and H.is_cuda

global _cached_sampler, _cached_H_id, _cached_max_shots, _custab_path_logged
if _cached_H is not H:
_cached_HT = H.T
_cached_H = H
_cached_sampler = None
_cached_device_id = None

# id(H) is the tensor's memory address — fast but not content-based.
# Safe in training loops where H is a long-lived tensor; a content hash
# (like cuda-qx-g uses) would be more robust but slower on every call.
H_id = id(H)
need_new = (_cached_sampler is None or _cached_H_id != H_id or batch_size > _cached_max_shots)
need_new = (
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
)

if need_new:
max_shots = max(batch_size, _MIN_MAX_SHOTS)
gpu_native = _CUPY_AVAILABLE and H.is_cuda
if gpu_native:
import cupy as cp
H_in = cp.from_dlpack(H.detach())
p_in = cp.from_dlpack(p.detach().to(torch.float64))
with cp.cuda.Device(device_id):
H_in = cp.from_dlpack(_cached_HT.detach())
p_in = cp.from_dlpack(p.detach().to(torch.float64))
pkg = "cupy"
else:
H_in = H.detach().cpu().numpy().astype(np.uint8)
H_in = _cached_HT.detach().cpu().numpy().astype(np.uint8)
p_in = p.detach().cpu().numpy().astype(np.float64)
pkg = "numpy"
_cached_sampler = BitMatrixSampler(
Expand All @@ -120,98 +153,34 @@ def custab_matrix_sampling(
package=pkg,
options=Options(device_id=device_id),
)
_cached_H_id = H_id
_cached_max_shots = max_shots
_cached_device_id = device_id

_cached_sampler.sample(batch_size)

out = _cached_sampler.get_outcomes(bit_packed=False)
t0 = time.perf_counter()
if gpu_native:
import cupy as cp
with cp.cuda.Device(device_id):
_cached_sampler.sample(batch_size)
out = _cached_sampler.get_outcomes(bit_packed=False)
else:
_cached_sampler.sample(batch_size)
out = _cached_sampler.get_outcomes(bit_packed=False)
if isinstance(out, np.ndarray):
out = torch.as_tensor(out, device=H.device).to(dtype=torch.uint8)
else:
out = torch.from_dlpack(out).to(dtype=torch.uint8)
_DEM_TIMINGS_S.append(time.perf_counter() - t0)

if not _custab_path_logged:
print(
f"---- [dem_sampling] Using cuST BitMatrixSampler path "
f"(max_shots={_cached_max_shots}, gpu_native={_CUPY_AVAILABLE})"
f"(max_shots={_cached_max_shots}, gpu_native={gpu_native}, device_id={device_id})"
)
_custab_path_logged = True

return out


def _use_custab() -> bool:
"""Use cuST if available and not disabled by USE_CUSTAB=0. Cached after first call."""
global _use_custab_cached
if _use_custab_cached is None:
if not _CUSTAB_AVAILABLE:
_use_custab_cached = False
else:
v = os.environ.get("USE_CUSTAB", "1").strip().lower()
_use_custab_cached = v not in ("0", "false", "no", "off")
return _use_custab_cached


def _reset_use_custab_cache() -> None:
"""Reset the _use_custab cache (e.g. after changing USE_CUSTAB in tests)."""
global _use_custab_cached
_use_custab_cached = None


def dem_sampling(
H: torch.Tensor, p: torch.Tensor, batch_size: int, device_id: int = 0
) -> torch.Tensor:
"""
Sample errors from a detector error model (DEM) using precomputed H and p matrices.
Uses cuST BitMatrixSampler when available; if cuST is not present or USE_CUSTAB=0,
uses the torch fallback.

Args:
H: (2*num_detectors, num_errors) uint8 - Detector-error incidence matrix
p: (num_errors,) float32 - Per-error probabilities
batch_size: int - Number of samples to generate
device_id: int - Device ID for cuST (ignored by torch path).

Returns:
frames_xz: (batch_size, 2*num_detectors) uint8 - Detector outcomes
"""
if H.ndim != 2:
raise ValueError(f"H must be 2-D, got ndim={H.ndim}")
if p.ndim != 1:
raise ValueError(f"p must be 1-D, got ndim={p.ndim}")
if H.shape[1] != p.shape[0]:
raise ValueError(f"H has {H.shape[1]} columns but p has {p.shape[0]} entries")

global _fallback_path_logged
t0 = time.perf_counter()

if _use_custab():
# cuST expects (errors, result); dem_sampling H is (result, errors) -> pass H.T
out = custab_matrix_sampling(H.T, p, batch_size, device_id)
_DEM_TIMINGS_S.append(time.perf_counter() - t0)
return out

num_errors = int(H.shape[1])
device = H.device

# Sample errors according to their probabilities (independent Bernoulli)
rand_vals = torch.rand(batch_size, num_errors, device=device, dtype=torch.float32)
errors = (rand_vals < p[None, :]).to(torch.uint8) # (batch_size, num_errors)

# Matrix multiply H @ errors^T to get detector outcomes
# H is (2*num_detectors, num_errors), errors is (batch_size, num_errors)
frames_xz = torch.matmul(errors.to(torch.float32), H.T.to(torch.float32))
frames_xz = frames_xz.to(torch.uint8) % 2 # Binary GF(2) arithmetic

_DEM_TIMINGS_S.append(time.perf_counter() - t0)
if not _fallback_path_logged:
print("Used fallback torch path for dem_sampling")
_fallback_path_logged = True

return frames_xz


def measure_from_stacked_frames(
frames_xz: torch.Tensor,
meas_qubits: torch.Tensor,
Expand Down
27 changes: 21 additions & 6 deletions code/qec/surface_code/homological_equivalence_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def _simplify_time_w1_step_nobreak(
return err_out, syn_out


_INT8_GEMM_OK: dict[str, bool] = {}
_INT8_GEMM_WARNED: set[str] = set()


def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tensor:
"""
Weight reduction (parallel within disjoint stabilizer layers).
Expand All @@ -355,7 +359,11 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
# at most 4 (stabilizer support size) and act1/act2 are bool→int8 with at
# most L ones, so intermediate sums stay well within int8 range as long as
# L < 128 (true for practical surface code distances).
_use_int8 = True
#
# _INT8_GEMM_OK caches per-device results so after one failure on a given
# device we skip int8 permanently (no repeated exceptions / warnings).
dev_key = str(cfg.device)
_use_int8 = _INT8_GEMM_OK.get(dev_key, True)

for layer_idx in cache.layers:
if layer_idx.numel() == 0:
Expand All @@ -372,12 +380,16 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
flip_mask = ((act2.to(torch.int8) @ masks_i8).to(torch.int32)
> 0) & (~set_to_zero_mask)
except RuntimeError as exc:
warnings.warn(
f"Int8 GEMM failed, falling back to float32 for weight reduction: {exc}",
RuntimeWarning,
stacklevel=2,
)
_use_int8 = False
_INT8_GEMM_OK[dev_key] = False
if dev_key not in _INT8_GEMM_WARNED:
_INT8_GEMM_WARNED.add(dev_key)
warnings.warn(
f"Int8 GEMM failed on {dev_key}, permanently falling back to "
f"float32 for weight reduction: {exc}",
RuntimeWarning,
stacklevel=2,
)
masks_f = cache.support_masks.to(torch.float32).index_select(0, layer_idx)
error_counts = (cfg.to(torch.float32) @ masks_f.t()).to(torch.int32)
act1 = (error_counts == 4) | ((error_counts == 2) & (sizes.unsqueeze(0) == 2))
Expand All @@ -396,6 +408,9 @@ def _weight_reduction(cfg: torch.Tensor, cache: SpacelikeHECache) -> torch.Tenso
cfg = cfg ^ flip_mask.to(cfg.dtype)
cfg_i8 = cfg.to(torch.int8)

if _use_int8 and dev_key not in _INT8_GEMM_OK:
_INT8_GEMM_OK[dev_key] = True

return cfg


Expand Down
Loading
Loading