# Imports

In [None]:
#basics
import math
import numpy as np
from tqdm import tqdm
#Quantum circuit simulator
import cirq
import stim
#plotting
import matplotlib.pyplot as plt
#pathing
import sys 
import os
sys.path.append(os.path.abspath(r"..."))
#Decoders
import pymatching as pm
from beliefmatching import BeliefMatching
#Data
import csv
#Mygates
import Shift
import Phase
import SUM
import pickle
import QFT
import Mul
import Id
import Dchannel
import TwoDchannel
import BFChannel

# Useful functions

In [None]:
def calculate_average(numbers):
    if not numbers:
        return None  # Return None if the list is empty
    return sum(numbers) / len(numbers)

# Save the results to a single CSV file, change these however you like
def save_results_to_csv(output_folder, filename, result2, result3, result5):
    """Save the results in a CSV file."""
    filepath = os.path.join(output_folder, filename)
    
    with open(filepath, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Optionally write the header row
        writer.writerow(['result2', 'result3', 'result5'])
        # Write the results as a single row
        writer.writerow([result2, result3, result5])

    print(f"Results saved to {filepath}")
    
# Function to append new results to the CSV file, change these however you like
def append_results_to_csv(output_folder, filename, result2, result3, result5):
    """Append new data to the CSV file."""
    filepath = os.path.join(output_folder, filename)
    
    # Open the file in append mode to add new rows
    with open(filepath, mode='a', newline='') as file:
        writer = csv.writer(file)
        # Append the new results as a new row
        writer.writerow([result2, result3, result5])

    print(f"New data appended to {filepath}")
    
def load_results(input_folder, filename):
    """Load the result from the specified folder."""
    filepath = os.path.join(input_folder, filename)
    with open(filepath, 'rb') as f:
        result = pickle.load(f)
    return result    

# Glue code + decoders

The next pieces of code generate the error syndromes of all possible single Pauli errors on the 5 data qudits of dimension d and tabulates the info. This  table is used to create the matching graph and the detector error model string. It's kind of a brute force way, and for larger distance codes it should be replaced by a a function that uses the parity check matrix as input (or a detector error model function that takes the circuit as input).

In [None]:
# Generate all Pauli errors for a qudit of dimension d
def create_ordered_string_list(d):
    result = ['']
    
    # Add 'Xab' strings in order
    for a in range(5):
        for b in range(1, d):
            result.append(f'X{a}{b}')
    
    # Add 'Zab' strings in order
    for a in range(5):
        for b in range(1, d):
            result.append(f'Z{a}{b}')
    
    # Add 'Yab' strings in order
    for a in range(5):
        for b in range(1, d):
            for c in range(1, d):
                result.append(f'Y{a}{b}{c}')
        
    return result

#Function that creates the 5 qudit code circuit with specific errors inserted.
def create_code(d,error): 
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8 = cirq.LineQid.range(9, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    circ = cirq.Circuit()

    #Encode
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[0]),QFT.Hdag(d).on(qudits[1]),QFT.Hdag(d).on(qudits[2]),QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[3],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[3])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[1])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    
    #Error
    # Apply X error on a specific qubit
    if 'X' in error:
        qudit_index = int(error[1])
        error_n = int(error[2])
        circ.append(cirq.Moment([Shift.Shift(d,error_n).on(qudits[qudit_index])]))

    # Apply Z error on a specific qubit
    if 'Z' in error:
        qudit_index = int(error[1])
        error_n = int(error[2])
        circ.append(cirq.Moment([Phase.Phase(d,error_n).on(qudits[qudit_index])]))
    
    # Apply Y error on a specific qubit
    if 'Y' in error:
        qudit_index = int(error[1])
        error_nx = int(error[2])
        error_nz = int(error[3])
        circ.append(cirq.Moment([Shift.Shift(d,error_nx).on(qudits[qudit_index])]))
        circ.append(cirq.Moment([Phase.Phase(d,error_nz).on(qudits[qudit_index])]))

    #parity check
    circ.append(cirq.Moment([QFT.H(d).on(qudits[5]),QFT.H(d).on(qudits[6]),QFT.H(d).on(qudits[7]),QFT.H(d).on(qudits[8])]))
    
    for i in range(0,4):
        circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[(i+2)%5]),QFT.Hdag(d).on(qudits[(i+4)%5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[i+5],qudits[i])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[i+5],qudits[(i+1)%5])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[i+5],qudits[(i+2)%5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[i+5],qudits[(i+4)%5])]))
        circ.append(cirq.Moment([QFT.H(d).on(qudits[(i+2)%5]),QFT.H(d).on(qudits[(i+4)%5])]))
        
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[5]),QFT.Hdag(d).on(qudits[6]),QFT.Hdag(d).on(qudits[7]),QFT.Hdag(d).on(qudits[8])]))
    circ.append(cirq.Moment([cirq.measure(qudits[5], key = 'a1'),cirq.measure(qudits[6], key = 'b1'),cirq.measure(qudits[7], key =  'c1'),cirq.measure(qudits[8], key =  'd1')]))
    
    return [qudits,circ]

# Function that connects errors with their error syndromes
def error_mapping(d):
    
    # Initialize a list to store the results
    result_list = []

    # Create a simulator
    sim = cirq.Simulator()
    
    errors = create_ordered_string_list(d)

    # Simulate each circuit and store the result in the list
    for error in errors:
        code = create_code(d,error)
        result = sim.simulate(code[1])
        measurements = []
        for letter in ['a1', 'b1', 'c1', 'd1']:
            key = f'{letter}'
            value = result.measurements[key][0]
            measurements.append(value)
        result_list.append({'Error': error, 'Result': measurements})

    return result_list

# Function that sorts and converts the error syndromes into 'nodes'. The nodes represent the ancilla's they trigger.
def extract_full_fault_ids(d):
    data_list = error_mapping(d)
    # Generate the specific lists in the desired order
    specific_lists = []
    for j in range(d):
        for i in range(4):
            lst = [0, 0, 0, 0]
            lst[i] = j + 1  # Start from 1 to d for the given position
            specific_lists.append(lst)
    
    # Extract errors corresponding to the specific lists in the correct order
    extracted_errors = []
    remaining_data = data_list.copy()  # Make a copy of the data_list to modify
    
    for lst in specific_lists:
        for item in remaining_data:
            if item['Result'] == lst:
                extracted_errors.append(item)
                remaining_data.remove(item)
                break  # Move to the next specific list after finding the match
    
    # Remove the first item of remaining_data
    if remaining_data:
        remaining_data.pop(0)
    
    # Append remaining data to extracted errors
    extracted_errors.extend(remaining_data)
    
    # Add 'Index' to each item and assign weights
    for i, item in enumerate(extracted_errors):
        item['Index'] = i
    
    # Convert 'Result' to 'Node' with 2 numbers
    for item in extracted_errors:
        result = item['Result']
        new_result = []
        for idx, val in enumerate(result):
            if val != 0:
                new_result.append(idx + 4 * (val - 1))
        item['Node'] = new_result
        del item['Result']
    
    return extracted_errors

In [None]:
#Example
error_mapping(2)

In [None]:
#Example
extract_full_fault_ids(2)

In [None]:
# A function that converts higher dimensional syndrome measurements into longer binary syndromes
def process_list(input_list,d):
    inter = []
    # Copy the input list per 4 elements and put them next in the list
    for i in range(0, len(input_list), 4):
        for j in range(0,d-1):
            inter.extend(input_list[i:i+4])
        
    output_list = []
    
    for i in range(0, len(inter), 4 * (d - 1)):
        chunk = inter[i:i + 4 * (d - 1)]
        processed_chunk = []

        # Process each group of 4 elements
        for k in range(d - 1):
            group = chunk[4 * k:4 * (k + 1)]
            if k == d - 2:  # Last group
                processed_group = [1 if num == d - 1 else 0 for num in group]
            else:  # Other groups
                processed_group = [1 if num == k + 1 else 0 for num in group]
            
            processed_chunk.extend(processed_group)

        output_list.extend(processed_chunk)

    return output_list

# A function to XOR error syndromes before feeding them to a decoder. UPDATE: does not work for circuit-level noise
def xor_list(lst,d):
    result = lst[: 4 * (d - 1)]  # Keep the first 4 elements unchanged
    for i in range( 4 * (d - 1), len(lst)):
        xor_result = lst[i] ^ lst[i -  4 * (d - 1)]  # XOR current element with element 4 positions before it
        result.append(xor_result)
    return result

#Updated XOR function that works for circuit-level noise
def xor_check_blocks_with_prev(lst, d):
    block_size = 4 * (d - 1)
    result = []
    check_block = [0] * block_size  # Initial check block
    prev_block = None
    last_index = len(lst) - block_size

    for i in range(0, len(lst), block_size):
        current_block = lst[i:i + block_size]

        # Pad incomplete blocks with zeros if needed
        if len(current_block) < block_size:
            current_block += [0] * (block_size - len(current_block))

        if prev_block is None:
            xor_block = current_block  # XOR with zero block
        else:
            xor_block = [a ^ b for a, b in zip(current_block, prev_block)]

        # Check if this is the final block
        is_last_block = (i >= last_index)

        # Determine whether to accept the block
        if xor_block == check_block:
            result.extend(current_block)
            check_block = current_block
        elif is_last_block and all(x == 0 for x in check_block):
            # Special case: last block fails, but check block is still all zeros
            result.extend(current_block)
            check_block = current_block
        else:
            result.extend([0] * block_size)
            # Do not update check_block

        prev_block = current_block

    return result

In [None]:
# A function that divides the probability on an error across all pauli errors and assigns them their respective weights.
def Dep_weights_MWPM(p,pm,d):
    p_dep = p
    weight = np.log((1-p_dep)/p_dep)
    weight_list = []
    for i in range(0,2*5 * (d - 1)):
        weight_list.append(weight)
    weight_list.append(np.log((1-pm)/pm))
    return weight_list

def get_errors_by_index(fault_data, index_list):
    # Create a dictionary to map index to errors
    index_to_error = {item['Index']: item['Error'] for item in fault_data}
    
    # Get the errors corresponding to the provided index list
    errors = [index_to_error[idx] for idx in index_list if idx in index_to_error]
    
    return errors

#Create the matching graph for the 5 qudit code, this function takes any distribution of weights
def create_matching_graph(d,cycles,id_list,weights):
    # Initialize the merged graph
    merged_graph = pm.Matching()
    
    
    NPC = 4*(d-1) #nodes per cycle
    
    # Create and merge the complete graphs
    for i in range(cycles):
        # Create the complete graph
        for j in range(NPC):
            merged_graph.add_boundary_edge(j+i*NPC, fault_ids={id_list[j]['Index']}, weight=weights[id_list[j]['Index']]) #0 P0, 1 S3, 2 X1, 3 Z4, 4 Z0, 5 X3, 6 S1, 7 P4
            
        for k in range(NPC,2*5 * (d - 1)):
            merged_graph.add_edge(i*NPC+id_list[k]['Node'][0], i*NPC+id_list[k]['Node'][1],fault_ids={id_list[k]['Index']},weight = weights[id_list[j]['Index']]) #0-5 Z1

        # Add edges between corresponding nodes in consecutive graphs (except for the first one)
        if i > 0:
            for node in range(NPC):
                merged_graph.add_edge(node + (i-1)*NPC, node + i*NPC,fault_ids={len(id_list)+1}, weight= weights[-1])

    return merged_graph

# Create the stim error model string for the 5 qudit code, this version of the function is specifically made for standard depolarization noise.
# Without an automated way for getting detector error models from qudit circuits this has to be made manually for every noise model.
def create_stim_error_model_string(data_list, d, p,cycles):
    error_lines = []
    
    # Preprocess to find entries with Xij and Zik errors and their indices
    x_errors = {entry['Error']: (entry['Node'], entry['Index']) for entry in data_list if 'X' in entry['Error']}
    z_errors = {entry['Error']: (entry['Node'], entry['Index']) for entry in data_list if 'Z' in entry['Error']}
    
    if cycles > 1:
        error_lines.append(f"repeat {cycles}" + " {")
    
    for entry in data_list:
        nodes = entry['Node']
        index = entry['Index']
        
        if 'Y' in entry['Error']:
            # Extract i, j, k from 'Yijk'
            y_error = entry['Error']
            i, j, k = y_error[1], y_error[2], y_error[3]
            
            # Find corresponding 'Xij' and 'Zik' entries
            xij_error = f'X{i}{j}'
            zik_error = f'Z{i}{k}'
            
            x_nodes, x_index = x_errors.get(xij_error, ([], None))
            z_nodes, z_index = z_errors.get(zik_error, ([], None))
            
            # Combine nodes with '^' separator, adding 'Li' for both parts
            z_node_str = ' '.join(f'D{node}' for node in z_nodes) + f' L{z_index}' if z_index is not None else ''
            x_node_str = ' '.join(f'D{node}' for node in x_nodes) + f' L{x_index}' if x_index is not None else ''
            node_str = f"{z_node_str} ^ {x_node_str}"
        else:
            node_str = ' '.join(f'D{node}' for node in nodes) + f' L{index}'
        
        error_lines.append(f"error(0.05) {node_str}")
        
    if cycles > 1:
        error_lines.append(f"shift_detectors {4*(d-1)}")
        error_lines.append("}")

    # Join all error lines with newlines between them
    model_string = "\n".join(error_lines)
    
    return model_string

In [None]:
# Correction circuit used to apply the correction on the simulated final state vector to compare whether we correct to the initial state.
def C_circ(errors,d):
    # Build the correction circuit
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8 = cirq.LineQid.range(9, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    correction_circuit = cirq.Circuit()

    # Correction on the data qubits
    # IMPORTANT: all qubits need to be acted on
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[0])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[1])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[2])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[3])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[4])]))

    for error in errors:
        # Apply X error on a specific qubit
        if 'X' in error:
            qudit_index = int(error[1])
            error_n = d - int(error[2])
            correction_circuit.append(cirq.Moment([Shift.Shift(d, error_n).on(qudits[qudit_index])]))

        # Apply Z error on a specific qubit
        if 'Z' in error:
            qudit_index = int(error[1])
            error_n = d - int(error[2])
            correction_circuit.append(cirq.Moment([Phase.Phase(d, error_n).on(qudits[qudit_index])]))

    # Reset the ancillas (can't use reset due to final_state_vector evolution from initial state vector)

    for i in range(0, 4):
        correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[5 + i])]))

    return correction_circuit

# Simulations

In [None]:
#The initial encoded state vector of the 5 qudot code
def Initial_state(d,L=0): 
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8 = cirq.LineQid.range(9, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    circ = cirq.Circuit()
    
    #encoding
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[0]),QFT.Hdag(d).on(qudits[1]),QFT.Hdag(d).on(qudits[2]),QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[3],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[3])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[1])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    
    circ.append(cirq.Moment([Shift.Shift(d,d).on(qudits[5])]))
    circ.append(cirq.Moment([Shift.Shift(d,d).on(qudits[6])]))
    circ.append(cirq.Moment([Shift.Shift(d,d).on(qudits[7])]))
    circ.append(cirq.Moment([Shift.Shift(d,d).on(qudits[8])]))
        
    if L>0:
        for i in range(0,5):
            qudit_index = i
            circ.append(cirq.Moment([Shift.Shift(d,L).on(qudits[qudit_index])]))

    sim = cirq.Simulator(dtype = np.complex128)
    result = sim.simulate(circ)
    rho = result.final_state_vector
    return rho

#Comparing whether state vectors are the same. This is way faster than computing fidelities and basically does the same thing.
def compareStateVectors(v1, v2):

    rtol = 1e-5
    atol = 1e-5

    n1 = len(v1)
    n2 = len(v2)

    if n1 != n2:
        raise ValueError("vectors must have equal length")
    
    diff = abs(v1 - v2)
    max_index = np.argmax(diff)
    
    if diff[max_index] == 0:
        isEqual = True
    elif v2[max_index] == 0:
        isEqual = False
    else:
        factor = v1[max_index]/v2[max_index]
        isEqual = np.allclose(v1, factor*v2, rtol=rtol, atol=atol)
        # if factor != 1:
        #     print('factor: ', factor)
    return isEqual

In [None]:
# Quantum circuit for the 5 qudit code with standard depolarization noise
def dep_circ(cycles,p,d): 
    Dep = Dchannel.depolarizeQudit(p,d)
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8 = cirq.LineQid.range(9, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    circ = cirq.Circuit()

    #Encode
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[0]),QFT.Hdag(d).on(qudits[1]),QFT.Hdag(d).on(qudits[2]),QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[3],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[3])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[1])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    
        
    for i in range(1,cycles+1):
        circ.append(cirq.Moment([Dep.on(qudits[0]),Dep.on(qudits[1]),Dep.on(qudits[2]),Dep.on(qudits[3]),Dep.on(qudits[4])]))
        cycle = f'{i}'
        #parity check
        circ.append(cirq.Moment([QFT.H(d).on(qudits[5]),QFT.H(d).on(qudits[6]),QFT.H(d).on(qudits[7]),QFT.H(d).on(qudits[8])]))

        for j in range(0,4):
            circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[(j+2)%5]),QFT.Hdag(d).on(qudits[(j+4)%5])]))
            circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[j+5],qudits[j])]))
            circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[j+5],qudits[(j+1)%5])]))
            circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[j+5],qudits[(j+2)%5])]))
            circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[j+5],qudits[(j+4)%5])]))
            circ.append(cirq.Moment([QFT.H(d).on(qudits[(j+2)%5]),QFT.H(d).on(qudits[(j+4)%5])]))

        circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[5]),QFT.Hdag(d).on(qudits[6]),QFT.Hdag(d).on(qudits[7]),QFT.Hdag(d).on(qudits[8])]))
        circ.append(cirq.Moment([cirq.measure(qudits[5], key = 'a'+cycle),cirq.measure(qudits[6], key = 'b'+cycle),cirq.measure(qudits[7], key =  'c'+cycle),cirq.measure(qudits[8], key =  'd'+cycle)]))
        circ.append(cirq.Moment([cirq.reset(qudits[5]),cirq.reset(qudits[6]),cirq.reset(qudits[7]),cirq.reset(qudits[8])]))
        
   
    return [qudits,circ]

In [None]:
# Function to simulate and extract logical error rates. Circuit simulation --> error syndromes --> decoding by two decoders --> correction--> logical error rates
def Simulate(circ,cycles,d,samples,id_list,p):
    
    #Initialiaze
    fidelitiesBM = []
    fidelitiesMWPM = []
    sim = cirq.Simulator(dtype=np.complex64)
    correct_state = Initial_state(d)
    from beliefmatching import BeliefMatching
    #Initialize decoders
    model_string = create_stim_error_model_string(id_list, d, p,cycles)
    model = stim.DetectorErrorModel(f"""{model_string}""")
    bmD = BeliefMatching(model, max_bp_iters=30)

    
    graph = create_matching_graph(d,cycles,id_list,Dep_weights_MWPM(p,0.00000000001,d))
    
    
    for j in tqdm(range(samples), desc="Simulating", unit="sample"):
        
        #Sample syndromes and store state vectors
        result = sim.simulate(circ[1])
                # Extract measurements
        measurements = []
        for i in range(1, cycles + 1):
            cycle_key = f'{i}'
            for letter in ['a', 'b', 'c', 'd']:
                key = f'{letter}{cycle_key}'
                value = result.measurements[key][0]
                measurements.append(value)
         
        extended_meas = process_list(measurements,d)
        measXOR = xor_list(extended_meas,d)
        rho = result.final_state_vector
        
        #Decoding
        decodingBM = bmD.decode(np.array(measXOR))
        decodingMWPM = graph.decode(measXOR)
        CposBM = [index for index, value in enumerate(decodingBM) if value == 1]
        CposMWPM = [index for index, value in enumerate(decodingMWPM) if value == 1]
        errorBM = get_errors_by_index(id_list, CposBM)
        errorMWPM = get_errors_by_index(id_list, CposMWPM)
        
        
        final_state_vectorBM = cirq.final_state_vector(program=C_circ(errorBM,d), initial_state=rho)
        final_state_vectorMWPM = cirq.final_state_vector(program=C_circ(errorMWPM,d), initial_state=rho)
        fidelityBM = compareStateVectors(final_state_vectorBM, correct_state)
        fidelityMWPM = compareStateVectors(final_state_vectorMWPM, correct_state)
        fidelitiesMWPM.append(fidelityMWPM)
        fidelitiesBM.append(fidelityBM)

    return calculate_average(fidelitiesBM), calculate_average(fidelitiesMWPM)

# Update: introduction of a flag qudit to correct hook errors

In [None]:
#New function that stores the error rates for flag qudit simulations
def append_results_to_csv_flag(output_folder, filename, result, p, samples):
    """
    Appends a new_value = result / samples to a CSV file (one value per row).
    Each file is assumed to correspond to one p-value setting.
    """
    new_value = result[0] / (samples)
    filepath = os.path.join(output_folder, filename)

    try:
        with open(filepath, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([f"{new_value:.12f}"])
        print(f"Appended {new_value:.12f} for p={p} to {filepath}")
    except Exception as e:
        print(f"Error appending to {filepath}: {e}")

#Initial state vector to use when a flag qudit is part of the circuitry
def Initial_state_Flag(d,L=0): 
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8, q9 = cirq.LineQid.range(10, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    qudits.append(q9)
    circ = cirq.Circuit()
    
    #encoding
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[0]),QFT.Hdag(d).on(qudits[1]),QFT.Hdag(d).on(qudits[2]),QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[3],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[3])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[1])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    
    circ.append(cirq.Moment([QFT.H(d).on(qudits[5]),QFT.H(d).on(qudits[6]),QFT.H(d).on(qudits[7]),QFT.H(d).on(qudits[8])]))
    
    for i in range(0,4):
        circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[(i+2)%5]),QFT.Hdag(d).on(qudits[(i+4)%5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[i+5],qudits[i])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[i+5],qudits[(i+1)%5])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[i+5],qudits[(i+2)%5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[i+5],qudits[(i+4)%5])]))
        circ.append(cirq.Moment([QFT.H(d).on(qudits[(i+2)%5]),QFT.H(d).on(qudits[(i+4)%5])]))
        
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[5]),QFT.Hdag(d).on(qudits[6]),QFT.Hdag(d).on(qudits[7]),QFT.Hdag(d).on(qudits[8])]))
    circ.append(cirq.Moment([Shift.Shift(d,d).on(qudits[9])]))  
    if L>0:
        for i in range(0,5):
            qudit_index = i
            circ.append(cirq.Moment([Shift.Shift(d,L).on(qudits[qudit_index])]))

    sim = cirq.Simulator(dtype = np.complex128)
    result = sim.simulate(circ)
    rho = result.final_state_vector
    return rho

def C_circ_Flag(errors,d):
    # Build the correction circuit
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8, q9 = cirq.LineQid.range(10, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    qudits.append(q9)
    correction_circuit = cirq.Circuit()

    # Correction on the data qubits
    # IMPORTANT: all qubits need to be acted on
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[0])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[1])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[2])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[3])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[4])]))
    correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[9])]))

    for error in errors:
        
        # Apply Z error on a specific qubit
        if 'Z' in error:
            qudit_index = int(error[1])
            error_n = d - int(error[2])
            correction_circuit.append(cirq.Moment([Phase.Phase(d, error_n).on(qudits[qudit_index])]))

        # Apply X error on a specific qubit
        if 'X' in error:
            qudit_index = int(error[1])
            error_n = d - int(error[2])
            correction_circuit.append(cirq.Moment([Shift.Shift(d, error_n).on(qudits[qudit_index])]))



        if 'Y' in error:
            qudit_index = int(error[1])
            error_n = d - int(error[2])
            error_n2 = d - int(error[3])
            correction_circuit.append(cirq.Moment([Phase.Phase(d, error_n2).on(qudits[qudit_index])]))
            correction_circuit.append(cirq.Moment([Shift.Shift(d, error_n).on(qudits[qudit_index])]))
            
    # Reset the ancillas (can't use reset due to final_state_vector evolution from initial state vector)

    for i in range(0, 4):
        correction_circuit.append(cirq.Moment([Id.I(d).on(qudits[5 + i])]))

    return correction_circuit

#An example of a circuit for the 5-qudit code with circuit-level noise and an extra flag qudit.
def Real_circ(cycles,p,d): 
    OneGErr = Dchannel.depolarizeQudit(p,d)
    TwoGErr = TwoDchannel.depolarizeTwoQudit(p,d)
    Idle = Dchannel.depolarizeQudit(p,d)
    MErr = BFChannel.BFd(p,d)
    qudits = []
    q0, q1, q2, q3, q4, q5, q6, q7, q8, q9 = cirq.LineQid.range(10, dimension=d)
    qudits.append(q0)
    qudits.append(q1)
    qudits.append(q2)
    qudits.append(q3)
    qudits.append(q4)
    qudits.append(q5)
    qudits.append(q6)
    qudits.append(q7)
    qudits.append(q8)
    qudits.append(q9)
    circ = cirq.Circuit()

    #Encode
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[0]),QFT.Hdag(d).on(qudits[1]),QFT.Hdag(d).on(qudits[2]),QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[3],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[2],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[1],qudits[3])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[1])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[0],qudits[4])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[4])]))
    circ.append(cirq.Moment([Mul.Mdag(d,d-1).on(qudits[1])]))
    
    for i in range(1,cycles):
        cycle = f'{i}'
        #parity check
        circ.append(cirq.Moment([Idle.on(qudits[0]), Idle.on(qudits[1]),OneGErr.on(qudits[2]),Idle.on(qudits[3]),OneGErr.on(qudits[4]),OneGErr.on(qudits[5]),OneGErr.on(qudits[6]),OneGErr.on(qudits[7]),OneGErr.on(qudits[8])]))
        #Gates
        circ.append(cirq.Moment([QFT.H(d).on(qudits[5]),QFT.H(d).on(qudits[6]),QFT.H(d).on(qudits[7]),QFT.H(d).on(qudits[8]),QFT.Hdag(d).on(qudits[2]), QFT.Hdag(d).on(qudits[4])]))
        #Error time slice
        
        #Two qubit gates + Error slices
        circ.append(cirq.Moment([TwoGErr.on(qudits[0],qudits[5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[5], qudits[0])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[5],qudits[9])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[1],qudits[5])]))        
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[5], qudits[1])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[2],qudits[5])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[5], qudits[2])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[5])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[5],qudits[9])]))
        circ.append(cirq.Moment([MErr.on(qudits[9])]))
        circ.append(cirq.Moment([cirq.measure(qudits[9], key = f'flag{(i-1)*4+0}')])) 
        circ.append(cirq.Moment([cirq.reset(qudits[9])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[4],qudits[5])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[5], qudits[4])]))
        
        circ.append(cirq.Moment([Idle.on(qudits[1]),Idle.on(qudits[2]),Idle.on(qudits[4])]))
        circ.append(cirq.Moment([OneGErr.on(qudits[0]), TwoGErr.on(qudits[1],qudits[6]),OneGErr.on(qudits[2]),OneGErr.on(qudits[3]),OneGErr.on(qudits[4])]))
        circ.append(cirq.Moment([QFT.H(d).on(qudits[2]), QFT.H(d).on(qudits[4]),QFT.Hdag(d).on(qudits[3]), QFT.Hdag(d).on(qudits[0]),SUM.SUM(d,d).on(qudits[6], qudits[1])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[6])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[6],qudits[9])]))
        
        circ.append(cirq.Moment([TwoGErr.on(qudits[2],qudits[6])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[6], qudits[2])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[3],qudits[6])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[6], qudits[3])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[6])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[6],qudits[9])]))
        circ.append(cirq.Moment([MErr.on(qudits[9])]))
        circ.append(cirq.Moment([cirq.measure(qudits[9], key = f'flag{(i-1)*4+1}')])) 
        circ.append(cirq.Moment([cirq.reset(qudits[9])]))  
        circ.append(cirq.Moment([TwoGErr.on(qudits[0],qudits[6])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[6], qudits[0])]))
        
        circ.append(cirq.Moment([OneGErr.on(qudits[0]), OneGErr.on(qudits[1]),TwoGErr.on(qudits[2],qudits[7]),OneGErr.on(qudits[3]),OneGErr.on(qudits[4])]))
        circ.append(cirq.Moment([QFT.H(d).on(qudits[3]), QFT.H(d).on(qudits[0]),QFT.Hdag(d).on(qudits[4]), QFT.Hdag(d).on(qudits[1]),SUM.SUM(d,d).on(qudits[7], qudits[2])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[7])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[7],qudits[9])]))
        
        circ.append(cirq.Moment([TwoGErr.on(qudits[3],qudits[7])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[7], qudits[3])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[4],qudits[7])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[7], qudits[4])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[7])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[7],qudits[9])]))
        circ.append(cirq.Moment([MErr.on(qudits[9])]))
        circ.append(cirq.Moment([cirq.measure(qudits[9], key = f'flag{(i-1)*4+2}')])) 
        circ.append(cirq.Moment([cirq.reset(qudits[9])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[1],qudits[7])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[7], qudits[1])]))
        
        circ.append(cirq.Moment([Idle.on(qudits[0]),Idle.on(qudits[3]),Idle.on(qudits[4])]))
        circ.append(cirq.Moment([OneGErr.on(qudits[0]), OneGErr.on(qudits[1]),OneGErr.on(qudits[2]),TwoGErr.on(qudits[3],qudits[8]),OneGErr.on(qudits[4])]))
        circ.append(cirq.Moment([QFT.H(d).on(qudits[4]), QFT.H(d).on(qudits[1]),QFT.Hdag(d).on(qudits[0]), QFT.Hdag(d).on(qudits[2]),SUM.SUM(d,d).on(qudits[8], qudits[3])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[8])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[8],qudits[9])]))
        
        circ.append(cirq.Moment([TwoGErr.on(qudits[4],qudits[8])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[8], qudits[4])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[0],qudits[8])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[8], qudits[0])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[9],qudits[8])]))
        circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[8],qudits[9])]))
        circ.append(cirq.Moment([MErr.on(qudits[9])]))
        circ.append(cirq.Moment([cirq.measure(qudits[9], key = f'flag{(i-1)*4+3}')])) 
        circ.append(cirq.Moment([cirq.reset(qudits[9])]))
        circ.append(cirq.Moment([TwoGErr.on(qudits[2],qudits[8])]))
        circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[8], qudits[2])]))
        
        #One qubit gates + idles
        circ.append(cirq.Moment([OneGErr.on(qudits[0]), Idle.on(qudits[1]),OneGErr.on(qudits[2]),Idle.on(qudits[3]),Idle.on(qudits[4]),OneGErr.on(qudits[5]),OneGErr.on(qudits[6]),OneGErr.on(qudits[7]),OneGErr.on(qudits[8])]))
        circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[5]),QFT.Hdag(d).on(qudits[6]),QFT.Hdag(d).on(qudits[7]),QFT.Hdag(d).on(qudits[8]),QFT.H(d).on(qudits[0]), QFT.H(d).on(qudits[2])]))
        
        
        #Measurement error
        circ.append(cirq.Moment([MErr.on(qudits[5]), MErr.on(qudits[6]),MErr.on(qudits[7]),MErr.on(qudits[8]),Idle.on(qudits[0]), Idle.on(qudits[1]),Idle.on(qudits[2]),Idle.on(qudits[3]),Idle.on(qudits[4])]))
        #Measurements + idle
        circ.append(cirq.Moment([cirq.measure(qudits[5], key = 'a'+cycle),cirq.measure(qudits[6], key = 'b'+cycle),cirq.measure(qudits[7], key =  'c'+cycle),cirq.measure(qudits[8], key =  'd'+cycle)]))
        circ.append(cirq.Moment([cirq.reset(qudits[5]),cirq.reset(qudits[6]),cirq.reset(qudits[7]),cirq.reset(qudits[8])]))
   
    cycle = f'{cycles}'
    #parity check
    #Gates
    circ.append(cirq.Moment([QFT.H(d).on(qudits[5]),QFT.H(d).on(qudits[6]),QFT.H(d).on(qudits[7]),QFT.H(d).on(qudits[8]),QFT.Hdag(d).on(qudits[2]), QFT.Hdag(d).on(qudits[4])]))
    #Error time slice

    #Two qubit gates + Error slices
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[5], qudits[0])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[5], qudits[1])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[5], qudits[2])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[5], qudits[4])]))
    circ.append(cirq.Moment([QFT.H(d).on(qudits[2]), QFT.H(d).on(qudits[4]),QFT.Hdag(d).on(qudits[3]), QFT.Hdag(d).on(qudits[0]),SUM.SUM(d,d).on(qudits[6], qudits[1])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[6], qudits[2])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[6], qudits[3])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[6], qudits[0])]))
    circ.append(cirq.Moment([QFT.H(d).on(qudits[3]), QFT.H(d).on(qudits[0]),QFT.Hdag(d).on(qudits[4]), QFT.Hdag(d).on(qudits[1]),SUM.SUM(d,d).on(qudits[7], qudits[2])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[7], qudits[3])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[7], qudits[4])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[7], qudits[1])]))
    circ.append(cirq.Moment([QFT.H(d).on(qudits[4]), QFT.H(d).on(qudits[1]),QFT.Hdag(d).on(qudits[0]), QFT.Hdag(d).on(qudits[2]),SUM.SUM(d,d).on(qudits[8], qudits[3])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[8], qudits[4])]))
    circ.append(cirq.Moment([SUM.SUMdag(d,d).on(qudits[8], qudits[0])]))
    circ.append(cirq.Moment([SUM.SUM(d,d).on(qudits[8], qudits[2])]))
    circ.append(cirq.Moment([QFT.Hdag(d).on(qudits[5]),QFT.Hdag(d).on(qudits[6]),QFT.Hdag(d).on(qudits[7]),QFT.Hdag(d).on(qudits[8]),QFT.H(d).on(qudits[0]), QFT.H(d).on(qudits[2])]))
    #Measurements
    circ.append(cirq.Moment([cirq.measure(qudits[5], key = 'a'+cycle),cirq.measure(qudits[6], key = 'b'+cycle),cirq.measure(qudits[7], key =  'c'+cycle),cirq.measure(qudits[8], key =  'd'+cycle)]))
    circ.append(cirq.Moment([cirq.reset(qudits[5]),cirq.reset(qudits[6]),cirq.reset(qudits[7]),cirq.reset(qudits[8])]))       
   
    return [qudits,circ]

#A function that reads the lookup-tables to correct hook errors
def find_correction_from_flags(circuit_dict, flags_input, measurements_input):
    for data in circuit_dict.values():
        if data['flags'] == flags_input and data['measurements'] == measurements_input:
            return data.get('correction', [])
    return []

#Updated function to simulate certain quantum circuits with a flag qudit, Note that we now save the amount of errors instead of the fidelity.
def Simulate_Flag(circ,cycles,d,samples,id_list,p,hook_map):
    
    #Initialiaze
    errors = 0

    sim = cirq.Simulator(dtype=np.complex64)
    correct_state = Initial_state_Flag(d)
    from beliefmatching import BeliefMatching
    #Initialize decoders
    model_string = create_stim_error_model_string(id_list, d, p,cycles)
    model = stim.DetectorErrorModel(f"""{model_string}""")
    bmD = BeliefMatching(model, max_bp_iters=50)
    flags = 0
    flags_corrected = 0
    for j in range(samples):
        
        #Sample syndromes and store state vectors
        result = sim.simulate(circ[1])                # Extract measurements
        measurements = []
        flagsmeas = []
        for i in range(1, cycles + 1):
            cycle_key = f'{i}'
            for letter in ['a', 'b', 'c', 'd']:
                key = f'{letter}{cycle_key}'
                value = result.measurements[key][0]
                measurements.append(value)
        for j in range(0,8):
            cycle_key = f'{j}'
            value = result.measurements['flag'+cycle_key][0]
            flagsmeas.append(value)
         
        extended_meas = process_list(measurements,d)
        measXOR = xor_check_blocks_with_prev(extended_meas,d)
        rho = result.final_state_vector
        
            #Decoding
        decodingBM = bmD.decode(np.array(measXOR))
        CposBM = [index for index, value in enumerate(decodingBM) if value == 1]
        errorBM = get_errors_by_index(id_list, CposBM)
        final_state_vectorBM = cirq.final_state_vector(program=C_circ_Flag(errorBM,d), initial_state=rho)
        fidelityBM = compareStateVectors(final_state_vectorBM, correct_state)
        fidelityHook = False
        if sum(flagsmeas)>0:
            hook = find_correction_from_flags(hook_map, flagsmeas,measurements)
            hook_state_vector = cirq.final_state_vector(program=C_circ_Flag(hook,d), initial_state=rho)
            fidelityHook = compareStateVectors(hook_state_vector, correct_state)
            if fidelityHook or fidelityBM:
                flags_corrected+=1
            flags+=1
        if (not fidelityBM and not fidelityHook):
            errors += 1
    if flags>0:
        print(f'flags: {flags}')
        print(f'flags corrected: {flags_corrected}')

    return errors,flags

# Execute

In [None]:
#Execute a simple standar depolarization noise model simulation
output_folder = r"..."

#Load or calculate the following lists
id_list2 = extract_full_fault_ids(2)
# id_list3 = extract_full_fault_ids(3)
# id_list5 = extract_full_fault_ids(5)
# id_listd = load_results(fault_id_folder, f'id_list{d}.pkl')

# Define paremeters
samples = 10
cycles = 3
p = 0.01

# Create circuits for different dimensions
circ2 = dep_circ(cycles,p,2)
# circ3 = dep_circ(cycles,p,3)
# circ5 = dep_circ(cycles,p,5)

result2 = Simulate(circ2, cycles, 2, samples, id_list2, p)
# result3 = Simulate(circ3, cycles, 3, samples, id_list3, p)
# result5 = Simulate(circ5, cycles, 5, samples, id_list5, p)

result2
#Perform this for multiple 'p's to get logical error rate plots as a function of physical error rates.
# save_results_to_csv(output_folder, f'TutorialMWPM{p}.csv', result2[1], result3[1], result5[1])
# save_results_to_csv(output_folder, f'TutorialBM{p}.csv', result2[0], result3[0], result5[0])
# append_results_to_csv(output_folder, f'TutorialMWPM{p}.csv', result2[1], result3[1], result5[1])
# append_results_to_csv(output_folder, f'TutorialBM{p}.csv', result2[0], result3[0], result5[0])

In [None]:
#Execute a circuit-level noise model simulation with flag qudits to correct hook errors
output_folder = r"..."
input_folder = r"C:\Users\Keppen32\OneDrive - imec\Documents\Github\5QuditCodeTutorial"

# Define paremeters
d=2
samples = 10
cycles = 3
#circuit-level noise now, per step.
p = 0.001

# Create circuits for different distances
circd = Real_circ(cycles,p,d)

#load or calculate these lists/dictionaries
# id_listd = load_results(fault_id_folder, f'id_list{d}.pkl')
# id_list2 = extract_full_fault_ids(2)
# id_list3 = extract_full_fault_ids(3)
# id_list5 = extract_full_fault_ids(5)
hook_mapd = load_results(input_folder, f'hook_map{d}.pkl')

result = Simulate_Flag(circd, cycles, d, samples, id_list2, p,hook_mapd)

#Perform this for multiple 'p's to get logical error rate plots as a function of physical error rates.
append_results_to_csv_flag(output_folder, '....csv', result,p, samples)