In [1]:
%load_ext autoreload
%autoreload 2

import stim
import numpy as np
# import matplotlib.pyplot as plt
from stim_utils import *
from stim_simulation_utils import *
from circuit_commons import *
print(stim.__version__)

1.14.0


In [2]:
import importlib.util
import sys


def reload_file(name, path):
    spec = importlib.util.spec_from_file_location(name, path)
    your_file = importlib.util.module_from_spec(spec)
    sys.modules[name] = your_file
    spec.loader.exec_module(your_file)


reload_file("stim_simulation_utils", "stim_simulation_utils.py")
reload_file("stim_utils", "stim_utils.py")


In [3]:
# Attempting to implement the actual encoder and decoders from http://arxiv.org/abs/2409.04628 (Demonstration of quantum computation and error correction with a tesseract code)

measurement_operators_rows = [
    [0,1,2,3],
    [4,5,6,7],
    [8,9,10,11],
    [12,13,14,15]
]

measurement_operators_columns = [
    [0,4,8,12],
    [1,5,9,13],
    [2,6,10,14],
    [3,7,11,15]
]

def encode_sub_circuit_quad(circuit, ancilla, qubits):
    circuit.append("CNOT", [qubits[0], ancilla]) # Flag qubit
    circuit.append("CNOT", [qubits[0], qubits[1]])
    circuit.append("CNOT", [qubits[0], qubits[2]])
    circuit.append("CNOT", [qubits[0], qubits[3]])
    circuit.append("CNOT", [qubits[0], ancilla]) # Flag qubit

def add_cnot_gates(circuit, start1, start2, num_gates=4):
    """
    Adds CNOT gates to the circuit.

    Parameters:
    - circuit: The quantum circuit object to modify.
    - start1: The starting index for the first qubit in the CNOT pairs.
    - start2: The starting index for the second qubit in the CNOT pairs.
    - num_gates: The number of CNOT gates to append (default is 4).
    """
    for i in range(num_gates):
        circuit.append("CNOT", [start1 + i, start2 + i])

def encode_manual(circuit):
    # Here we encode the state |++0000> as can be seen in Fig. 9 of that paper
    # initialize qubits:
    circuit.append("H", [0,1,2,3,4,8,12])

    encode_sub_circuit_quad(circuit, 20, [4, 5, 6, 7])
    encode_sub_circuit_quad(circuit, 21, [8, 9, 10, 11])
    encode_sub_circuit_quad(circuit, 22, [12, 13, 14, 15])

    add_cnot_gates(circuit, 0, 16) # working on ancilla qubits as flag qubits
    add_cnot_gates(circuit, 0, 4)
    add_cnot_gates(circuit, 0, 8)
    add_cnot_gates(circuit, 0, 12)
    add_cnot_gates(circuit, 0, 16) # working on ancilla qubits as flag qubits

    circuit.append("R", [18, 19]) # Reset ancillas 18,19 since they role is done and we them for the following
    circuit.append("H", [19])
    circuit.append("CNOT", [19, 18]) # cnot to flag qubit
    circuit.append("CNOT", [19, 0]) # measuring stabilizer
    circuit.append("CNOT", [19, 1]) # measuring stabilizer
    circuit.append("CNOT", [19, 2]) # measuring stabilizer
    circuit.append("CNOT", [19, 3]) # measuring stabilizer
    circuit.append("CNOT", [19, 18]) # cnot to flag qubit

def error_correction_round_rows(circuit):
    # TODO consider changing this to manual implementation as seen in Fig. 4(d)
    for row in measurement_operators_rows:
        append_stabilizer(circuit, "X", row)
        append_stabilizer(circuit, "Z", row)

def error_correction_round_columns(circuit):
    # TODO consider changing this to manual implementation as seen in Fig. 4(d)
    for column in measurement_operators_columns:
        append_stabilizer(circuit, "X", column)
        append_stabilizer(circuit, "Z", column)


def error_correct_manual(circuit):
    flagX = -1
    flagZ = -1
    # TODO need to implement. these functions are the basis:
    error_correction_round_rows(circuit)
    error_correction_round_columns(circuit)



# Reference code for correcting Z error in column:
def correct_column_Z(flagX: int, measX, frameZ):
    if flagX == -1:  # no row flagged already
        if sum(measX) == 2:
            return "reject"
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                # Disagreeing measurement is 1 -> flag the index of 1
                flagX = measX.index(1)
            else:
                # Disagreeing measurement is 0 -> flag the index of 0
                flagX = measX.index(0)
    else:  # row flagX in (0, 1, 2, 3) flagged
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                col = measX.index(1)
            else:
                col = measX.index(0)
            frameZ[4 * flagX + col] += 1  # Z correction
        if sum(measX) == 2:
            if measX in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameZ[[4 * flagX, 4 * flagX + 1]] += 1  # ZZII on flagged row
            else:
                return "reject"
        flagX = -1
    return flagX, measX, frameZ

# Deduced from above example:
# TODO test following correction - I went over them manually and they look ok. and make sense as far as I understand
def correct_column_X(flagZ, measZ, frameX):
    if flagZ == -1:  # No row flagged already
        if sum(measZ) == 2:
            return "reject"
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                flagZ = measZ.index(1)
            else:
                flagZ = measZ.index(0)
    else:  # Row flagZ in (0, 1, 2, 3) flagged
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                col = measZ.index(1)
            else:
                col = measZ.index(0)
            frameX[4 * flagZ + col] += 1  # Apply X correction
        elif sum(measZ) == 2:
            if measZ in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameX[[4 * flagZ, 4 * flagZ + 1]] += 1  # XXII on flagged row
            else:
                return "reject"
        flagZ = -1
    return flagZ, measZ, frameX

def correct_row_Z(flagX, measX, frameZ):
    if flagX == -1:  # No column flagged already
        if sum(measX) == 2:
            return "reject"
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                flagX = measX.index(1)
            else:
                flagX = measX.index(0)
    else:  # Column flagX in (0, 1, 2, 3) flagged
        if sum(measX) in (1, 3):
            if sum(measX) == 1:
                row = measX.index(1)
            else:
                row = measX.index(0)
            frameZ[4 * row + flagX] += 1  # Apply Z correction
        elif sum(measX) == 2:
            if measX in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameZ[[4 * 0 + flagX, 4 * 1 + flagX]] += 1  # ZZII on flagged column. multiplication by 0 is for readability.
            else:
                return "reject"
        flagX = -1
    return flagX, measX, frameZ

def correct_row_X(flagZ, measZ, frameX):
    if flagZ == -1:  # No column flagged already
        if sum(measZ) == 2:
            return "reject"
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                flagZ = measZ.index(1)
            else:
                flagZ = measZ.index(0)
    else:  # Column flagZ in (0, 1, 2, 3) flagged
        if sum(measZ) in (1, 3):
            if sum(measZ) == 1:
                row = measZ.index(1)
            else:
                row = measZ.index(0)
            frameX[4 * row + flagZ] += 1  # Apply X correction
        elif sum(measZ) == 2:
            if measZ in ([0, 0, 1, 1], [1, 1, 0, 0]):
                frameX[[4 * 0 + flagZ, 4 * 1 + flagZ]] += 1  # XXII on flagged column. multiplication by 0 is for readability.
            else:
                return "reject"
        flagZ = -1
    return flagZ, measZ, frameX


## creating cirucit with "real" encoder/decoder/error-correction

In [4]:
# Initialize the circuit
circuit = init_circuit()

# Encode the logical qubit
encode_manual(circuit)

# Apply noise channel
channel(circuit, 0.5, [0])

error_correct_manual(circuit)
# Perform error correction
# error_correction(circuit)

# Print the final circuit for verification
# print(circuit)
circuit.diagram()

In [5]:
# Flag qubits simulation - see reference: https://quantumcomputing.stackexchange.com/questions/22281/simulating-flag-qubits-and-conditional-branches-using-stim
simulator = stim.TableauSimulator()
simulator.do(circuit)

simulator.current_measurement_record()

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True]