In [None]:
"""
surface_code_z_matching.py

Construct a PyMatching Matching graph for correcting ONLY X errors (i.e. decoding
Z-check detectors) for a rotated surface code of distance d, with T rounds.

Modelled error sources:
 - p_meas: measurement flip (ancilla readout) -> time-like edges between same check across rounds.
 - p_1q: single-qubit X error probability per data qubit per round (from 1q gates / idle).
 - p_cx: per-CX probability that the CX produces a data-X on the involved data qubit.
 - p_cx_corr: per-pair-CX-correlated-X probability (two neighboring data qubits simultaneously get X)
                 — used to add extra edges between detector nodes corresponding to correlated faults.

Outputs:
 - M: pymatching.Matching for Z-check detectors only
 - detmap: mapping (face_r, face_c, t) -> detector index
 - helper data structures for translating syndromes

Notes/assumptions:
 - Rotated surface code layout (faces are checks). Z-checks are (r+c)%2 == 0.
 - We place detectors at round boundaries t = 0..T (so there are T+1 detector layers).
 - Data errors occurring "during" a round are attached to the detectors at that round boundary.
 - CX schedule: each Z-ancilla performs CXs with its up-to-four neighboring data qubits.
   The code assumes each such CX is counted once per measurement round for that check.
 - This is an approximate phenomenological mapping from circuit faults to detector edges;
   it is flexible and parameterized so you can tune probabilities to match a more detailed
   circuit-noise model or results from stim/experiments.
"""

from typing import Dict, Tuple, List
import pymatching as pm
import math


def rotated_z_sublattice_detectors(d: int, T: int):
    """Return a map for Z-check detectors on a rotated d x d code for rounds t=0..T."""
    assert d % 2 == 1 and d >= 3
    detmap = {}
    node = 0
    for t in range(T + 1):
        for r in range(d - 1):
            for c in range(d - 1):
                if ((r + c) % 2) == 0:  # Z-check
                    detmap[(r, c, t)] = node
                    node += 1
    return detmap


def add_boundary_or_edge(M: pm.Matching, u: int, v: int, p: float):
    """Add an edge between u and v (or boundary edge if v is None). Use error_probability if available."""
    if v is None:
        M.add_boundary_edge(u, error_probability=p)
    else:
        M.add_edge(u, v, error_probability=p)


def compute_combined_single_data_error_prob(p_1q: float, p_cx: float, n_cx: int):
    """
    Approximate probability that a data qubit ends up with an X error in a given round,
    given:
      - p_1q: single-qubit error prob (1q gates / idle)
      - p_cx: per-CX probability that the CX produces a data-X (independently),
      - n_cx: number of CXs touching the data qubit in that round (typically 0..2 for rotated code).
    We assume independence and compute:
       p_total = 1 - (1 - p_1q) * (1 - p_cx) ** n_cx
    """
    return 1.0 - (1.0 - p_1q) * ((1.0 - p_cx) ** n_cx)


def build_z_matching_for_x_errors(
    d: int,
    T: int,
    p_meas: float,
    p_1q: float,
    p_cx: float,
    p_cx_corr: float,
    include_correlated_pairs: bool = True,
):
    """
    Build a pymatching.Matching that decodes X errors (i.e. uses Z-check detectors).

    Returns:
      M, detmap, metadata
    where metadata contains e.g. data_qubit -> adjacent checks and cx-counts per round
    """
    assert d % 2 == 1 and d >= 3
    assert T >= 1

    detmap = rotated_z_sublattice_detectors(d, T)
    M = pm.Matching()

    # helper: map data qubit coordinates (i,j) to the Z-check faces it touches
    # Rotated layout: data qubits at grid vertices 0..d-1 in both dims
    data_adj_checks: Dict[Tuple[int, int], List[Tuple[int, int]]] = {}
    for i in range(d):
        for j in range(d):
            faces = []
            for fr, fc in [(i - 1, j - 1), (i - 1, j), (i, j - 1), (i, j)]:
                if 0 <= fr < (d - 1) and 0 <= fc < (d - 1):
                    if ((fr + fc) % 2) == 0:  # Z-check faces only
                        faces.append((fr, fc))
            data_adj_checks[(i, j)] = faces

    # 1) time-like edges from measurement errors: each Z-check between t and t+1
    for t in range(T):
        for r in range(d - 1):
            for c in range(d - 1):
                if ((r + c) % 2) == 0:
                    u = detmap[(r, c, t)]
                    v = detmap[(r, c, t + 1)]
                    add_boundary_or_edge(M, u, v, p_meas)

    # 2) Space-like edges from single-data errors (result of 1q or CX mapped to data X),
    #    attach to detectors in the same round boundary t.
    #    We compute n_cx_per_data per round as the number of CXs that touch the data qubit for Z-measurement.
    #    For rotated surface code Z-check ancilla, each data qubit typically participates in up to 2 CXs per round.
    #    Here we assume every Z-check performs CXs with all adjacent data qubits each round.
    #    So n_cx equals the count of adjacent Z-checks per data qubit.
    metadata = {"data_adj_checks": data_adj_checks, "detmap": detmap}

    for t in range(T + 1):  # attach data-errors at each detector layer
        for (i, j), faces in data_adj_checks.items():
            n_cx = len(faces)  # approximate count of CXs touching this data for Z-check measurement
            # combined single-data X error prob for this round
            p_data = compute_combined_single_data_error_prob(p_1q, p_cx, n_cx)

            if len(faces) == 0:
                continue
            elif len(faces) == 1:
                # boundary: single check flipped -> boundary edge
                u = detmap[(faces[0][0], faces[0][1], t)]
                add_boundary_or_edge(M, u, None, p_data)
            else:
                # interior: typical case is 2 adjacent Z-checks for a data qubit in rotated layout
                # connect the two detectors with probability p_data
                u = detmap[(faces[0][0], faces[0][1], t)]
                v = detmap[(faces[1][0], faces[1][1], t)]
                add_boundary_or_edge(M, u, v, p_data)
                # if there are more than two (rare in rotated layout), add pairwise edges for all pairs
                if len(faces) > 2:
                    for k in range(2, len(faces)):
                        w = detmap[(faces[k][0], faces[k][1], t)]
                        add_boundary_or_edge(M, u, w, p_data)

    # 3) CX-induced correlated errors: if two data qubits are both neighbors of the same ancilla
    #    (i.e. two data qubits that are connected via a common Z-check ancilla), a CX failure
    #    could create correlated X on the pair. Add an edge connecting the detectors that such
    #    a correlated X would flip (in the same round t) with probability p_cx_corr.
    #
    #    For rotated code, data qubits that are co-neighbors of the same Z-check are typically
    #    pairs of data qubits adjacent around that face; find those pairs and connect their detectors.
    if include_correlated_pairs and p_cx_corr > 0.0:
        # iterate z-check faces; for each face, look at its neighboring data qubits (up to 4).
        for r in range(d - 1):
            for c in range(d - 1):
                if ((r + c) % 2) != 0:
                    continue
                # data qubit neighbors (grid vertices)
                data_neighbors = []
                for di, dj in [(r, c), (r, c + 1), (r + 1, c), (r + 1, c + 1)]:
                    # data qubit coordinates are (i,j) = those vertices
                    if 0 <= di < d and 0 <= dj < d:
                        data_neighbors.append((di, dj))
                # consider pairs among these data neighbors
                for a_idx in range(len(data_neighbors)):
                    for b_idx in range(a_idx + 1, len(data_neighbors)):
                        qa = data_neighbors[a_idx]
                        qb = data_neighbors[b_idx]
                        # find detectors that would flip if qa and qb both suffer X
                        faces_a = data_adj_checks[qa]
                        faces_b = data_adj_checks[qb]
                        # For each round t, a simultaneous X on qa and qb flips the union of their adjacent checks
                        for t in range(T + 1):
                            # get the detector nodes they would flip (list)
                            nodes = []
                            for (fr, fc) in faces_a:
                                nodes.append(detmap[(fr, fc, t)])
                            for (fr, fc) in faces_b:
                                nodes.append(detmap[(fr, fc, t)])
                            # eliminate duplicates
                            nodes = sorted(set(nodes))
                            # A correlated two-qubit X error generally flips an even number of detectors.
                            # If it flips exactly two detectors, we can add a simple edge between them.
                            # Otherwise, for >2, we conservatively add pairwise edges between each pair with
                            # probability scaled down so total weight is roughly p_cx_corr.
                            if len(nodes) == 1:
                                add_boundary_or_edge(M, nodes[0], None, p_cx_corr)
                            elif len(nodes) == 2:
                                add_boundary_or_edge(M, nodes[0], nodes[1], p_cx_corr)
                            else:
                                # split the correlated probability among all unique unordered pairs
                                npairs = (len(nodes) * (len(nodes) - 1)) // 2
                                # distribute p_cx_corr across pairs conservatively (divide)
                                per_pair = p_cx_corr / max(1, npairs)
                                for uu_idx in range(len(nodes)):
                                    for vv_idx in range(uu_idx + 1, len(nodes)):
                                        add_boundary_or_edge(M, nodes[uu_idx], nodes[vv_idx], per_pair)

    return M, detmap, metadata


In [None]:
# -------------------------
# Example usage
# -------------------------
# Example parameters
d = 5
T = 8
p_meas = 0.01    # measurement flip probability
p_1q = 1e-4      # single-qubit X prob from idles/1q gates per round
p_cx = 1e-3      # per-CX probability of producing an X on the data qubit
p_cx_corr = 5e-4 # correlated two-qubit X prob per CX-pair (tunable)

M, detmap, meta = build_z_matching_for_x_errors(
    d=d,
    T=T,
    p_meas=p_meas,
    p_1q=p_1q,
    p_cx=p_cx,
    p_cx_corr=p_cx_corr,
    include_correlated_pairs=True,
)

# Show some stats
print("Matching created:")
print("  number of detector nodes:", len(detmap))
print("  number of edges in matching graph:", len(M.graph.edges))
# You can now decode using M.decode(syndrome_vector) where syndrome_vector is a list/ndarray of 0/1