In [None]:
import stim
from tsim.circuit import Circuit
import matplotlib.pyplot as plt
import numpy as np
import time

In [None]:
p = 0.01
stim_circ = stim.Circuit.generated(
    "repetition_code:memory",
    # "surface_code:rotated_memory_z",
    distance=3,
    rounds=2,
    after_clifford_depolarization=p,
    after_reset_flip_probability=p * 2,
    before_measure_flip_probability=p,
    before_round_data_depolarization=p * 3,
)
stim_circ.diagram("timeline-svg")

In [None]:
c = Circuit.from_stim_circuit(stim_circ)
c.diagram()
c.without_noise().diagram()

In [None]:
sampler = c.compile_sampler()
print(sampler)

In [None]:
n_samples = 200
sampler.sample(n_samples, 100)

In [None]:
stim_sampler = stim_circ.compile_sampler()
stim_sampler.sample(n_samples)

In [None]:
det_sampler = c.compile_detector_sampler()  # so far only supports observables

In [None]:
det_sampler.sample(n_samples)

In [None]:
stim_det_sampler = stim_circ.compile_detector_sampler()
dets, obs = stim_det_sampler.sample(n_samples, separate_observables=True)
obs

In [None]:
n_samples = 5_000
samples = sampler.sample(n_samples, 100)
stim_samples = stim_sampler.sample(n_samples)

plt.hist(np.count_nonzero(samples, axis=1), alpha=0.5, label="ZX")
plt.hist(np.count_nonzero(stim_samples, axis=1), alpha=0.5, label="Stim")
plt.legend();

In [None]:
n_samples = 5_000

start = time.perf_counter()
obs_samples = det_sampler.sample(n_samples)
duration_zx = time.perf_counter() - start

start = time.perf_counter()
obs_stim_samples = stim_det_sampler.sample(n_samples)
duration_stim = time.perf_counter() - start

print(f"(ZX)   Observable flip rate: {np.count_nonzero(samples) / samples.size}")
print(f"(Stim) Observable flip rate: {np.count_nonzero(stim_samples) / stim_samples.size}")

print("\nTime per sample:")
print(f"(ZX)   {duration_zx / n_samples:.2e} seconds")
print(f"(Stim) {duration_stim / n_samples:.2e} seconds")

Magic state distillation

In [None]:
stim_circ = stim.Circuit.from_file("msd_circuits/d=3_X.stim")
c = Circuit()
num_qubits = stim_circ.num_qubits
block_size = num_qubits // 5

for i in range(num_qubits):
    c.r(i)

targets = np.arange(5) * block_size
if block_size == 7:
    targets += 6
else:
    targets += 7

for t in targets:
    t = int(t)
    c.h(t)
    c.t(t)
    c.h(t)

c.append_stim_circuit(stim_circ.without_noise())
# c.append_stim_circuit(stim_circ)
c.diagram(labels=False)

In [None]:
sampler = c.compile_detector_sampler()
print(sampler)

In [None]:
sampler.sample(40, 20)
