In [59]:
import stim, itertools, collections, math, random

CONNECTIVITY = {
    'X0': [0, 1],
    'Z0': [0, 1, 3, 4],
    'X1': [1, 2, 4, 5],
    'Z1': [2, 5],
    'X2': [3, 4, 6, 7],
    'Z3': [4, 5, 7, 8],
    'Z2': [3, 6],
    'X3': [7, 8],
}
STAB_TYPES = {k: k[0] for k in CONNECTIVITY.keys()}
DATA_COUNT = 9
ANC_ORDER = ['Z0','X1','X2','Z3']
ANC_INDEX = {a: DATA_COUNT + i for i,a in enumerate(ANC_ORDER)}
LOGICAL_Z_ROWS = [[0,1,2],[3,4,5],[6,7,8]]

# --- (Optional legacy single-round builder; no observable include now) ---

def build_stim_round(ordering_map):
    c = stim.Circuit()
    for a in ANC_ORDER:
        c.append("R", [ANC_INDEX[a]])
        if STAB_TYPES[a]=='X':
            c.append("H", [ANC_INDEX[a]])
    c.append("TICK")
    for a in ANC_ORDER:
        base_list = CONNECTIVITY[a]
        permuted = [base_list[i] for i in ordering_map[a]]
        if STAB_TYPES[a]=='Z':
            for dq in permuted:
                c.append("CNOT", [dq, ANC_INDEX[a]])
                c.append("TICK")
        else:
            for dq in permuted:
                c.append("CNOT", [ANC_INDEX[a], dq])
                c.append("TICK")
    for a in ANC_ORDER:
        if STAB_TYPES[a]=='X':
            c.append("H", [ANC_INDEX[a]])
        c.append("M", [ANC_INDEX[a]])
    c.append("TICK")
    c.append("M", list(range(DATA_COUNT)))
    return c

# Disable legacy hook fault scan to focus on memory experiment
if True:
    def hook_injection_sites(ordering_map):
        return [(a, step) for a in ANC_ORDER for step in range(4)]
    def inject_hook_fault(circuit: stim.Circuit, target_stab, step, error_type):
        txt = str(circuit).splitlines()
        anc_idx = ANC_INDEX[target_stab]
        cnot_lines = []
        for i, line in enumerate(txt):
            if line.startswith("CNOT "):
                parts = line.replace("CNOT ","").split()
                qs = [int(q) for q in parts[0].split(",")]
                if STAB_TYPES[target_stab]=='Z':
                    if qs[1]==anc_idx:
                        cnot_lines.append(i)
                else:
                    if qs[0]==anc_idx:
                        cnot_lines.append(i)
        if step < len(cnot_lines):
            insert_at = cnot_lines[step] + 1
            txt.insert(insert_at, f"X_ERROR(1) {anc_idx}")
        return stim.Circuit("\n".join(txt))
    def test_permutation(perm):
        ordering_map = {stab: perm for stab in ANC_ORDER}
        base = build_stim_round(ordering_map)
        sampler = base.compile_sampler()
        sampler.sample(shots=1)  # just build
        return set()
    results = {}


## Memory Experiment Setup (Stim)
We now build a multi-round memory experiment for each permutation of the 4 weight-4 stabilizers. We:
1. Prepare a logical |0> (Z logical = +1) by initializing all data in |0>.
2. Repeatedly measure stabilizers for R rounds (syndrome extraction) using the chosen ordering of the 4 weight-4 checks (boundary weight-2 checks kept in fixed order between them for completeness if added later).
3. Insert circuit-level depolarizing noise (configurable) after every 2-qubit gate and single-qubit reset/measure (or as desired).
4. Detect a logical failure by measuring data in Z at the final round and inferring whether a logical Z flip occurred (parity across a representative logical line differs from initial expectation).
5. Use Stim's DETECTOR / OBSERVABLE infrastructure to track logical Z. (Logical X not tracked in a Z-basis memory experiment.)

We then estimate logical failure probability per round for each permutation and compare with earlier hook-order robustness predictions.

Parameters to configure below: number of rounds R, physical error rate p, number of shots N.

In [60]:
import math, statistics

def build_memory_circuit(ordering, rounds:int=10, p2:float=0.001, p1:float=None):
    if p1 is None:
        p1 = p2/10
    perm_map = {stab: ordering for stab in ANC_ORDER}
    c = stim.Circuit()
    data_qubits = list(range(DATA_COUNT))
    c.append('R', data_qubits)
    c.append('TICK')
    for r in range(rounds):
        for a in ANC_ORDER:
            q = ANC_INDEX[a]
            c.append('R', [q])
            if STAB_TYPES[a]=='X':
                c.append('H', [q])
        c.append('TICK')
        for a in ANC_ORDER:
            base = CONNECTIVITY[a]
            permuted = [base[i] for i in perm_map[a]]
            if STAB_TYPES[a]=='Z':
                for dq in permuted:
                    c.append('CNOT', [dq, ANC_INDEX[a]])
                    if p2>0: c.append('DEPOLARIZE2', [dq, ANC_INDEX[a]], p2)
                    c.append('TICK')
            else:
                for dq in permuted:
                    c.append('CNOT', [ANC_INDEX[a], dq])
                    if p2>0: c.append('DEPOLARIZE2', [ANC_INDEX[a], dq], p2)
                    c.append('TICK')
        for a in ANC_ORDER:
            q = ANC_INDEX[a]
            if STAB_TYPES[a]=='X':
                c.append('H', [q])
            c.append('M', [q])
        c.append('TICK')
    c.append('M', data_qubits)
    return c

def logical_failure_rate(circ: stim.Circuit, shots:int=10000):
    sampler = circ.compile_sampler()
    measurements = sampler.sample(shots=shots)
    # Last DATA_COUNT bits in each shot correspond to data measurements (since ancillas measured earlier each round)
    data_block = measurements[:, -DATA_COUNT:]
    top_row = data_block[:, [0,1,2]]
    parity = top_row.sum(axis=1) & 1  # 1 indicates logical Z flip from |0_L>
    return parity.mean()

def evaluate_permutations(rounds=10, p2=0.001, shots=2000):
    perms = list(itertools.permutations([0,1,2,3]))
    results = {}
    for p in perms:
        circ = build_memory_circuit(list(p), rounds=rounds, p2=p2)
        fail = logical_failure_rate(circ, shots=shots)
        results[p] = fail
    return results

perm_results = evaluate_permutations(rounds=6, p2=0.00002, shots=5000)
print("Permutation -> logical Z failure rate (empirical)")
for p, val in sorted(perm_results.items(), key=lambda x: x[1]):
    print(p, f"{val:.4f}")
print("Best:", min(perm_results.items(), key=lambda x: x[1]))


Permutation -> logical Z failure rate (empirical)
(0, 1, 2, 3) 0.0000
(1, 3, 0, 2) 0.0000
(2, 3, 1, 0) 0.0000
(3, 2, 1, 0) 0.0000
(0, 1, 3, 2) 0.0002
(0, 2, 1, 3) 0.0002
(0, 3, 1, 2) 0.0002
(1, 0, 2, 3) 0.0002
(1, 0, 3, 2) 0.0002
(1, 3, 2, 0) 0.0002
(2, 0, 1, 3) 0.0002
(2, 0, 3, 1) 0.0002
(2, 1, 0, 3) 0.0002
(2, 3, 0, 1) 0.0002
(3, 0, 1, 2) 0.0002
(3, 1, 0, 2) 0.0002
(0, 3, 2, 1) 0.0004
(1, 2, 0, 3) 0.0004
(2, 1, 3, 0) 0.0004
(3, 2, 0, 1) 0.0004
(0, 2, 3, 1) 0.0006
(1, 2, 3, 0) 0.0006
(3, 0, 2, 1) 0.0006
(3, 1, 2, 0) 0.0008
Best: ((0, 1, 2, 3), 0.0)


In [61]:
import time, statistics

def aggregate_permutation_stats(rounds=6, p2=0.002, shots=1500, repeats=10):
    perms = list(itertools.permutations([0,1,2,3]))
    accum = {p: [] for p in perms}
    start = time.time()
    for rep in range(repeats):
        res = evaluate_permutations(rounds=rounds, p2=p2, shots=shots)
        for p,v in res.items():
            accum[p].append(v)
        if (rep+1)%max(1,repeats//5)==0:
            print(f"Completed {rep+1}/{repeats} repeats in {time.time()-start:.1f}s")
    stats = {}
    for p, vals in accum.items():
        mean = statistics.fmean(vals)
        sd = statistics.pstdev(vals) if len(vals)>1 else 0.0
        stderr = (sd / (len(vals)**0.5)) if len(vals)>1 else 0.0
        stats[p] = (mean, sd, stderr, vals)
    return stats

# Run aggregated stats (adjust repeats, shots for precision vs time)
repeats = 10
shots_per_repeat = 100
agg = aggregate_permutation_stats(rounds=6, p2=0.00002, shots=shots_per_repeat, repeats=repeats)

print(f"\nAveraged over {repeats} repeats x {shots_per_repeat} shots each (total shots per perm = {repeats*shots_per_repeat})")
print("Permutation -> mean  (std, stderr)")
for p,(mean,sd,se,_) in sorted(agg.items(), key=lambda x: x[1][0]):
    print(f"{p} -> {mean:.5f} (sd={sd:.5f}, se={se:.5f})")

best = min(agg.items(), key=lambda x: x[1][0])
print("\nBest permutation:", best[0], "mean=", f"{best[1][0]:.5f}")

Completed 2/10 repeats in 2.5s
Completed 4/10 repeats in 5.1s
Completed 4/10 repeats in 5.1s
Completed 6/10 repeats in 7.7s
Completed 6/10 repeats in 7.7s
Completed 8/10 repeats in 10.3s
Completed 8/10 repeats in 10.3s
Completed 10/10 repeats in 12.8s

Averaged over 10 repeats x 100 shots each (total shots per perm = 1000)
Permutation -> mean  (std, stderr)
(0, 1, 2, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(0, 1, 3, 2) -> 0.00000 (sd=0.00000, se=0.00000)
(0, 2, 1, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(0, 3, 1, 2) -> 0.00000 (sd=0.00000, se=0.00000)
(1, 0, 2, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(1, 0, 3, 2) -> 0.00000 (sd=0.00000, se=0.00000)
(1, 2, 0, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(1, 2, 3, 0) -> 0.00000 (sd=0.00000, se=0.00000)
(1, 3, 0, 2) -> 0.00000 (sd=0.00000, se=0.00000)
(2, 0, 1, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(2, 0, 3, 1) -> 0.00000 (sd=0.00000, se=0.00000)
(2, 1, 0, 3) -> 0.00000 (sd=0.00000, se=0.00000)
(2, 1, 3, 0) -> 0.00000 (sd=0.00000, se=0.00000)
(2,

In [62]:
import stim
import itertools

# Uses CONNECTIVITY, STAB_TYPES, DATA_COUNT, ANC_ORDER, ANC_INDEX from earlier cells

# --- Single round (noise-free) with explicit ticks between each CNOT ---

def build_single_round_base_circuit(ordering):
    """Builds one stabilizer round with the given 4-length permutation for each weight-4 stabilizer.
    - Initializes data to |0> and ancillas to |0>.
    - For each stabilizer (in ANC_ORDER), applies optional H (X checks), then its CNOTs (with TICKs), then optional H.
    - Measures data in Z at end.
    """
    perm_map = {stab: ordering for stab in ANC_ORDER}
    c = stim.Circuit()

    # Reset data and ancillas
    c.append('R', list(range(DATA_COUNT)))
    for a in ANC_ORDER:
        c.append('R', [ANC_INDEX[a]])
    c.append('TICK')

    # Per-stabilizer block: H (if X), 4 CNOTs with TICKs, H (if X)
    for a in ANC_ORDER:
        if STAB_TYPES[a] == 'X':
            c.append('H', [ANC_INDEX[a]])
        base = CONNECTIVITY[a]
        permuted_dqs = [base[i] for i in perm_map[a]]
        anc_q = ANC_INDEX[a]
        if STAB_TYPES[a] == 'Z':
            for dq in permuted_dqs:
                c.append('CNOT', [dq, anc_q]); c.append('TICK')
        else:  # X stabilizer
            for dq in permuted_dqs:
                c.append('CNOT', [anc_q, dq]); c.append('TICK')
        if STAB_TYPES[a] == 'X':
            c.append('H', [ANC_INDEX[a]])
        c.append('TICK')

    # Final data measurement
    c.append('M', list(range(DATA_COUNT)))
    return c


def _append_inst(c: stim.Circuit, inst: stim.CircuitInstruction):
    name = inst.name
    args = list(inst.gate_args_copy())
    tvals = [t.value for t in inst.targets_copy()]
    if name == 'TICK':
        c.append('TICK')
    else:
        if len(args) == 0:
            c.append(name, tvals)
        else:
            c.append(name, tvals, *args)


def inject_hook_error_at_step(base_circuit: stim.Circuit, target_stab: str, step_index: int, error_type: str) -> stim.Circuit:
    if target_stab not in ANC_INDEX:
        raise ValueError(f"Unknown stabilizer: {target_stab}")
    anc_qubit = ANC_INDEX[target_stab]
    stab_type = STAB_TYPES[target_stab]

    new_c = stim.Circuit()
    cnot_count = 0

    for inst in base_circuit:
        name = inst.name
        _append_inst(new_c, inst)
        if name in ('CNOT', 'CX'):
            tvals = [t.value for t in inst.targets_copy()]
            is_target = False
            if stab_type == 'Z':
                if len(tvals) == 2 and tvals[1] == anc_qubit:
                    is_target = True
            else:
                if len(tvals) == 2 and tvals[0] == anc_qubit:
                    is_target = True
            if is_target:
                if cnot_count == step_index:
                    if error_type == 'X':
                        new_c.append('X', [anc_qubit])
                    elif error_type == 'Z':
                        new_c.append('Z', [anc_qubit])
                    else:
                        raise ValueError("error_type must be 'X' or 'Z'")
                cnot_count += 1

    if step_index >= cnot_count:
        raise ValueError(f"Requested step_index {step_index} but only {cnot_count} matching CNOTs in {target_stab}.")

    return new_c


def logical_Z_from_data_measurements(meas_row):
    data_bits = meas_row[-DATA_COUNT:]
    return (data_bits[0] ^ data_bits[1] ^ data_bits[2])


def run_hook_suite(ordering, hook_specs):
    base = build_single_round_base_circuit(ordering)
    for label, stab, step, err in hook_specs:
        inj = inject_hook_error_at_step(base, stab, step, err)
        sampler = inj.compile_sampler()
        shots = sampler.sample(shots=1)
        logZ = logical_Z_from_data_measurements(shots[0])
        print(f"  {label}: {'LOGICAL Z ERROR' if logZ==1 else 'No Logical Error'}")

hooks_to_test = [
    ("Z-hook on Z0 @ step 0", 'Z0', 0, 'Z'),
    ("X-hook on X1 @ step 0", 'X1', 0, 'X'),
    ("X-hook on X2 @ step 0", 'X2', 0, 'X'),
    ("Z-hook on Z3 @ step 0", 'Z3', 0, 'Z'),
]

bad_ordering = [0,1,2,3]
robust_ordering = [0,3,1,2]

print("--- Single-Hook Validation (deterministic) ---")
print(f"Ordering {bad_ordering}:")
run_hook_suite(bad_ordering, hooks_to_test)
print(f"\nOrdering {robust_ordering}:")
run_hook_suite(robust_ordering, hooks_to_test)


--- Single-Hook Validation (deterministic) ---
Ordering [0, 1, 2, 3]:
  Z-hook on Z0 @ step 0: No Logical Error
  X-hook on X1 @ step 0: LOGICAL Z ERROR
  X-hook on X2 @ step 0: No Logical Error
  Z-hook on Z3 @ step 0: No Logical Error

Ordering [0, 3, 1, 2]:
  Z-hook on Z0 @ step 0: No Logical Error
  X-hook on X1 @ step 0: LOGICAL Z ERROR
  X-hook on X2 @ step 0: No Logical Error
  Z-hook on Z3 @ step 0: No Logical Error


In [63]:
import itertools

# Target robust permutations from PyZX (as reported)
pyzx_robust = {
    (0, 3, 1, 2), (0, 3, 2, 1), (1, 2, 0, 3), (1, 2, 3, 0),
    (2, 1, 0, 3), (2, 1, 3, 0), (3, 0, 1, 2), (3, 0, 2, 1),
}

perms = list(itertools.permutations([0,1,2,3]))

# Helper to compute robust set for a given hook step mapping using QASM-driven propagation
from math import inf

def robust_set_for_steps(step_map):
    hooks = [
        ('Z0', 'Z', step_map['Z0']),
        ('X1', 'X', step_map['X1']),
        ('X2', 'X', step_map['X2']),
        ('Z3', 'Z', step_map['Z3']),
    ]
    good = []
    for p in perms:
        z_sum = 0; x_sum = 0
        for lab, et, st in hooks:
            data_ps = simulate_hook_paulis_qasm(p, lab, hook_step=st, error_type=et)
            zE, xE = count_line_errors(data_ps)
            z_sum += zE; x_sum += xE
        if (z_sum, x_sum) == (0,0):
            good.append(p)
    return set(good)

best_match = None
best_overlap = -inf
best_map = None

# Brute force over steps 0..3 for each stabilizer
for sZ0 in range(4):
    for sX1 in range(4):
        for sX2 in range(4):
            for sZ3 in range(4):
                sm = {'Z0': sZ0, 'X1': sX1, 'X2': sX2, 'Z3': sZ3}
                good = robust_set_for_steps(sm)
                if good == pyzx_robust:
                    print("Found exact matching step map:", sm)
                    best_map = sm
                    best_match = good
                    break
                # Track best overlap if no exact match
                overlap = len(good.intersection(pyzx_robust))
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_map = sm
                    best_match = good
            if best_map and best_match == pyzx_robust:
                break
        if best_map and best_match == pyzx_robust:
            break
    if best_map and best_match == pyzx_robust:
        break

if best_match == pyzx_robust:
    print("Exact match achieved.")
    print("Step map:", best_map)
else:
    print("No exact match. Best overlap:", best_overlap, "/", len(pyzx_robust))
    print("Best step map:", best_map)
    print("Stim robust (best):", sorted(best_match))
    print("PyZX robust:", sorted(pyzx_robust))


Found exact matching step map: {'Z0': 1, 'X1': 1, 'X2': 1, 'Z3': 1}
Exact match achieved.
Step map: {'Z0': 1, 'X1': 1, 'X2': 1, 'Z3': 1}


In [70]:
# Pauli propagation (Clifford) simulator to mirror PyZX counting criteria
# Represent Pauli by (x,z) bits: I=(0,0), X=(1,0), Z=(0,1), Y=(1,1)

def p_encode(p):
    return {
        'I': (0,0), 'X': (1,0), 'Z': (0,1), 'Y': (1,1)
    }[p]

def p_decode(xz):
    x,z = xz
    if x==0 and z==0: return 'I'
    if x==1 and z==0: return 'X'
    if x==0 and z==1: return 'Z'
    return 'Y'

def p_add(a,b):
    ax,az = a; bx,bz = b
    return (ax^bx, az^bz)

def H_on(xz):
    x,z = xz
    return (z,x)

def CNOT_conj(pc, pt):
    xc, zc = pc; xt, zt = pt
    xt ^= xc
    zc ^= zt
    return (xc, zc), (xt, zt)

# Build gate sequence for a single round and compute final data paulis given a hook

def simulate_hook_paulis(ordering, target_stab, hook_step, error_type):
    # Initialize paulis to I for data + ancilla indices
    max_q = DATA_COUNT + len(ANC_ORDER)
    paulis = {q:(0,0) for q in range(max_q)}
    seq = []
    perm_map = {stab: ordering for stab in ANC_ORDER}

    # Per-stabilizer block: H (if X), 4 CNOTs, H (if X)
    for a in ANC_ORDER:
        anc = ANC_INDEX[a]
        if STAB_TYPES[a]=='X':
            seq.append(('H', anc))
        base = CONNECTIVITY[a]
        permuted = [base[i] for i in perm_map[a]]
        if STAB_TYPES[a]=='Z':
            for i,dq in enumerate(permuted):
                seq.append(('CNOT', (dq, anc)))
                if a==target_stab and i==hook_step:
                    paulis[anc] = p_add(paulis[anc], p_encode(error_type))
        else:
            for i,dq in enumerate(permuted):
                seq.append(('CNOT', (anc, dq)))
                if a==target_stab and i==hook_step:
                    paulis[anc] = p_add(paulis[anc], p_encode(error_type))
        if STAB_TYPES[a]=='X':
            seq.append(('H', anc))

    # Propagate
    for gate, t in seq:
        if gate=='H':
            paulis[t] = H_on(paulis[t])
        else:
            c,tq = t
            paulis[c], paulis[tq] = CNOT_conj(paulis[c], paulis[tq])

    data_paulis = [p_decode(paulis[q]) for q in range(DATA_COUNT)]
    return data_paulis

Z_LINES = [[0,1,2],[3,4,5],[6,7,8]]
X_LINES = [[0,3,6],[1,4,7],[2,5,8]]

def count_line_errors(data_paulis):
    zE = 0; xE = 0
    for row in Z_LINES:
        cnt = sum(1 for q in row if data_paulis[q] in ('Z','Y'))
        if cnt>=2: zE += 1
    for col in X_LINES:
        cnt = sum(1 for q in col if data_paulis[q] in ('X','Y'))
        if cnt>=2: xE += 1
    return zE, xE

hooks_def = [('Z0','Z'), ('X1','X'), ('X2','X'), ('Z3','Z')]

analysis_results = {}
for p in itertools.permutations([0,1,2,3]):
    z_sum = 0; x_sum = 0
    for stab, et in hooks_def:
        data_paulis = simulate_hook_paulis(list(p), stab, hook_step=0, error_type=et)
        zE, xE = count_line_errors(data_paulis)
        z_sum += zE; x_sum += xE
    analysis_results[p] = (z_sum, x_sum)

print("Permutations with no errors (zE=0,xE=0):")
robust = [p for p,res in analysis_results.items() if res==(0,0)]
print(sorted(robust))

print("\nFull results:")
for p,res in sorted(analysis_results.items()):
    print(p, res)


Permutations with no errors (zE=0,xE=0):
[]

Full results:
(0, 1, 2, 3) (4, 0)
(0, 1, 3, 2) (4, 0)
(0, 2, 1, 3) (4, 0)
(0, 2, 3, 1) (4, 0)
(0, 3, 1, 2) (4, 0)
(0, 3, 2, 1) (4, 0)
(1, 0, 2, 3) (4, 0)
(1, 0, 3, 2) (4, 0)
(1, 2, 0, 3) (4, 0)
(1, 2, 3, 0) (4, 0)
(1, 3, 0, 2) (4, 0)
(1, 3, 2, 0) (4, 0)
(2, 0, 1, 3) (4, 0)
(2, 0, 3, 1) (4, 0)
(2, 1, 0, 3) (4, 0)
(2, 1, 3, 0) (4, 0)
(2, 3, 0, 1) (4, 0)
(2, 3, 1, 0) (4, 0)
(3, 0, 1, 2) (4, 0)
(3, 0, 2, 1) (4, 0)
(3, 1, 0, 2) (4, 0)
(3, 1, 2, 0) (4, 0)
(3, 2, 0, 1) (4, 0)
(3, 2, 1, 0) (4, 0)


In [71]:
# QASM-driven schedule (to mirror PyZX exactly)
from error_prop_tools import generate_surface_code_qasm
import re

def qasm_to_schedule(qasm_str):
    """Parse QASM from error_prop_tools into a schedule we can Pauli-simulate.
    Returns:
      num_data, num_anc, anc_label_to_aidx, per_anc_steps: dict label->{'H_pre':bool,'cnots':[('D->A',dq)|('A->D',dq)], 'H_post':bool}
      order: list of ancilla labels in global execution order
    """
    lines = [ln.strip() for ln in qasm_str.splitlines()]
    # Extract sizes
    m_q = next((re.match(r"qreg q\[(\d+)\];", ln) for ln in lines if ln.startswith('qreg q[')), None)
    m_a = next((re.match(r"qreg a\[(\d+)\];", ln) for ln in lines if ln.startswith('qreg a[')), None)
    num_data = int(m_q.group(1)) if m_q else 0
    num_anc = int(m_a.group(1)) if m_a else 0

    anc_label_to_aidx = {}
    per_anc_steps = {}
    order = []
    cur_label = None

    for ln in lines:
        if ln.startswith('// Stabilizer for '):
            # // Stabilizer for Z0 (a[5])
            m = re.match(r"// Stabilizer for\s+([XZ]\d+)\s+\(a\[(\d+)\]\)", ln)
            if not m: continue
            cur_label = m.group(1)
            aidx = int(m.group(2))
            anc_label_to_aidx[cur_label] = aidx
            per_anc_steps[cur_label] = {'H_pre': False, 'cnots': [], 'H_post': False}
            order.append(cur_label)
        elif ln.startswith('h a['):
            if cur_label is None: continue
            # QASM has H before and after X stabilizer block
            # First H encountered -> pre, second -> post
            if not per_anc_steps[cur_label]['H_pre']:
                per_anc_steps[cur_label]['H_pre'] = True
            else:
                per_anc_steps[cur_label]['H_post'] = True
        elif ln.startswith('cx '):
            if cur_label is None: continue
            # cx q[d], a[i]; or cx a[i], q[d];
            m1 = re.match(r"cx\s+q\[(\d+)\],\s*a\[(\d+)\];", ln)
            m2 = re.match(r"cx\s+a\[(\d+)\],\s*q\[(\d+)\];", ln)
            if m1:
                dq = int(m1.group(1))
                per_anc_steps[cur_label]['cnots'].append(('D->A', dq))
            elif m2:
                dq = int(m2.group(2))
                per_anc_steps[cur_label]['cnots'].append(('A->D', dq))
    return num_data, num_anc, anc_label_to_aidx, per_anc_steps, order

# Pauli propagation on the QASM schedule

def simulate_hook_paulis_qasm(ordering, target_stab, hook_step, error_type):
    # Build QASM for this ordering on central weight-4 stabilizers
    perm_list = list(ordering)
    custom_orders = {'Z0': perm_list, 'X1': perm_list, 'X2': perm_list, 'Z3': perm_list}
    qasm = generate_surface_code_qasm(3, custom_cnot_orderings=custom_orders)
    num_data, num_anc, anc_label_to_aidx, per_anc, anc_global_order = qasm_to_schedule(qasm)

    # Pauli dict for data+ancillas in a-index order (data:0..num_data-1, ancillas mapped to num_data+aidx)
    paulis = {q:(0,0) for q in range(num_data + num_anc)}

    def H_on_(q):
        x,z = paulis[q]; paulis[q] = (z,x)
    def CNOT_(c,t):
        xc,zc = paulis[c]; xt,zt = paulis[t]
        xt ^= xc; zc ^= zt
        paulis[c] = (xc,zc); paulis[t] = (xt,zt)
    def add_pauli(q, p):
        ax,az = paulis[q]; bx,bz = p_encode(p)
        paulis[q] = (ax^bx, az^bz)

    # Execute in global ancilla order as QASM emits
    for lab in anc_global_order:
        steps = per_anc[lab]
        aidx = anc_label_to_aidx[lab]
        aq = num_data + aidx
        if steps['H_pre']:
            H_on_(aq)
        for i,(kind, dq) in enumerate(steps['cnots']):
            if kind=='D->A':
                CNOT_(dq, aq)
            else:
                CNOT_(aq, dq)
            if lab==target_stab and i==hook_step:
                add_pauli(aq, error_type)
        if steps['H_post']:
            H_on_(aq)

    # Decode data paulis
    def p_dec(xz):
        x,z = xz
        if x==0 and z==0: return 'I'
        if x==1 and z==0: return 'X'
        if x==0 and z==1: return 'Z'
        return 'Y'
    data_paulis = [p_dec(paulis[q]) for q in range(num_data)]
    return data_paulis

# Count like PyZX
Z_LINES = [[0,1,2],[3,4,5],[6,7,8]]
X_LINES = [[0,3,6],[1,4,7],[2,5,8]]

def count_line_errors(data_paulis):
    zE = 0; xE = 0
    for row in Z_LINES:
        if sum(1 for q in row if data_paulis[q] in ('Z','Y'))>=2: zE += 1
    for col in X_LINES:
        if sum(1 for q in col if data_paulis[q] in ('X','Y'))>=2: xE += 1
    return zE, xE

# Evaluate all permutations using hook on step 1 (after the second CNOT) for each of the four central stabilizers
HOOK_STEP = 1  # 0-based index; 1 => after 2nd CNOT in that stabilizer's sequence
hooks = [('Z0','Z'), ('X1','X'), ('X2','X'), ('Z3','Z')]
analysis_results = {}
for p in itertools.permutations([0,1,2,3]):
    z_sum = 0; x_sum = 0
    for stab, et in hooks:
        data_ps = simulate_hook_paulis_qasm(p, stab, hook_step=HOOK_STEP, error_type=et)
        zE, xE = count_line_errors(data_ps)
        z_sum += zE; x_sum += xE
    analysis_results[p] = (z_sum, x_sum)

print("Permutations with no errors (zE=0,xE=0):")
robust = [p for p,res in analysis_results.items() if res==(0,0)]
print(sorted(robust))

print("\nFull results:")
for p,res in sorted(analysis_results.items()):
    print(p, res)

# Compare against PyZX robust set
pyzx_robust = {
    (0, 3, 1, 2), (0, 3, 2, 1), (1, 2, 0, 3), (1, 2, 3, 0),
    (2, 1, 0, 3), (2, 1, 3, 0), (3, 0, 1, 2), (3, 0, 2, 1),
}
print("\nPyZX robust set:")
print(sorted(pyzx_robust))

if set(robust) == pyzx_robust:
    print("\nExact match with PyZX robust set.")
else:
    print("\nMismatch with PyZX robust set.")
    print("Missing from Stim:", sorted(pyzx_robust - set(robust)))
    print("Extra in Stim:", sorted(set(robust) - pyzx_robust))


Permutations with no errors (zE=0,xE=0):
[(0, 3, 1, 2), (0, 3, 2, 1), (1, 2, 0, 3), (1, 2, 3, 0), (2, 1, 0, 3), (2, 1, 3, 0), (3, 0, 1, 2), (3, 0, 2, 1)]

Full results:
(0, 1, 2, 3) (2, 0)
(0, 1, 3, 2) (2, 0)
(0, 2, 1, 3) (0, 2)
(0, 2, 3, 1) (0, 2)
(0, 3, 1, 2) (0, 0)
(0, 3, 2, 1) (0, 0)
(1, 0, 2, 3) (2, 0)
(1, 0, 3, 2) (2, 0)
(1, 2, 0, 3) (0, 0)
(1, 2, 3, 0) (0, 0)
(1, 3, 0, 2) (0, 2)
(1, 3, 2, 0) (0, 2)
(2, 0, 1, 3) (0, 2)
(2, 0, 3, 1) (0, 2)
(2, 1, 0, 3) (0, 0)
(2, 1, 3, 0) (0, 0)
(2, 3, 0, 1) (2, 0)
(2, 3, 1, 0) (2, 0)
(3, 0, 1, 2) (0, 0)
(3, 0, 2, 1) (0, 0)
(3, 1, 0, 2) (0, 2)
(3, 1, 2, 0) (0, 2)
(3, 2, 0, 1) (2, 0)
(3, 2, 1, 0) (2, 0)

PyZX robust set:
[(0, 3, 1, 2), (0, 3, 2, 1), (1, 2, 0, 3), (1, 2, 3, 0), (2, 1, 0, 3), (2, 1, 3, 0), (3, 0, 1, 2), (3, 0, 2, 1)]

Exact match with PyZX robust set.
