In [1]:
import stim
import numpy as np
import random

## Data structure for Node 
this is used for:
- referencing to the qubits in the protocol
- measurement tree for each RGS arm
- 

## Implementation Details

Qubits indices:
- qubit 0 and 1 are located at end nodes for anchoring half-RGS, Alice and Bob respectively. So the goal is to verify the stabilizers of qubit 0 and 1 and expecting to see XZ and ZX.
- qubit 2 and 3 are for building the biclique RGS. The RGS is generated sequentially so we can reuse the qubits.



In [None]:
### helper functions

def num_qubits_per_rgs_arm(bvec: list[int]) -> int:
    num_in_layers = bvec[:]
    for i in range(1, len(num_in_layers)):
        num_in_layers[i] *= num_in_layers[i-1]
    num_qubits_per_arm = 1 + np.sum(num_in_layers)
    return num_qubits_per_arm

In [None]:
from enum import Enum

class Node:
    def __init__(self, qubit_index=-1, parent_index=-1):
        self.qubit_index = qubit_index  # for debugging
        self.parent_index = parent_index  # for debugging
        self.measurement_result: bool | None = (
            None  # this should be True and False if the qubit has been measured
        )
        self.eigenvalue: bool = False # use this to toggle and also to 
        self.children: list[Node] = []
        self.is_lost = False  # this is used to denote whether the qubit is lost in the fiber or not
        self.has_z = False  # this is used to denote whether the qubit has Z side effect from the emission process or not

    def get_post_order(self):
        return_list: list[Node] = []
        for u in self.children:
            return_list.extend(u.descendents_post_order())
        return [self] + return_list

    def get_level_traversal(self):
        return_list = [self]
        queue = [self]
        while len(queue) > 0:
            new_queue = []
            for u in queue:
                return_list.append(u)
                new_queue.extend(u.children)
            queue = new_queue
        return return_list

    def add_z_side_effects(self, circuit: stim.Circuit, z_prob=0.5):
        if len(self.children) == 0:
            return
        for u in self.children:
            u.add_z_side_effects(circuit, z_prob)
        if self.parent_index == -1:
            return
        if random.random() < z_prob:
            circuit.append("Z", self.qubit_index)
            self.has_z = not self.has_z

    def get_indices_from_level(self, k: int) -> list[int]:
        # Not that we need it right now?
        """get all the indices from the nodes"""
        cur_level = 0
        queue = [self]
        while len(queue) > 0:
            # arrive at the correct level, return the qubit indices
            if cur_level == k:
                return [v.qubit_index for v in queue]
            # need to go to the next level
            num_nodes_in_level = len(queue)
            for _ in range(num_nodes_in_level):
                u = queue.pop(0)
                queue.extend(u.children)
            cur_level += 1
        # should change this to an error
        return []  # empty list indicating the level specified is out of range

In [None]:
from enum import Enum

# functional syntax
Pauli = Enum('Basis', ['X', 'Y', 'Z'])

class HalfRGS:
    """Currently this is only used at end nodes"""
    def __init__(self, m: int, branching_params: list[int], anchor_index: int):
        self.m = m
        self.bv = branching_params
        self.arms = [Node() for _ in range(m)]
        self.anchor = anchor_index
        self.measurement_bases: list[Pauli | None] = [None for _ in range(m)]
    
    def assign_qubit_indices(self, starting_index: int) -> int:
        pass

    def initialize_quantum_state(self, t: stim.TableauSimulator, outer_emitter: int, root_ancilla: int):
        pass

class RGS:
    def __init__(self, m: int, branching_params: list[int]):
        self.m = m
        self.bv = branching_params
        self.left_arms = [Node() for _ in range(m)]
        self.right_arms = [Node() for _ in range(m)]
        self.measurement_bases_left: list[Pauli | None] = [None for _ in range(m)]
        self.measurement_bases_right: list[Pauli | None] = [None for _ in range(m)]

    def assign_qubit_indices(self, starting_index: int) -> int:
        # return the next unused index
        pass

    def initialize_quantum_state(self, t: stim.TableauSimulator, anchor_left: int, anchor_right: int, outer_emitter: int, root_ancilla: int):
        # if nothing is wrong, the four ints should be 2, 3, 4, 5
        pass


# return True if success and False if fail 
def measurements_at_absa(t: stim.TableauSimulator, m: int, left_halfs: list[Node], right_halfs: list[Node]) -> bool:
    """measurement of all qubits in the RGS (step 1)"""
    pass

def update_tree_with_outer_qubits(left_tree: Node, right_tree: Node):
    """update in place with the BSM results; toggling 1st level results with root of another tree (step 2)"""
    pass

def compute_parity_for_end_nodes(m: int, left_halfs: list[Node], right_halfs: list[Node]):
    """apply the parity at end nodes (step 3)"""
    # parity multiplied together of Z of left (right) is sent to right (left), 
    # while X of left (right) is sent to left (right).
    pass

In [None]:
def experiment_setup(
    number_of_hops: int,
    m: int,
    bv: list[int],
    loss_probability: float = 0,
    photon_error_probability: float = 0,
    emitter_error_probability: float = 0,
):
    """In this method, we assign the qubit index to all the objects we will create and manage"""
    num_of_qubits_in_half_rgs = num_qubits_per_rgs_arm(bv)
    # Ancilla qubits we require (total 4)
    #   temporary anchor for tree encoding: 1
    #   emitter for outer qubit: 1
    #   anchor for the half-RGS: 2
    # qubit range
    #   Alice:      0, 4 to (num_of_qubits_in_half_rgs + 3)
    #   Bob:        1, (num_of_qubits_in_half_rgs + 4) to (2 * num_of_qubits_in_half_rgs + 3)
    #   first RGS:  (2 * num_of_qubits_in_half_rgs + 4) to (3 * num_of_qubits_in_half_rgs + 3)
    #   ...
    alice_anchor = 0
    bob_anchor = 1
    anchor_left = 2
    anchor_right = 3
    outer_emitter = 4
    pass

In [1]:
def experiment_run():
    # setup
    # RGS creation
    # ABSA measurements
    
    """
    steps to perform one trial
    0. experiment_setup
    1. RGS creation
    2. ABSA measurements (create measurement trees)
        a. BSM
        b. Single qubit measurements
    3. Protocol (according to Sec 6 or 7?)
        a. measurement tree updated with the side effect (this is actually done prior at RGS creation)
        b. update from BSM (the pair tree)
        c. parity multiplied together of Z of left (right) is sent to right (left), 
            while X of left (right) is sent to left (right)
    """
    