# Generalized Entaglement Purification Protocol functions 

### Helper Functions (tomography, density matrix extractor)

In [103]:
import itertools
import numpy as np
import stim
from typing import Sequence
import stim
from qutip import Qobj, ptrace
import matplotlib.pyplot as plt

def sample(circuit, shots, indices, inclusive=True):
    """
    Returns a dense int array of shape (shots, total_cols).
    'indices' can mix single ints and [start, end] ranges (inclusive).
    """
    # Normalize indices to a list
    if isinstance(indices, (int, np.integer)):
        indices = [int(indices)]

    sampler = circuit.compile_sampler()
    rows = []
    row_len = None

    for _ in range(shots):
        m = sampler.sample(1)[0]  # 1D array of measurement bits for this shot
        parts = []
        for idx in indices:
            if isinstance(idx, (list, tuple)) and len(idx) == 2:
                s, e = idx
                if inclusive:
                    e = e + 1
                parts.append(m[s:e])
            else:
                i = int(idx)
                parts.append(m[i:i+1])
        row = np.concatenate(parts, axis=0).astype(int, copy=False)

        if row_len is None:
            row_len = len(row)
        elif len(row) != row_len:
            raise ValueError(
                f"Inconsistent columns across shots: expected {row_len}, got {len(row)}. "
                "Ensure your 'indices' produce the same total length each shot."
            )
        rows.append(row)

    return np.vstack(rows)

def plot_syndromes(measurements, rows_to_plot, figsize=None):
    """
    Plot error counts for selected rows.
    rows_to_plot: e.g. [0], [0,1], [2,3] where r maps to S{r//2+1}-N{r%2+1}.
    Input 'measurements' is (shots, cols) int array:
      - FULL: >=24 cols as 8 triplets in order
        [S1N1 bit, S1N1 phase, S1N2 bit, S1N2 phase, S2N1 bit, S2N1 phase, S2N2 bit, S2N2 phase]
      - PRE-SLICED: exactly 6*len(rows_to_plot) cols: [row0 bit(3), row0 phase(3), row1 bit(3), row1 phase(3), ...]
    """
    rows_to_plot = list(rows_to_plot)
    if not rows_to_plot:
        raise ValueError("rows_to_plot must not be empty.")
    m = np.asarray(measurements, dtype=int)
    shots, cols = m.shape
    k = len(rows_to_plot)
    if cols % 3:
        raise ValueError("Columns must be a multiple of 3 (triplets).")

    # Build (shots, 2*k, 3) of [bit, phase] triplets per requested row
    if cols >= 24:
        t_idx = sum(([2*r, 2*r+1] for r in rows_to_plot), [])
        triplets = [m[:, 3*g:3*(g+1)] for g in t_idx]
        synds = np.concatenate(triplets, axis=1).reshape(shots, 2*k, 3)
    elif cols == 6*k:
        synds = m.reshape(shots, 2*k, 3)
    else:
        raise ValueError("Provide either full (>=24) or pre-sliced (==6*len(rows_to_plot)) columns.")

    def label(r): return f"S{r//2+1}-N{r%2+1}"
    titles = [f"{label(r)} {t}" for r in rows_to_plot for t in ("bit", "phase")]
    weights = np.array([1, 2, 4])

    if figsize is None:
        figsize = (8, 3*k)
    fig, axes = plt.subplots(k, 2, figsize=figsize, sharey=True)
    axes = np.atleast_2d(axes)

    for i in range(k):
        for side in (0, 1):
            ax = axes[i, side]
            codes = synds[:, 2*i + side, :] @ weights          # 0..7
            errors = codes[codes > 0] - 1                       # 1..7 -> 0..6
            counts = np.bincount(errors, minlength=7)
            x = np.arange(7)
            ax.bar(x, counts, width=0.6)
            ax.set_xticks(x)
            ax.set_xticklabels([f"$q_{{{j}}}$" for j in x])
            ax.set_ylim(0, shots)
            ax.set_title(titles[2*i + side])
            ax.set_ylabel("Error Count")
            ax.grid(True, linestyle=":", alpha=0.5)

    fig.suptitle(", ".join(map(label, rows_to_plot)), y=1.02, fontsize=14)
    plt.tight_layout()
    plt.show()

def tomography(circuit: stim.Circuit,
    qubits: tuple[int, int], shots: int = 10000
) -> np.ndarray:
    """
    Reconstruct the reduced density matrix on the specified pair of qubits
    after entanglement swapping by performing Pauli tomography.

    Args:
        qubits: Tuple of two qubit indices (e.g., (0, 3)).
        shots: Number of samples per Pauli setting.

    Returns:
        rho: 4×4 numpy array representing the density matrix on those qubits.
    """
    q0, q1 = qubits
    base = circuit.copy()
    paulis = ['I', 'X', 'Y', 'Z']
    exp_vals: dict[tuple[str, str], float] = {}

    # Estimate expectation for each Pauli pair
    for p0 in paulis:
        for p1 in paulis:
            if p0 == 'I' and p1 == 'I':
                exp_vals[(p0, p1)] = 1.0
                continue

            c = base.copy()
            tom_count = 0

            # Rotate & measure on q0 if needed
            if p0 != 'I':
                if p0 == 'X':
                    c.append("H", [q0])
                elif p0 == 'Y':
                    c.append("S_DAG", [q0])
                    c.append("H", [q0])
                c.append("M", [q0])
                tom_count += 1

            # Rotate & measure on q1 if needed
            if p1 != 'I':
                if p1 == 'X':
                    c.append("H", [q1])
                elif p1 == 'Y':
                    c.append("S_DAG", [q1])
                    c.append("H", [q1])
                c.append("M", [q1])
                tom_count += 1

            # Sample the circuit
            sampler = c.compile_sampler()
            bits = sampler.sample(shots)  # shape (shots, 2 + tom_count)

            # Extract tomography bits (always start at index 2)
            if tom_count == 2:
                m0 = bits[:, 2]
                m1 = bits[:, 3]
                exp_vals[(p0, p1)] = np.mean((1 - 2*m0) * (1 - 2*m1))
            elif p0 != 'I':
                m0 = bits[:, 2]
                exp_vals[(p0, p1)] = np.mean(1 - 2*m0)
            else:
                m1 = bits[:, 2]
                exp_vals[(p0, p1)] = np.mean(1 - 2*m1)

    # Reconstruct density matrix via Pauli expansion
    sigma = {
        'I': np.eye(2, dtype=complex),
        'X': np.array([[0, 1], [1, 0]], dtype=complex),
        'Y': np.array([[0, -1j], [1j, 0]], dtype=complex),
        'Z': np.array([[1, 0], [0, -1]], dtype=complex),
    }
    rho = sum(
        exp_vals[(p0, p1)] * np.kron(sigma[p0], sigma[p1])
        for p0 in paulis for p1 in paulis
    ) / 4

    return rho

def tomography_on_circuit(circuit: stim.Circuit, qubits: tuple[int, int], shots: int = 10000):
    """
    Perform 2-qubit Pauli tomography on the specified pair of qubits in the given circuit.
    Returns the reconstructed 4x4 density matrix.
    
    Args:
        circuit: A stim.Circuit without demolition measurements.
        qubits: A tuple (q0, q1) of qubit indices to reconstruct.
        shots: Number of samples per Pauli setting.
    """
    paulis = ['I', 'X', 'Y', 'Z']
    exp_vals = {}

    # For each Pauli basis combination
    for p0 in paulis:
        for p1 in paulis:
            # The identity-identity expectation is always 1
            if p0 == 'I' and p1 == 'I':
                exp_vals[(p0, p1)] = 1.0
                continue

            # Copy the base circuit
            c = circuit.copy()
            meas_qubits = []

            # Apply basis change and measurement for first qubit
            if p0 != 'I':
                if p0 == 'X':
                    c.append('H', [qubits[0]])
                elif p0 == 'Y':
                    c.append('S_DAG', [qubits[0]])
                    c.append('H', [qubits[0]])
                c.append('M', [qubits[0]])
                meas_qubits.append(qubits[0])

            # Apply basis change and measurement for second qubit
            if p1 != 'I':
                if p1 == 'X':
                    c.append('H', [qubits[1]])
                elif p1 == 'Y':
                    c.append('S_DAG', [qubits[1]])
                    c.append('H', [qubits[1]])
                c.append('M', [qubits[1]])
                meas_qubits.append(qubits[1])

            # Sample the circuit
            sampler = c.compile_sampler()
            bits = sampler.sample(shots)  # shape (shots, len(meas_qubits))

            # Compute expectation value
            if len(meas_qubits) == 2:
                m0 = bits[:, 0]
                m1 = bits[:, 1]
                exp_vals[(p0, p1)] = np.mean((1 - 2*m0) * (1 - 2*m1))
            else:
                m = bits[:, 0]
                exp_vals[(p0, p1)] = np.mean(1 - 2*m)

    # Reconstruct density matrix from Pauli expectations
    sigma = {
        'I': np.eye(2, dtype=complex),
        'X': np.array([[0,1],[1,0]], dtype=complex),
        'Y': np.array([[0,-1j],[1j,0]], dtype=complex),
        'Z': np.array([[1,0],[0,-1]], dtype=complex),
    }
    rho = sum(
        exp_vals[(p0, p1)] * np.kron(sigma[p0], sigma[p1])
        for p0 in paulis for p1 in paulis
    ) / 4

    return rho

def tomography_dm(
    circuit: stim.Circuit,
    qubits: list[int] | tuple[int, ...],
    shots: int = 10_000,
    use_direct_pauli_meas: bool = True,
) -> np.ndarray:
    """
    Reconstruct the reduced density matrix on `qubits` of a Stim circuit using
    full Pauli tomography and linear inversion.

    ρ = (1 / 2^k) * Σ_{P∈{I,X,Y,Z}^k} ⟨P⟩ P

    Args:
        circuit: The base stim.Circuit (may contain any gates/measurements).
        qubits:  Iterable of qubit indices to reconstruct (length = k).
        shots:   Number of samples per Pauli setting.
        use_direct_pauli_meas:
                 True  -> use MX/MY/M (Z) instructions (cleaner, recommended).
                 False -> rotate with H/S_DAG then use M (Z) for all Pauli bases.

    Returns:
        rho: A (2**k × 2**k) complex numpy array (density matrix on `qubits`).
    """
    # ---- Pauli helpers ----
    PAULIS = ("I", "X", "Y", "Z")
    SIGMA = {
        "I": np.eye(2, dtype=complex),
        "X": np.array([[0, 1], [1, 0]], dtype=complex),
        "Y": np.array([[0, -1j], [1j, 0]], dtype=complex),
        "Z": np.array([[1, 0], [0, -1]], dtype=complex),
    }

    def kron_all(mats):
        out = mats[0]
        for m in mats[1:]:
            out = np.kron(out, m)
        return out

    qubits = tuple(qubits)
    k = len(qubits)
    base_meas = circuit.num_measurements  # how many M/MX/MY/MZ are already there
    exp_vals = {}

    # Iterate over every Pauli string on these k qubits
    for setting in itertools.product(PAULIS, repeat=k):
        # identity on all => expectation is 1
        if all(p == "I" for p in setting):
            exp_vals[setting] = 1.0
            continue

        # Clone base circuit and append tomography operations
        c = circuit.copy()
        measured_count = 0

        if use_direct_pauli_meas:
            # Direct Pauli demolition measurements
            for q, p in zip(qubits, setting):
                if p == "I":
                    continue
                gate = {"X": "MX", "Y": "MY", "Z": "M"}[p]
                c.append(gate, [q])
                measured_count += 1
        else:
            # Basis-rotate then M (Z)
            for q, p in zip(qubits, setting):
                if p == "I":
                    continue
                if p == "X":
                    c.append("H", [q])
                elif p == "Y":
                    c.append("S_DAG", [q])
                    c.append("H", [q])
                # Z: do nothing
                c.append("M", [q])
                measured_count += 1

        # Sample
        sampler = c.compile_sampler()
        bits = sampler.sample(shots=shots)

        # Extract the tomography bits we just added (they appear at the end)
        tomo_bits = bits[:, base_meas: base_meas + measured_count]
        eig = 1 - 2 * tomo_bits  # 0->+1, 1->-1

        # Expectation is product of eigenvalues across all measured qubits
        exp_vals[setting] = float(np.mean(np.prod(eig, axis=1))) if measured_count > 1 else float(np.mean(eig))

    # Linear inversion reconstruction
    dim = 2 ** k
    rho = np.zeros((dim, dim), dtype=complex)
    for setting, val in exp_vals.items():
        mats = [SIGMA[p] for p in setting]
        rho += val * kron_all(mats)
    rho /= dim

    return rho

def get_exact_density_matrix(
    circuit: stim.Circuit,
    keep_qubits: Sequence[int]
) -> np.ndarray:
    """
    Given a Clifford circuit and a list of qubits to keep, return
    the exact reduced density matrix on those qubits.

    Args:
        circuit: a stim.Circuit acting on N qubits.
        keep_qubits: sorted list of distinct qubit indices in [0, N).
    Returns:
        (2^k × 2^k) NumPy array for k = len(keep_qubits).
    """
    # --- 1) Validate inputs ---
    n = circuit.num_qubits
    if not keep_qubits:
        raise ValueError("keep_qubits must be non‑empty")
    if any((q < 0 or q >= n) for q in keep_qubits):
        raise IndexError(f"keep_qubits entries must be in [0, {n-1}]")
    if len(set(keep_qubits)) != len(keep_qubits):
        raise ValueError("keep_qubits contains duplicates")

    # --- 2) Run the circuit in the tableau simulator ---
    sim = stim.TableauSimulator()
    sim.do(circuit)

    # --- 3) Extract the statevector ---
    try:
        # Stim ≥1.9: direct access to the 2^n amplitudes
        psi = np.array(sim.state_vector(), dtype=complex)
    except AttributeError:
        # Fallback: build the sparse unitary via the same tableau
        U = circuit.to_sparse_unitary().toarray()
        psi0 = np.zeros((2**n,), dtype=complex)
        psi0[0] = 1.0
        psi = U.dot(psi0)

    # --- 4) Form the full density matrix ρ = |ψ><ψ| ---
    rho_full = np.outer(psi, psi.conj())

    # --- 5) Partial trace via QuTiP ---
    dims = [[2]*n, [2]*n]
    qobj = Qobj(rho_full, dims=dims)
    rho_reduced = ptrace(qobj, list(keep_qubits))

    return np.array(rho_reduced.full())


### Generalized [[7,1,3]] encoding (Physical -> Logical)

In [3]:
import stim

def encode_713(circuit: stim.Circuit, qubit_indices: list[int]) -> stim.Circuit:
    c = circuit.copy()
    q = qubit_indices
    c.append("H", [q[4], q[5], q[6]])

    c.append("CX", [q[0], q[1]])
    c.append("CX", [q[0], q[2]])

    c.append("CX", [q[6], q[3]])
    c.append("CX", [q[6], q[1]])
    c.append("CX", [q[6], q[0]])

    c.append("CX", [q[5], q[3]])
    c.append("CX", [q[5], q[2]])
    c.append("CX", [q[5], q[0]])

    c.append("CX", [q[4], q[3]])
    c.append("CX", [q[4], q[2]])
    c.append("CX", [q[4], q[1]])
    
    return c

def decode_713(circuit: stim.Circuit, qubit_indices: list[int]) -> stim.Circuit:
    c = circuit.copy()
    q = qubit_indices

    c.append("CX", [q[4], q[1]])
    c.append("CX", [q[4], q[2]])
    c.append("CX", [q[4], q[3]])

    c.append("CX", [q[5], q[0]])
    c.append("CX", [q[5], q[2]])
    c.append("CX", [q[5], q[3]])

    c.append("CX", [q[6], q[0]])
    c.append("CX", [q[6], q[1]])
    c.append("CX", [q[6], q[3]])

    c.append("CX", [q[0], q[2]])
    c.append("CX", [q[0], q[1]])

    c.append("H", [q[4], q[5], q[6]])

    return c

### Generalized [[7,1,3]] Stabalizers 

In [4]:
import stim

def stabalizers_713(circuit: stim.Circuit, data_qubit_indices: list[int], ancilla_qubit_indicies: list[int]) -> stim.Circuit:
    c = circuit.copy()
    d = data_qubit_indices
    a = ancilla_qubit_indicies

    # --- Stabilizer Measurements ------------------------------------------
    # Bit-flip syndrome
    c.append("CX", [d[0], a[0]])
    c.append("CX", [d[2], a[0]])
    c.append("CX", [d[4], a[0]])
    c.append("CX", [d[6], a[0]])

    c.append("CX", [d[1], a[1]])
    c.append("CX", [d[2], a[1]])
    c.append("CX", [d[5], a[1]])
    c.append("CX", [d[6], a[1]])

    c.append("CX", [d[3], a[2]])
    c.append("CX", [d[4], a[2]])
    c.append("CX", [d[5], a[2]])
    c.append("CX", [d[6], a[2]])

    # Phase-flip syndrome
    c.append("H", [a[3], a[4], a[5]])
 
    c.append("CX", [a[3], d[0]])
    c.append("CX", [a[3], d[2]])
    c.append("CX", [a[3], d[4]])
    c.append("CX", [a[3], d[6]])

    c.append("CX", [a[4], d[1]])
    c.append("CX", [a[4], d[2]])
    c.append("CX", [a[4], d[5]])
    c.append("CX", [a[4], d[6]])

    c.append("CX", [a[5], d[3]])
    c.append("CX", [a[5], d[4]])
    c.append("CX", [a[5], d[5]])
    c.append("CX", [a[5], d[6]])

    c.append("H", [a[3], a[4], a[5]])

    # Bit measurements
    c.append("M", [a[0]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])
    c.append("M", [a[1]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])
    c.append("M", [a[2]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])

    # Phase Measurements
    c.append("M", [a[3]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])
    c.append("M", [a[4]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])
    c.append("M", [a[5]])   
    # c.append("DETECTOR", [stim.target_rec(-1)])
    
    return c

### Generalized T-CNOT Stim (Sampling)

In [5]:
import stim

def tcnot(circuit: stim.Circuit, qubit_dict: dict) -> stim.Circuit:
    """
    Applies a T-CNOT operation on the specified qubits in the given circuit.

    Args:
        circuit: A stim.Circuit object representing the quantum circuit.
        qubit_dict: A dictionary containing lists of qubit indices:
            - 'n1_memory': List of memory qubit indices for the first block.
            - 'n1_ancilla': List of ancilla qubit indices for the first block.
            - 'n2_ancilla': List of ancilla qubit indices for the second block.
            - 'n2_memory': List of memory qubit indices for the second block.
            Example:
                {
                    'n1_memory': [0, 1, 2],
                    'n1_ancilla': [3, 4, 5],
                    'n2_ancilla': [6, 7, 8],
                    'n2_memory': [9, 10, 11]
                }

    Returns:
        c: A stim.Circuit object representing the modified circuit after applying T-CNOT operations.
    """
    c = circuit.copy()

    n1_memory = qubit_dict['n1_memory']
    n1_ancilla = qubit_dict['n1_communcation']
    n2_ancilla = qubit_dict['n2_communcation']
    n2_memory = qubit_dict['n2_memory']
    
    block_length = len(n1_memory)  # Assuming all lists are of the same length

    for _ in range(block_length):
        c.append('H', n1_memory[_])  
        c.append("H", n1_ancilla[_])
        c.append("CNOT", [n1_ancilla[_], n2_ancilla[_]])
        c.append('CNOT', [n1_memory[_], n1_ancilla[_]])
        c.append('CNOT', [n2_ancilla[_], n2_memory[_]])
        c.append('M', [n1_ancilla[_]])
        c.append('CX', [stim.target_rec(-1), stim.GateTarget(n2_memory[_])])
        c.append("H", [n2_ancilla[_]])
        c.append('M', [n2_ancilla[_]])
        c.append('CZ', [stim.target_rec(-1), stim.GateTarget(n1_memory[_])])
    
    return c



### Generalized Entanglement Swapping Stim (Sampling)

In [6]:
import stim 

def entanglment_swapping(circuit: stim.Circuit, qubit_indices: list[int], total_qubits: int, bell_state_prepared = True) -> stim.Circuit:
    """
    Constructs the entanglement swapping circuit on specified qubits:
    - Prepare Bell pairs on the specified qubit indices.
    - Bell measurement on the middle qubits.
    - Pauli-frame corrections on the last qubit.

    Args:
        qubit_indices: List of qubit indices to perform entanglement swapping on.
        total_qubits: Total number of qubits in the system.

    Returns:
        c: stim.Circuit representing the entanglement swapping circuit.
    """
    c = circuit.copy()
    
    if not bell_state_prepared:
        # Prepare Bell pairs
        for i in range(0, len(qubit_indices), 2):
            c.append("H", [qubit_indices[i]])
            c.append("CNOT", [qubit_indices[i], qubit_indices[i + 1]])

    # Bell measurement on the middle qubits
    c.append("CNOT", [qubit_indices[1], qubit_indices[2]])
    c.append("H", [qubit_indices[1]])
    c.append("M", [qubit_indices[1]])  # m_X
    c.append("M", [qubit_indices[2]])  # m_Z

    # Pauli-frame corrections on the last qubit
    c.append('CX', [stim.target_rec(-1), stim.GateTarget(qubit_indices[3])])
    c.append('CZ', [stim.target_rec(-2), stim.GateTarget(qubit_indices[3])])


    return c

# Entanglement Purification Protocol

In [101]:
import stim
import random

def entaglement_purification_qec():
    
    circuit = stim.Circuit()
    total_qubits = 80
    block_length = 7
    station1 = {
                'n1_memory': [0, 1, 2, 3, 4, 5, 6],
                'n1_communcation': [7, 8, 9, 10, 11, 12, 13],
                'n2_communcation': [14, 15, 16, 17, 18, 19, 20],
                'n2_memory': [21, 22, 23, 24, 25, 26, 27],
                'n1_ancilla' : [56, 57, 58, 59, 60, 61],
                'n2_ancilla' : [62, 63, 64, 65, 66, 67]
                }
    station2 = {
                'n1_memory' : [28, 29, 30, 31, 32, 33, 34],
                'n1_communcation' : [35, 36, 37, 38, 39, 40, 41],
                'n2_communcation' : [42, 43, 44, 45, 46, 47, 48],
                'n2_memory' : [49, 50, 51, 52, 53, 54, 55],
                'n1_ancilla' : [68, 69, 70, 71, 72, 73],
                'n2_ancilla' : [74, 75, 76, 77, 78, 79]
                }
    
    
    def stabalizers(circuit, nodes: list = None):
        c = circuit.copy()
        if 0 in nodes:
            c = stabalizers_713(c, station1['n1_memory'], station1['n1_ancilla'])
        if 1 in nodes:
            c = stabalizers_713(c, station1['n2_memory'], station1['n2_ancilla'])
        if 2 in nodes:
            c = stabalizers_713(c, station2['n1_memory'], station2['n1_ancilla'])   
        if 3 in nodes:
            c = stabalizers_713(c, station2['n2_memory'], station2['n2_ancilla'])       
        return c  

    circuit.append("I", [i for i in range(80)])
    circuit = encode_713(circuit, station1['n1_memory'])
    circuit = encode_713(circuit, station1['n2_memory']) 
    circuit = encode_713(circuit, station2['n1_memory']) 
    circuit = encode_713(circuit, station2['n2_memory']) 

    p = random.random()
    if p <= 0.1:
        circuit.append('X', [0])
        circuit = stabalizers(circuit, [0])
        circuit.append('X', [0])


    circuit = tcnot(circuit, station1)

    circuit = tcnot(circuit, station2)

    for i in range(block_length):
        qubit_indicies = [station1['n1_memory'][i], station1['n2_memory'][i], station2['n1_memory'][i], station2['n2_memory'][i]]
        circuit = entanglment_swapping(circuit = circuit, qubit_indices = qubit_indicies, total_qubits = total_qubits)
    
    circuit = decode_713(circuit, station1['n1_memory'])
    # circuit = decode_713(circuit, station1['n2_memory']) 
    # circuit = decode_713(circuit, station2['n1_memory']) 
    circuit = decode_713(circuit, station2['n2_memory']) 

    return circuit


############

### Fidelity Calculation

In [88]:
import numpy as np

circuit = entaglement_purification_qec()
sampler = circuit.compile_sampler()
events = sampler.sample(shots=10000) 
rho = tomography_dm(circuit, [0,49])
print(rho)

def fidelity_phi_plus(rho: np.ndarray) -> float:
    """
    Compute the fidelity F = ⟨Φ⁺| rho |Φ⁺⟩ of a two-qubit state ρ
    with the Bell state |Φ+> = (|00> + |11>)/√2.

    Args:
        rho: a 4x4 complex numpy array representing the density matrix.

    Returns:
        Fidelity F ∈ [0,1].
    """
    # check shape
    if rho.shape != (4, 4):
        raise ValueError(f"Expected a 4x4 density matrix, got shape {rho.shape}")

    # define |Φ+> = (|00> + |11>)/√2
    phi_plus = np.array([1, 0, 0, 1], dtype=complex) / np.sqrt(2)

    # F = ⟨Φ+| rho |Φ+>
    F = np.vdot(phi_plus, rho @ phi_plus)  
    return float(np.real(F))

print(fidelity_phi_plus(rho)) 


[[ 4.988e-01+0.j      -1.000e-04-0.00345j -7.150e-03-0.00345j
   5.000e-01+0.0028j ]
 [-1.000e-04+0.00345j  3.300e-03+0.j       0.000e+00-0.001j
   5.950e-03+0.00655j]
 [-7.150e-03+0.00345j  0.000e+00+0.001j   -3.300e-03+0.j
  -7.400e-03+0.00025j]
 [ 5.000e-01-0.0028j   5.950e-03-0.00655j -7.400e-03-0.00025j
   5.012e-01+0.j     ]]
0.9999999999999998


# Baretkok Generation

import stim
from typing import Dict, Tuple, Optional
import numpy as np

class BarrettKokPhases:
    """
    Barrett-Kok protocol phases that modify a passed-in circuit.
    Each phase appends operations to the circuit and returns metadata.
    """
    
    def __init__(self, qubit_map: Dict[str, int]):
        self.qubit_map = qubit_map
        self.round_data = []
        
    def phase1_prepare_memories(self, circuit: stim.Circuit) -> Dict:
        """Phase 1: Initialize memories in |+⟩ state."""
        circuit.append("H", [self.qubit_map['mem_0'], self.qubit_map['mem_1']])
        
        return {
            'phase': 'prepare_memories',
            'operations': 'H on both memories',
            'qubits_modified': ['mem_0', 'mem_1']
        }
    
    def phase2_create_photons(self, circuit: stim.Circuit, round_num: int) -> Dict:
        """Phase 2: Create entangled photons from memories."""
        metadata = {'phase': f'create_photons_round_{round_num}'}
        
        # If round 2, apply X gates first (part of Barrett-Kok)
        if round_num == 2:
            circuit.append("X", [self.qubit_map['mem_0'], self.qubit_map['mem_1']])
            metadata['pre_operations'] = 'X gates applied'
        
        # Create entangled photons
        circuit.append("CNOT", [self.qubit_map['mem_0'], self.qubit_map['photon_0']])
        circuit.append("CNOT", [self.qubit_map['mem_1'], self.qubit_map['photon_1']])
        
        metadata['operations'] = 'CNOT from memories to photons'
        metadata['photons_created'] = True
        
        return metadata
    
    def simulate_photon_transmission(self, p_loss: float = 0.1) -> Tuple[bool, Dict]:
        """
        Simulate photon transmission (happens between phases).
        This is NOT added to circuit - it's classical simulation.
        """
        photon_0_arrives = np.random.random() > p_loss
        photon_1_arrives = np.random.random() > p_loss
        
        metadata = {
            'event': 'photon_transmission',
            'photon_0_arrived': photon_0_arrives,
            'photon_1_arrived': photon_1_arrives,
            'both_arrived': photon_0_arrives and photon_1_arrives
        }
        
        return photon_0_arrives and photon_1_arrives, metadata
    
    def phase3_bsm_measurement(self, circuit: stim.Circuit) -> Dict:
        """Phase 3: Perform BSM on photons."""
        circuit.append("CNOT", [self.qubit_map['photon_0'], self.qubit_map['photon_1']])
        circuit.append("H", [self.qubit_map['photon_0']])
        circuit.append("M", [self.qubit_map['photon_0'], self.qubit_map['photon_1']])
        
        return {
            'phase': 'bsm_measurement',
            'operations': 'Bell measurement on photons',
            'measurement_indices': [len(circuit.num_measurements) - 2, len(circuit.num_measurements) - 1]
        }
    
    def simulate_bsm_result(self, circuit: stim.Circuit, p_indistinguishable: float = 0.5) -> Tuple[bool, Tuple[int, int], Dict]:
        """
        Simulate BSM outcome (distinguishable vs indistinguishable).
        Samples the actual measurement result from the circuit.
        """
        # First check if distinguishable
        is_distinguishable = np.random.random() > p_indistinguishable
        
        if not is_distinguishable:
            # Indistinguishable - both detectors fire (or neither)
            return False, (0, 0), {
                'distinguishable': False,
                'result': 'indistinguishable'
            }
        
        # Sample actual measurement outcome
        sampler = circuit.compile_sampler()
        result = sampler.sample(shots=1)[0]
        
        # Get last two measurements
        measurement_result = (int(result[-2]), int(result[-1]))
        
        return True, measurement_result, {
            'distinguishable': True,
            'detector_0': measurement_result[0],
            'detector_1': measurement_result[1]
        }
    
    def phase4_apply_corrections(self, circuit: stim.Circuit, 
                                round1_result: Tuple[int, int], 
                                round2_result: Tuple[int, int]) -> Dict:
        """Phase 4: Apply final corrections based on measurement outcomes."""
        # Primary always applies X
        circuit.append("X", [self.qubit_map['mem_1']])
        
        # Non-primary applies Z if measurements differ
        corrections_applied = ['X on mem_1 (primary)']
        
        if round1_result[0] != round2_result[0]:
            circuit.append("Z", [self.qubit_map['mem_0']])
            corrections_applied.append('Z on mem_0 (measurement difference)')
            
        return {
            'phase': 'corrections',
            'corrections_applied': corrections_applied,
            'final_state': 'entangled'
        }
    
    def phase_reset(self, circuit: stim.Circuit) -> Dict:
        """Reset phase: Clean up for retry."""
        circuit.append("R", [
            self.qubit_map['mem_0'], 
            self.qubit_map['mem_1'],
            self.qubit_map['photon_0'],
            self.qubit_map['photon_1']
        ])
        
        return {
            'phase': 'reset',
            'reason': 'protocol_failure',
            'qubits_reset': ['mem_0', 'mem_1', 'photon_0', 'photon_1']
        }


def run_adaptive_protocol(qubit_map: Dict[str, int], 
                         p_loss: float = 0.1,
                         p_indistinguishable: float = 0.5,
                         max_attempts: int = 10) -> Tuple[Optional[stim.Circuit], Dict]:
    """
    Run the protocol adaptively, building the circuit as we go.
    """
    protocol = BarrettKokPhases(qubit_map)
    
    for attempt in range(max_attempts):
        # Start fresh circuit for this attempt
        circuit = stim.Circuit()
        attempt_metadata = {
            'attempt': attempt + 1,
            'phases': [],
            'success': False
        }
        
        # Phase 1: Prepare
        meta = protocol.phase1_prepare_memories(circuit)
        attempt_metadata['phases'].append(meta)
        
        # Try both rounds
        round_results = []
        
        for round_num in [1, 2]:
            round_meta = {'round': round_num}
            
            # Phase 2: Create photons
            meta = protocol.phase2_create_photons(circuit, round_num)
            round_meta['photon_creation'] = meta
            
            # Simulate transmission (not in circuit)
            both_arrived, transmission_meta = protocol.simulate_photon_transmission(p_loss)
            round_meta['transmission'] = transmission_meta
            
            if not both_arrived:
                # Photon lost - reset and retry
                reset_meta = protocol.phase_reset(circuit)
                attempt_metadata['phases'].append(reset_meta)
                attempt_metadata['failure_reason'] = f'photon_loss_round_{round_num}'
                break
                
            # Phase 3: BSM
            bsm_meta = protocol.phase3_bsm_measurement(circuit)
            round_meta['bsm'] = bsm_meta
            
            # Simulate BSM result
            is_distinguishable, result, bsm_result_meta = protocol.simulate_bsm_result(circuit, p_indistinguishable)
            round_meta['bsm_result'] = bsm_result_meta
            
            if not is_distinguishable:
                # Indistinguishable - reset and retry
                reset_meta = protocol.phase_reset(circuit)
                attempt_metadata['phases'].append(reset_meta)
                attempt_metadata['failure_reason'] = f'indistinguishable_round_{round_num}'
                break
                
            round_results.append(result)
            attempt_metadata['phases'].append(round_meta)
        
        # Check if both rounds succeeded
        if len(round_results) == 2:
            # Apply corrections
            corrections_meta = protocol.phase4_apply_corrections(
                circuit, 
                round_results[0], 
                round_results[1]
            )
            attempt_metadata['phases'].append(corrections_meta)
            attempt_metadata['success'] = True
            
            return circuit, attempt_metadata
    
    # Failed after max attempts
    return None, {'error': 'max_attempts_exceeded', 'attempts': max_attempts}


# Usage example
if __name__ == "__main__":
    qubit_map = {
        'mem_0': 0,
        'mem_1': 1,
        'photon_0': 2,
        'photon_1': 3
    }
    
    circuit, metadata = run_adaptive_protocol(
        qubit_map,
        p_loss=0.1,
        p_indistinguishable=0.5,
        max_attempts=10
    )
    
    if circuit:
        print(f"Success! Circuit has {circuit.num_measurements} measurements")
        print(f"Metadata: {metadata}")
    else:
        print(f"Failed: {metadata}")