Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[![Docs](https://img.shields.io/badge/docs-latest-blue?style=flat&logo=materialformkdocs)](https://queracomputing.github.io/tsim/latest)
[![Coverage](https://img.shields.io/codecov/c/github/QuEraComputing/tsim?style=flat&logo=codecov)](https://codecov.io/gh/QuEraComputing/tsim)
[![arXiv](https://img.shields.io/badge/arXiv-2403.06777-b31b1b.svg?style=flat&logo=arxiv)](https://arxiv.org/abs/2403.06777)
[![Like This? Leave a star](https://img.shields.io/github/stars/QuEraComputing/tsim?style=flat&label=Like%20Tsim%3F%20Leave%20a%20star&color=yellow&logo=github)](https://github.com/QuEraComputing/tsim)

# tsim

A GPU-accelerated circuit sampler via ZX-calculus stabilizer rank decomposition.
Expand Down
12 changes: 2 additions & 10 deletions docs/demos/from_stim_to_tsim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,7 @@
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"TSIM's samplers also have a `batch_size` argument, which does not exist in STIM. This parameter controls the number of shots sampled in parallel.\n",
"\n",
"To achieve maximum performance, it may be required to increase the `batch_size`. Especially when running on a GPU, it is recommended to increase the `batch_size` until VRAM is exhausted."
]
"source": "TSIM's samplers also have a `batch_size` argument, which does not exist in STIM. This parameter controls the number of shots sampled in parallel.\n\nWhen `batch_size` is not specified, it is automatically chosen based on available device memory (GPU VRAM or system RAM). To achieve maximum performance on a GPU, you can also set `batch_size` explicitly."
},
{
"cell_type": "code",
Expand All @@ -139,11 +135,7 @@
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"TSIM uses `jax` just-in-time compilation, which is triggered upon first execution of the `sample` function. This means that subsequent calls to `sample` with the same parameters will be faster. Note that recompilation is triggered whenever the `batch_size` is changed.\n",
"\n",
"When `batch_size` is not specified, it is set to `shots` by default."
]
"source": "TSIM uses `jax` just-in-time compilation, which is triggered upon first execution of the `sample` function. This means that subsequent calls to `sample` with the same parameters will be faster. Note that recompilation is triggered whenever the `batch_size` is changed."
},
{
"cell_type": "markdown",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"jax>=0.4.38",
"lxml>=5.0.0",
"numpy>=1.25.0",
"psutil>=5.9.0",
"pyzx-param>=0.9.2",
"stim>=1.0.0",
]
Expand Down
49 changes: 41 additions & 8 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import psutil

from tsim.compile.evaluate import evaluate
from tsim.compile.pipeline import compile_program
Expand Down Expand Up @@ -167,10 +168,40 @@ def __init__(
self.circuit = circuit
self._num_detectors = prepared.num_detectors

def _peak_bytes_per_sample(self) -> int:
"""Estimate peak device memory per sample from compiled program structure."""
peak = 0
for component in self._program.components:
for circuit in component.compiled_scalar_graphs:
G = circuit.num_graphs
max_a = circuit.a_const_phases.shape[1]
max_b = circuit.b_term_types.shape[1]
max_c = circuit.c_const_bits_a.shape[1]
max_d = circuit.d_const_alpha.shape[1]
largest = max(max_a * 16, max_b * 4, max_c * 4, max_d * 16)
peak = max(peak, G * largest * 3)
return max(peak, 1)

def _estimate_batch_size(self) -> int:
"""Estimate the largest batch size that fits in available device memory."""
device = jax.devices()[0]
if device.platform == "gpu":
stats = device.memory_stats()
available = stats.get("bytes_limit", 8 * 1024**3) - stats.get(
"bytes_in_use", 0
)
else:
available = psutil.virtual_memory().available

half_of_available = int(available * 0.5) # conservative estimate
return max(1, half_of_available // self._peak_bytes_per_sample())

def _sample_batches(self, shots: int, batch_size: int | None = None) -> np.ndarray:
"""Sample in batches and concatenate results."""
if batch_size is None:
batch_size = shots
max_batch_size = self._estimate_batch_size()
num_batches = max(1, ceil(shots / max_batch_size))
batch_size = ceil(shots / num_batches)

batches: list[jax.Array] = []
for _ in range(ceil(shots / batch_size)):
Expand Down Expand Up @@ -252,14 +283,15 @@ def __init__(self, circuit: Circuit, *, seed: int | None = None):
"""
super().__init__(circuit, sample_detectors=False, mode="sequential", seed=seed)

def sample(self, shots: int, *, batch_size: int = 1024) -> np.ndarray:
def sample(self, shots: int, *, batch_size: int | None = None) -> np.ndarray:
"""Sample measurement outcomes from the circuit.

Args:
shots: The number of times to sample every measurement in the circuit.
batch_size: The number of samples to process in each batch. When using a
GPU, it is recommended to increase this value until VRAM is fully
utilized for maximum performance.
batch_size: The number of samples to process in each batch. Defaults to
None, which automatically chooses a batch size based on available
memory. When using a GPU, setting this explicitly can help fully
utilize VRAM for maximum performance.

Returns:
A numpy array containing the measurement samples.
Expand Down Expand Up @@ -330,9 +362,10 @@ def sample(

Args:
shots: The number of times to sample every detector in the circuit.
batch_size: The number of samples to process in each batch. When using a
GPU, it is recommended to increase this value until VRAM is fully
utilized for maximum performance.
batch_size: The number of samples to process in each batch. Defaults to
None, which automatically chooses a batch size based on available
memory. When using a GPU, setting this explicitly can help fully
utilize VRAM for maximum performance.
separate_observables: Defaults to False. When set to True, the return value
is a (detection_events, observable_flips) tuple instead of a flat
detection_events array.
Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import patch

import numpy as np
import pytest

from tsim.circuit import Circuit

Expand Down Expand Up @@ -49,3 +52,32 @@ def test_sampler_repr():
repr_str = repr(sampler)
assert "CompiledMeasurementSampler" in repr_str
assert "2 error channel bits" in repr_str


@pytest.mark.parametrize(
("shots", "expected_batch_size"),
[(100, 25), (101, 26)],
)
def test_auto_batch(shots, expected_batch_size):
c = Circuit("""
H 0
M 0
""")
sampler = c.compile_sampler(seed=42)

# Mock _estimate_batch_size to return a small value so auto-batching kicks in.
with (
patch.object(type(sampler), "_estimate_batch_size", return_value=30),
patch.object(
sampler._channel_sampler,
"sample",
wraps=sampler._channel_sampler.sample,
) as channel_sample,
):
result = sampler.sample(shots)

assert result.shape == (shots, 1)
assert channel_sample.call_count == 4 # 4 batches of equal size
assert [call.args[0] for call in channel_sample.call_args_list] == [
expected_batch_size
] * 4
18 changes: 10 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading