In [None]:
%load_ext autoreload
%autoreload 2

import stim
import numpy as np
# import matplotlib.pyplot as plt
from stim_utils import *
from stim_simulation_utils import *
from circuit_commons import *
print(stim.__version__)

In [None]:
import importlib.util
import sys


def reload_file(name, path):
    spec = importlib.util.spec_from_file_location(name, path)
    your_file = importlib.util.module_from_spec(spec)
    sys.modules[name] = your_file
    spec.loader.exec_module(your_file)


reload_file("stim_simulation_utils", "stim_simulation_utils.py")
reload_file("stim_utils", "stim_utils.py")


# Implementing real encoder/decoder
## Encoding procedure for $|++0000>$
as in the original paper http://arxiv.org/abs/2409.04628 (Demonstration of quantum computation and error correction with a tesseract code)

In [None]:
measurement_operators_rows = [
    [0,1,2,3],
    [4,5,6,7],
    [8,9,10,11],
    [12,13,14,15]
]

measurement_operators_columns = [
    [0,4,8,12],
    [1,5,9,13],
    [2,6,10,14],
    [3,7,11,15]
]

def encode_sub_circuit_quad(circuit, ancilla, qubits):
    circuit.append("CNOT", [qubits[0], ancilla]) # Flag qubit
    circuit.append("CNOT", [qubits[0], qubits[1]])
    circuit.append("CNOT", [qubits[0], qubits[2]])
    circuit.append("CNOT", [qubits[0], qubits[3]])
    circuit.append("CNOT", [qubits[0], ancilla]) # Flag qubit

def add_cnot_gates(circuit, start1, start2, num_gates=4):
    """
    Adds CNOT gates to the circuit.

    Parameters:
    - circuit: The quantum circuit object to modify.
    - start1: The starting index for the first qubit in the CNOT pairs.
    - start2: The starting index for the second qubit in the CNOT pairs.
    - num_gates: The number of CNOT gates to append (default is 4).
    """
    for i in range(num_gates):
        circuit.append("CNOT", [start1 + i, start2 + i])

def encode_manual(circuit):
    # Here we encode the state |++0000> as can be seen in Fig. 9 of that paper
    # initialize qubits:
    circuit.append("H", [0,1,2,3,4,8,12])

    encode_sub_circuit_quad(circuit, 20, [4, 5, 6, 7])
    encode_sub_circuit_quad(circuit, 21, [8, 9, 10, 11])
    encode_sub_circuit_quad(circuit, 22, [12, 13, 14, 15])

    add_cnot_gates(circuit, 0, 16) # working on ancilla qubits as flag qubits
    add_cnot_gates(circuit, 0, 4)
    add_cnot_gates(circuit, 0, 8)
    add_cnot_gates(circuit, 0, 12)
    add_cnot_gates(circuit, 0, 16) # working on ancilla qubits as flag qubits

    circuit.append("R", [18, 19]) # Reset ancillas 18,19 since their role is done and we need them for the following
    circuit.append("H", [19])
    circuit.append("CNOT", [19, 18]) # cnot to flag qubit
    circuit.append("CNOT", [19, 0]) # measuring stabilizer
    circuit.append("CNOT", [19, 1]) # measuring stabilizer
    circuit.append("CNOT", [19, 2]) # measuring stabilizer
    circuit.append("CNOT", [19, 3]) # measuring stabilizer
    circuit.append("CNOT", [19, 18]) # cnot to flag qubit

## Error correction rules
Processing rows and columns of tesseract. Here we rely mainly on Fig.7 from the paper and expand it to all the other settings (row/column X/Z)

In [None]:
# Reference code for correcting Z error in column:
def correct_column_Z(flagX: int, measX, frameZ):
    if flagX == -1:  # no row flagged already
        if sum(measX) == 2:
            return "reject"
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                # Disagreeing measurement is 1 -> flag the index of 1
                flagX = measX.index(1)
            else:
                # Disagreeing measurement is 0 -> flag the index of 0
                flagX = measX.index(0)
    else:  # row flagX in (0, 1, 2, 3) flagged
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                col = measX.index(1)
            else:
                col = measX.index(0)
            frameZ[4 * flagX + col] += 1  # Z correction
        if sum(measX) == 2:
            if measX in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameZ[[4 * flagX, 4 * flagX + 1]] += 1  # ZZII on flagged row
            else:
                return "reject"
        flagX = -1
    return flagX, measX, frameZ

# Deduced from above example:
# TODO test following correction - I went over them manually and they look ok. and make sense as far as I understand
def correct_column_X(flagZ, measZ, frameX):
    if flagZ == -1:  # No row flagged already
        if sum(measZ) == 2:
            return "reject"
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                flagZ = measZ.index(1)
            else:
                flagZ = measZ.index(0)
    else:  # Row flagZ in (0, 1, 2, 3) flagged
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                col = measZ.index(1)
            else:
                col = measZ.index(0)
            frameX[4 * flagZ + col] += 1  # Apply X correction
        elif sum(measZ) == 2:
            if measZ in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameX[[4 * flagZ, 4 * flagZ + 1]] += 1  # XXII on flagged row
            else:
                return "reject"
        flagZ = -1
    return flagZ, measZ, frameX

def correct_row_Z(flagX, measX, frameZ):
    if flagX == -1:  # No column flagged already
        if sum(measX) == 2:
            return "reject"
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                flagX = measX.index(1)
            else:
                flagX = measX.index(0)
    else:  # Column flagX in (0, 1, 2, 3) flagged
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                row = measX.index(1)
            else:
                row = measX.index(0)
            frameZ[4 * row + flagX] += 1  # Apply Z correction
        elif sum(measX) == 2:
            if measX in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameZ[[4 * 0 + flagX, 4 * 1 + flagX]] += 1  # ZZII on flagged column. multiplication by 0 is for readability.
            else:
                return "reject"
        flagX = -1
    return flagX, measX, frameZ

def correct_row_X(flagZ, measZ, frameX):
    if flagZ == -1:  # No column flagged already
        if sum(measZ) == 2:
            return "reject"
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                flagZ = measZ.index(1)
            else:
                flagZ = measZ.index(0)
    else:  # Column flagZ in (0, 1, 2, 3) flagged
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                row = measZ.index(1)
            else:
                row = measZ.index(0)
            frameX[4 * row + flagZ] += 1  # Apply X correction
        elif sum(measZ) == 2:
            if measZ in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameX[[4 * 0 + flagZ, 4 * 1 + flagZ]] += 1  # XXII on flagged column. multiplication by 0 is for readability.
            else:
                return "reject"
        flagZ = -1
    return flagZ, measZ, frameX

## Error correction - manual implementation
Using the error correction rules, we implement a decoder. In every round, we measure rows and columns, correcting errors accordingly. Note that this code uses flag qubits, and so errors are corrected on subsequent row/column measurement and the Pauli frame is updated. See Fig. 8 in the paper for reference

In [58]:
def append_detector_on_last_n_measurements(circuit, num_measurements=4):
    circuit.append("DETECTOR", [
        stim.target_rec(-i) for i in range(1,num_measurements+1)
    ])

def get_qubits_and_ancillas():
    """
    the two ancilla qubits are for the X and Z measurements, respectively.
    """
    return list(range(16)), 16, 17


def measure_x_z_stabilizer(circuit, data_qubits, x_ancilla, z_ancilla):
    """
    Appends a circuit to measure X and Z stabilizers on data_qubits.
    Based on Fig 4(d) of the paper.
    X stabilizer is measured on x_ancilla.
    Z stabilizer is measured on z_ancilla. The Z measurement also acts as a flag for the X measurement.
    """
    circuit.append("H", x_ancilla)
    for q in data_qubits:
        circuit.append("CNOT", [x_ancilla, q])
    for q in data_qubits:
        circuit.append("CNOT", [q, z_ancilla])
    circuit.append("H", x_ancilla)

    # In stim, the order of measurements in a single M command matters for rec targeting.
    # M x_ancilla, z_ancilla means x is rec(-2), z is rec(-1)
    circuit.append("M", [x_ancilla, z_ancilla])
    circuit.append("R", [x_ancilla, z_ancilla])


def error_correction_round_rows(circuit):
    """
    Appends one round of row-based stabilizer measurements to the circuit.
    Measures X and Z stabilizers for each of the 4 rows.
    """
    _, x_ancilla, z_ancilla = get_qubits_and_ancillas()
    for row in measurement_operators_rows:
        measure_x_z_stabilizer(circuit, row, x_ancilla, z_ancilla)


def error_correction_round_columns(circuit):
    """
    Appends one round of column-based stabilizer measurements to the circuit.
    Measures X and Z stabilizers for each of the 4 columns.
    """
    _, x_ancilla, z_ancilla = get_qubits_and_ancillas()
    for col in measurement_operators_columns:
        measure_x_z_stabilizer(circuit, col, x_ancilla, z_ancilla)


def error_correct_manual(circuit, rounds=3):
    for i in range(rounds):
        error_correction_round_rows(circuit)
        error_correction_round_columns(circuit)

In [59]:
def process_shot(shot_data, rounds):
    """
    Processes the measurement data for a single shot to apply the error correction logic.
    This function simulates the classical processing part of the decoder.
    """
    flagX = -1
    flagZ = -1
    frameX = np.zeros(16, dtype=np.uint8)
    frameZ = np.zeros(16, dtype=np.uint8)

    # Each round has 4 (rows) + 4 (cols) = 8 stabilizer measurements.
    # Each stabilizer measurement has 2 measurement outcomes (X and Z).
    measurements_per_round = 8 * 2
    
    for r in range(rounds):
        round_start_index = r * measurements_per_round
        
        # --- Row Pass ---
        # The first 8 measurements are from the 4 row-stabilizers
        row_measurements = shot_data[round_start_index : round_start_index + 8]
        # X results are at even indices, Z at odd indices
        measX_rows = row_measurements[0::2]
        measZ_rows = row_measurements[1::2]

        # Correct Z errors based on X syndromes
        result = correct_row_Z(flagX, measX_rows.tolist(), frameZ)
        if isinstance(result, str) and result == "reject":
            return "reject", None, None
        flagX, _, frameZ = result

        # Correct X errors based on Z syndromes
        result = correct_row_X(flagZ, measZ_rows.tolist(), frameX)
        if isinstance(result, str) and result == "reject":
            return "reject", None, None
        flagZ, _, frameX = result

        # --- Column Pass ---
        # The next 8 measurements are from the 4 col-stabilizers
        col_measurements = shot_data[round_start_index + 8 : round_start_index + 16]
        measX_cols = col_measurements[0::2]
        measZ_cols = col_measurements[1::2]
        
        # Correct Z errors based on X syndromes
        result = correct_column_Z(flagX, measX_cols.tolist(), frameZ)
        if isinstance(result, str) and result == "reject":
            return "reject", None, None
        flagX, _, frameZ = result

        # Correct X errors based on Z syndromes
        result = correct_column_X(flagZ, measZ_cols.tolist(), frameX)
        if isinstance(result, str) and result == "reject":
            return "reject", None, None
        flagZ, _, frameX = result
        
    return "accept", frameX, frameZ


def run_manual_error_correction(circuit, shots, rounds):
    """
    Runs the full manual error correction simulation.
    """
    sampler = circuit.compile_sampler()
    shot_data_all = sampler.sample(shots=shots)
    
    accept_count = 0
    reject_count = 0
    
    for i in range(shots):
        shot_data = shot_data_all[i]
        status, _, _ = process_shot(shot_data, rounds)
        if status == "accept":
            accept_count += 1
        else:
            reject_count += 1
            
    print(f"Accepted: {accept_count}/{shots}")
    print(f"Rejected: {reject_count}/{shots}")
    print(f"Acceptance Rate: {accept_count/shots:.2%}")

    return accept_count, reject_count

In [76]:
# --- Test Simulation ---
# Define parameters for our test run
ROUNDS = 10
SHOTS = 100000
NOISE_LEVEL = 0.1 # Corresponds to a 0.1% physical error rate

# 1. Build the full quantum circuit
# We start with a fresh circuit from your helper function 16 qubits for code + 2 ancillas for measurement
test_circuit = init_circuit(qubits=18)

# First, prepare a valid encoded state
encode_manual(test_circuit)
# -----------------------------

# Now, apply noise to the encoded state
channel(test_circuit, NOISE_LEVEL, noise_type="X_ERROR")

# Append the error correction rounds to the circuit
error_correct_manual(test_circuit, rounds=ROUNDS)

# 2. Run the simulation and the classical decoder
print(f"--- Running Manual Error Correction Simulation (Corrected) ---")
print(f"Rounds: {ROUNDS}, Shots: {SHOTS}, Noise: {NOISE_LEVEL}")
run_manual_error_correction(test_circuit, shots=SHOTS, rounds=ROUNDS)


--- Running Manual Error Correction Simulation (Corrected) ---
Rounds: 10, Shots: 100000, Noise: 0.1
Accepted: 67571/100000
Rejected: 32429/100000
Acceptance Rate: 67.57%


(67571, 32429)

##results:

ok. its still not working as expected (I think?) but hopefully we're getting there. also I've changed the number of qubits in the circuit to 18 to account for the ancillas

Running with 0 rounds of error correction:
--- Running Manual Error Correction Simulation (Corrected) ---
Rounds: 0, Shots: 1000, Noise: 0.001
Accepted: 1000/1000
Rejected: 0/1000
Acceptance Rate: 100.00%

running with 1 round:
--- Running Manual Error Correction Simulation (Corrected) ---
Rounds: 1, Shots: 1000, Noise: 0.001
Accepted: 126/1000
Rejected: 874/1000
Acceptance Rate: 12.60%
running with 2 rounds:
--- Running Manual Error Correction Simulation (Corrected) ---
Rounds: 2, Shots: 1000, Noise: 0.001
Accepted: 10/1000
Rejected: 990/1000
Acceptance Rate: 1.00%

and this is with the channel disabled (no noise).

What could cause this?

## creating cirucit with "real" encoder/decoder/error-correction

In [None]:
# Initialize the circuit
circuit = init_circuit()

# Encode the logical qubit
encode_manual(circuit)

# Apply noise channel
channel(circuit, 0.1)
# manual_noise(circuit, [0])


error_correct_manual(circuit, rounds=3)
# Perform error correction
# error_correction(circuit)

# Print the final circuit for verification
# print(circuit)
# circuit.diagram()
circuit.to_crumble_url()

## Simulations:

### test (sanity)

In [None]:
sampler = circuit.compile_detector_sampler()
print(sampler.sample(shots=10))

### Flag qubits simulation
see reference: https://quantumcomputing.stackexchange.com/questions/22281/simulating-flag-qubits-and-conditional-branches-using-stim

In [None]:

simulator = stim.TableauSimulator()
simulator.do(circuit)

simulator.current_measurement_record()

### Using FlipSimulator
Since TableauSimulator doesn't natively support detection events, we are trying the flip simulator. see here for more details:
https://quantumcomputing.stackexchange.com/questions/34496/sample-detection-events-with-tableausimulator

In [None]:
flipSimulator = stim.FlipSimulator(batch_size=10)
flipSimulator.do(circuit)

flipSimulator.get_detector_flips()