In [14]:
# Import necessary Qiskit components
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import AerSimulator # Use AerSimulator for simulation
import numpy as np
import pymatching # <<< Added
import networkx as nx # Optional: For graph visualization if needed
import matplotlib.pyplot as plt # Optional: For graph visualization

# --- Configuration for d=3 Rotated Surface Code (Same as before) ---
TOTAL_QUBITS = 17
NUM_SYNDROME_BITS = 8 # 4 Z-stabilizers, 4 X-stabilizers

# Data Qubits (9 total)
DATA_QUBITS = [1, 3, 4, 5, 7, 9, 11, 12, 13]

# Measure Qubits (Ancillas - 8 total)
Z_ANCILLAS = [0, 2, 8, 10] # Indices for Z-stabilizer measurements
X_ANCILLAS = [6, 14, 15, 16] # Indices for X-stabilizer measurements
MEASURE_QUBITS = Z_ANCILLAS + X_ANCILLAS

# Stabilizer definitions (Same as before)
stabilizers = {
    0:  {'type': 'Z', 'data_qubits': [1, 3]}, 2:  {'type': 'Z', 'data_qubits': [1, 4, 5, 7]},
    8:  {'type': 'Z', 'data_qubits': [7, 9, 11, 12]}, 10: {'type': 'Z', 'data_qubits': [9, 13]},
    6:  {'type': 'X', 'data_qubits': [3, 4, 9, 11]}, 14: {'type': 'X', 'data_qubits': [5, 12]},
    15: {'type': 'X', 'data_qubits': [11, 13]}, 16: {'type': 'X', 'data_qubits': [7, 12]},
}

# Map ancilla index to syndrome bit index (Same as before)
syndrome_bit_map = {
    0: 0, 2: 1, 8: 2, 10: 3, # Z ancillas -> syndrome bits 0-3
    6: 4, 14: 5, 15: 6, 16: 7 # X ancillas -> syndrome bits 4-7
}
ancilla_order_for_syndrome = Z_ANCILLAS + X_ANCILLAS

# --- Helper Function for Stabilizer Measurement (Same as before) ---
def measure_stabilizer(qc, measure_qubit, stab_info, cl_bit_index):
    stab_type = stab_info['type']
    data_qubits = stab_info['data_qubits']
    qc.reset(measure_qubit) # Reset ancilla before use
    qc.h(measure_qubit)
    if stab_type == 'Z':
        for dq in data_qubits: qc.cz(measure_qubit, dq)
    elif stab_type == 'X':
        for dq in data_qubits: qc.cx(measure_qubit, dq)
    qc.h(measure_qubit)
    qc.measure(measure_qubit, cl_bit_index)

# --- REVISED: Function to Build Matching Graphs ---
def build_matching_graphs(stabilizers, data_qubits, z_ancillas, x_ancillas):
    """
    Constructs the matching graphs (stabilizer nodes only) and identifies boundary nodes.

    Returns:
        tuple: (graph_x, graph_z, node_map_x, node_map_z, boundary_nodes_x, boundary_nodes_z)
            graph_x: nx.Graph for X errors (nodes are Z stabilizers)
            graph_z: nx.Graph for Z errors (nodes are X stabilizers)
            node_map_x: Maps Z ancilla index to graph_x node index
            node_map_z: Maps X ancilla index to graph_z node index
            boundary_nodes_x: List of node indices in graph_x connected to boundary
            boundary_nodes_z: List of node indices in graph_z connected to boundary
    """
    # --- X-Error Graph (based on Z stabilizers) ---
    graph_x = nx.Graph()
    node_map_x = {anc: i for i, anc in enumerate(z_ancillas)} # Map Z ancilla index -> node index 0,1,2,3
    num_x_nodes = len(z_ancillas)
    boundary_nodes_x_set = set() # Store indices of nodes connected to boundary

    # Add nodes for Z stabilizers ONLY
    for anc_idx, graph_node_idx in node_map_x.items():
        graph_x.add_node(graph_node_idx, qubit_ancilla_idx=anc_idx, type='Z-Stabilizer')

    # Add edges based on single X errors on data qubits
    for dq in data_qubits:
        flipped_ancillas = [anc for anc in z_ancillas if dq in stabilizers[anc]['data_qubits']]

        if len(flipped_ancillas) == 1:
            # X error on dq connects one Z stabilizer node TO THE BOUNDARY
            anc_node_idx = node_map_x[flipped_ancillas[0]]
            boundary_nodes_x_set.add(anc_node_idx)
            # We might store which qubit causes this boundary connection if needed for interpretation later
            # graph_x.nodes[anc_node_idx].setdefault('boundary_error_qubits', []).append(dq)
        elif len(flipped_ancillas) == 2:
            # X error on dq connects two Z stabilizers
            anc_node1 = node_map_x[flipped_ancillas[0]]
            anc_node2 = node_map_x[flipped_ancillas[1]]
            # Add edge only if it doesn't exist, or update metadata if needed
            if not graph_x.has_edge(anc_node1, anc_node2):
                 graph_x.add_edge(anc_node1, anc_node2, weight=1, error_data_qubit=[dq], error_type='X')
            else:
                 # If edge exists, perhaps append the data qubit causing it (if interpretation needs it)
                 graph_x[anc_node1][anc_node2].setdefault('error_data_qubit', []).append(dq)


    # --- Z-Error Graph (based on X stabilizers) ---
    graph_z = nx.Graph()
    node_map_z = {anc: i for i, anc in enumerate(x_ancillas)} # Map X ancilla index -> node index 0,1,2,3
    num_z_nodes = len(x_ancillas)
    boundary_nodes_z_set = set()

    # Add nodes for X stabilizers ONLY
    for anc_idx, graph_node_idx in node_map_z.items():
        graph_z.add_node(graph_node_idx, qubit_ancilla_idx=anc_idx, type='X-Stabilizer')

    # Add edges based on single Z errors on data qubits
    for dq in data_qubits:
        flipped_ancillas = [anc for anc in x_ancillas if dq in stabilizers[anc]['data_qubits']]

        if len(flipped_ancillas) == 1:
            anc_node_idx = node_map_z[flipped_ancillas[0]]
            boundary_nodes_z_set.add(anc_node_idx)
            # graph_z.nodes[anc_node_idx].setdefault('boundary_error_qubits', []).append(dq)
        elif len(flipped_ancillas) == 2:
            anc_node1 = node_map_z[flipped_ancillas[0]]
            anc_node2 = node_map_z[flipped_ancillas[1]]
            if not graph_z.has_edge(anc_node1, anc_node2):
                graph_z.add_edge(anc_node1, anc_node2, weight=1, error_data_qubit=[dq], error_type='Z')
            else:
                graph_z[anc_node1][anc_node2].setdefault('error_data_qubit', []).append(dq)


    return graph_x, graph_z, node_map_x, node_map_z, list(boundary_nodes_x_set), list(boundary_nodes_z_set)

# --- Build the graphs ONCE ---
print("--- Building Matching Graphs (Stabilizer Nodes Only) ---")
graph_x, graph_z, node_map_x, node_map_z, boundary_x, boundary_z = build_matching_graphs(
    stabilizers, DATA_QUBITS, Z_ANCILLAS, X_ANCILLAS
)
print(f"X-Error Graph (Z-Stabs): {graph_x.number_of_nodes()} nodes, {graph_x.number_of_edges()} edges. Boundary nodes: {boundary_x}")
print(f"Z-Error Graph (X-Stabs): {graph_z.number_of_nodes()} nodes, {graph_z.number_of_edges()} edges. Boundary nodes: {boundary_z}")

# Optional: Visualize the graphs
# try:
#     plt.figure(1)
#     pos = nx.spring_layout(graph_x) # Or choose a better layout
#     nx.draw(graph_x, pos, with_labels=True, font_weight='bold')
#     nx.draw_networkx_nodes(graph_x, pos, nodelist=boundary_x, node_color='red') # Highlight boundary nodes
#     plt.title("X-Error Matching Graph (Nodes are Z-Stabilizers)")
#     plt.figure(2)
#     pos = nx.spring_layout(graph_z)
#     nx.draw(graph_z, pos, with_labels=True, font_weight='bold')
#     nx.draw_networkx_nodes(graph_z, pos, nodelist=boundary_z, node_color='red')
#     plt.title("Z-Error Matching Graph (Nodes are X-Stabilizers)")
#     plt.show()
# except Exception as e:
#     print(f"Install networkx and matplotlib to visualize graphs. Error: {e}")


# --- Create the Main Quantum Circuit (Similar to before) ---
qreg = QuantumRegister(TOTAL_QUBITS, 'q')
creg_syndrome = ClassicalRegister(NUM_SYNDROME_BITS, 'syndrome')
qc = QuantumCircuit(qreg, creg_syndrome)

# 1. Initialize Logical State (|0>) - Default

# 2. Introduce a Single Qubit Error
print("\n--- Introducing Error ---")
# --- Try different errors ---
error_qubit_index = 4 # Data qubit index (e.g., 1, 3, 4, 5, 7, 9, 11, 12, 13)
error_type = 'X'   # Try 'X', 'Z', 'Y', or None
# ---

if error_type and error_qubit_index in DATA_QUBITS:
    print(f"Applying {error_type} error on data qubit {error_qubit_index}")
    if error_type == 'X': qc.x(error_qubit_index)
    elif error_type == 'Z': qc.z(error_qubit_index)
    elif error_type == 'Y': qc.y(error_qubit_index)
    qc.barrier()
elif error_qubit_index not in DATA_QUBITS and error_type is not None:
     print(f"Warning: Qubit {error_qubit_index} is an ancilla or invalid. No error applied.")
     error_type = None
else:
    print("No error applied.")
    error_type = None

# 3. Measure Stabilizers (Syndrome Measurement)
print("\n--- Measuring Stabilizers ---")
for measure_qubit_index in ancilla_order_for_syndrome:
    stab_info = stabilizers[measure_qubit_index]
    cl_bit_index = syndrome_bit_map[measure_qubit_index]
    measure_stabilizer(qc, measure_qubit_index, stab_info, cl_bit_index)
    # qc.barrier() # Optional barrier


# --- Simulate the Circuit (Same as before) ---
print("\n--- Simulating ---")
simulator = AerSimulator()
job = simulator.run(qc, shots=1)
result = job.result()
counts = result.get_counts(qc)
print(f"Simulation Counts: {counts}")

# --- Decode the Syndrome using PyMatching ---
print("\n--- Decoding with PyMatching (MWPM) ---")

# Get the measured syndrome string and convert to list/tuple
syndrome_str_qiskit = list(counts.keys())[0]
syndrome_str = syndrome_str_qiskit[::-1] # Order: Z0..Z3, X0..X3
syndrome_full = np.array([int(bit) for bit in syndrome_str])

print(f"Measured Syndrome (Z0..Z3, X0..X3): {syndrome_full.tolist()}")

# --- Prepare inputs for PyMatching (REVISED) ---

# Syndrome for X-errors (uses Z-stabilizer outcomes: bits 0-3)
# Array size matches number of nodes in graph_x (which is num Z stabilizers)
syndrome_x_part = np.zeros(graph_x.number_of_nodes(), dtype=np.uint8)
for anc_idx, graph_node_idx in node_map_x.items():
    syndrome_bit_idx = syndrome_bit_map[anc_idx] # Get index (0-3) in full syndrome
    syndrome_x_part[graph_node_idx] = syndrome_full[syndrome_bit_idx]
print(f"Syndrome for X-Error Graph (Nodes 0..{graph_x.number_of_nodes()-1}=Z-Stabs): {syndrome_x_part.tolist()}")

# Syndrome for Z-errors (uses X-stabilizer outcomes: bits 4-7)
# Array size matches number of nodes in graph_z (which is num X stabilizers)
syndrome_z_part = np.zeros(graph_z.number_of_nodes(), dtype=np.uint8)
for anc_idx, graph_node_idx in node_map_z.items():
    syndrome_bit_idx = syndrome_bit_map[anc_idx] # Get index (4-7) in full syndrome
    syndrome_z_part[graph_node_idx] = syndrome_full[syndrome_bit_idx]
print(f"Syndrome for Z-Error Graph (Nodes 0..{graph_z.number_of_nodes()-1}=X-Stabs): {syndrome_z_part.tolist()}")


# --- Initialize PyMatching Objects (REVISED) ---
# Pass the graph and the list of boundary node indices
matcher_x = pymatching.Matching(graph_x, boundary_nodes=boundary_x)
matcher_z = pymatching.Matching(graph_z, boundary_nodes=boundary_z)

# --- Decode ---
# Pass the syndrome array corresponding ONLY to the stabilizer nodes
correction_x_indices = matcher_x.decode(syndrome_x_part, return_corrected_Jeg_indices=True)
correction_z_indices = matcher_z.decode(syndrome_z_part, return_corrected_Jeg_indices=True)

# The returned indices refer to edges in the graph's internal representation.
# It's often easier to get the correction in terms of which *nodes* need flipping.
# Let's decode again to get the correction array (size = num_nodes = num_stabilizers)
# correction_array[i] = 1 means node i needs correction relative to the MWPM solution
correction_array_x = matcher_x.decode(syndrome_x_part)
correction_array_z = matcher_z.decode(syndrome_z_part)

print(f"MWPM Correction Array (X-Error Graph Nodes): {correction_array_x.tolist()}")
print(f"MWPM Correction Array (Z-Error Graph Nodes): {correction_array_z.tolist()}")

# --- Interpret the Correction (Needs Careful Thought) ---
# This part is tricky. The correction_array tells us which stabilizer nodes
# were involved in the MWPM solution. We need to map this back to the most likely
# physical error(s) on data qubits.

# Strategy:
# 1. Find which nodes are indicated by correction_array_x and correction_array_z.
# 2. Identify the original syndrome nodes.
# 3. The difference between corrected nodes and syndrome nodes might imply internal edges used.
# 4. Boundary connections are implicit. If a syndrome node is also a corrected node AND a boundary node,
#    it likely implies a boundary edge was used in the matching.
# 5. Map edges/boundary connections back to data qubits.

# Let's try a simpler interpretation based on the original syndrome:
# If syndrome_x_part has a single '1' at index `i`, and `i` is in `boundary_x`, assume an X error on a data qubit flipping only stab `i`.
# If syndrome_z_part has a single '1' at index `j`, and `j` is in `boundary_z`, assume a Z error on a data qubit flipping only stab `j`.
# If syndrome_x_part has two '1's at `i1`, `i2`, assume an X error on a data qubit flipping stabs `i1`, `i2`.
# etc.

predicted_x_error_qubits = set()
predicted_z_error_qubits = set()

syndrome_nodes_x_indices = np.where(syndrome_x_part == 1)[0]
syndrome_nodes_z_indices = np.where(syndrome_z_part == 1)[0]

# Simple interpretation for SINGLE errors (most likely for d=3)
if len(syndrome_nodes_x_indices) == 1:
    node_idx = syndrome_nodes_x_indices[0]
    if node_idx in boundary_x:
        # Find data qubits flipping ONLY this Z-stabilizer
        anc_idx = z_ancillas[node_idx] # Get the original ancilla index
        for dq in data_qubits:
            flipped_ancillas = [anc for anc in z_ancillas if dq in stabilizers[anc]['data_qubits']]
            if len(flipped_ancillas) == 1 and flipped_ancillas[0] == anc_idx:
                 predicted_x_error_qubits.add(dq) # Could be multiple dq's if degenerate
elif len(syndrome_nodes_x_indices) == 2:
    node1, node2 = syndrome_nodes_x_indices
    # Find data qubits flipping BOTH these Z-stabilizers
    anc1 = z_ancillas[node1]
    anc2 = z_ancillas[node2]
    for dq in data_qubits:
        flipped_ancillas = [anc for anc in z_ancillas if dq in stabilizers[anc]['data_qubits']]
        if len(flipped_ancillas) == 2 and set(flipped_ancillas) == {anc1, anc2}:
            predicted_x_error_qubits.add(dq)

if len(syndrome_nodes_z_indices) == 1:
    node_idx = syndrome_nodes_z_indices[0]
    if node_idx in boundary_z:
        # Find data qubits flipping ONLY this X-stabilizer
        anc_idx = x_ancillas[node_idx] # Get the original ancilla index
        for dq in data_qubits:
            flipped_ancillas = [anc for anc in x_ancillas if dq in stabilizers[anc]['data_qubits']]
            if len(flipped_ancillas) == 1 and flipped_ancillas[0] == anc_idx:
                 predicted_z_error_qubits.add(dq)
elif len(syndrome_nodes_z_indices) == 2:
    node1, node2 = syndrome_nodes_z_indices
    # Find data qubits flipping BOTH these X-stabilizers
    anc1 = x_ancillas[node1]
    anc2 = x_ancillas[node2]
    for dq in data_qubits:
        flipped_ancillas = [anc for anc in x_ancillas if dq in stabilizers[anc]['data_qubits']]
        if len(flipped_ancillas) == 2 and set(flipped_ancillas) == {anc1, anc2}:
            predicted_z_error_qubits.add(dq)

# (This interpretation assumes MWPM finds the single error path for single syndromes)

# Combine Predictions
corrections = []
common_qubits = predicted_x_error_qubits.intersection(predicted_z_error_qubits)
only_x_qubits = predicted_x_error_qubits - common_qubits
only_z_qubits = predicted_z_error_qubits - common_qubits

for q in common_qubits: corrections.append(f"Y on Q{q}")
for q in only_x_qubits: corrections.append(f"X on Q{q}")
for q in only_z_qubits: corrections.append(f"Z on Q{q}")

# Handle potential degeneracies (multiple qubits predicted for one syndrome part)
# For d=3 single error, expect only one qubit total. If more, pick one arbitrarily or flag.
if len(corrections) > 1:
     predicted_error = f"Ambiguous ({', '.join(corrections)})"
     predicted_type = "Ambiguous"
     predicted_qubit = "N/A"
elif len(corrections) == 1:
     predicted_error = corrections[0]
     parts = predicted_error.split()
     predicted_type = parts[0]
     predicted_qubit = int(parts[-1][1:])
else:
     predicted_error = "None"
     predicted_type = None
     predicted_qubit = None


print(f"\nDecoder Prediction (Simplified Interpretation): {predicted_error}")

# --- Explanation (Same as before, checks against prediction) ---
print("\n--- Explanation ---")
is_error_detected = np.any(syndrome_full)

if not is_error_detected:
    if error_type is None:
        print("Correct: No error introduced, syndrome is zero.")
    else:
        print(f"Result: Error ({error_type} on Q{error_qubit_index}) introduced, but syndrome is zero (Undetected - Logical Error or Boundary Issue).")
else: # Error detected
    if error_type is None:
        print(f"Result: No error introduced, but non-zero syndrome {syndrome_full.tolist()} detected (Problem!).")
    else: # Error introduced and detected
        print(f"Correct: Error ({error_type} on Q{error_qubit_index}) introduced, non-zero syndrome {syndrome_full.tolist()} detected.")
        if predicted_error == "None":
            print("Decoder failed to find a likely single error cause based on simple interpretation.")
        elif predicted_type == "Ambiguous":
            print(f"Decoder Prediction is Ambiguous: {predicted_error}. Input was {error_type} on Q{error_qubit_index}.")
            print("(This suggests the stabilizer definitions might lead to degenerate syndromes for single errors).")
        else:
            print(f"Decoder suggests correction: {predicted_error}")
            # Check if prediction matches input (for single qubit errors)
            if predicted_type == error_type and predicted_qubit == error_qubit_index:
                print(f"Success: Decoder correctly identified the single {error_type} error on qubit {error_qubit_index}.")
            else:
                 print(f"Mismatch: Decoder predicted {predicted_type} on {predicted_qubit}, but input was {error_type} on {error_qubit_index}.")
                 print("(This might indicate graph definition issues, interpretation limitations, or MWPM finding an equivalent error).")

--- Building Matching Graphs (Stabilizer Nodes Only) ---
X-Error Graph (Z-Stabs): 4 nodes, 3 edges. Boundary nodes: [0, 1, 2, 3]
Z-Error Graph (X-Stabs): 4 nodes, 2 edges. Boundary nodes: [0, 1, 2, 3]

--- Introducing Error ---
Applying X error on data qubit 4

--- Measuring Stabilizers ---

--- Simulating ---
Simulation Counts: {'01110010': 1}

--- Decoding with PyMatching (MWPM) ---
Measured Syndrome (Z0..Z3, X0..X3): [0, 1, 0, 0, 1, 1, 1, 0]
Syndrome for X-Error Graph (Nodes 0..3=Z-Stabs): [0, 1, 0, 0]
Syndrome for Z-Error Graph (Nodes 0..3=X-Stabs): [1, 1, 1, 0]


ValueError: No perfect matching could be found. This likely means that the syndrome has odd parity in the support of a connected component without a boundary.