In [2]:
import stim
import numpy as np
import random

## Data structure for Node 
this is used for:
- referencing to the qubits in the protocol
- measurement tree for each RGS arm
- 

## Implementation Details

Qubits indices:
- qubit 0 and 1 are located at end nodes for anchoring half-RGS, Alice and Bob respectively. So the goal is to verify the stabilizers of qubit 0 and 1 and expecting to see XZ and ZX.
- qubit 2 and 3 are for building the biclique RGS. The RGS is generated sequentially so we can reuse the qubits.



In [3]:
### helper functions

def num_qubits_per_rgs_arm(bvec: list[int]) -> int:
    num_in_layers = bvec[:]
    for i in range(1, len(num_in_layers)):
        num_in_layers[i] *= num_in_layers[i-1]
    num_qubits_per_arm = 1 + np.sum(num_in_layers)
    return num_qubits_per_arm

In [4]:
from enum import Enum

class Node:
    def __init__(self, qubit_index=-1, parent_index=-1):
        self.qubit_index = qubit_index  # for debugging
        self.parent_index = parent_index  # for debugging
        self.measurement_result: bool | None = (
            None  # this should be True and False if the qubit has been measured
        )
        self.eigenvalue: bool = False # use this to toggle and also to 
        self.children: list[Node] = []
        self.is_lost = False  # this is used to denote whether the qubit is lost in the fiber or not
        self.has_z = False  # this is used to denote whether the qubit has Z side effect from the emission process or not

    def get_post_order(self):
        return_list: list[Node] = []
        for u in self.children:
            return_list.extend(u.descendents_post_order())
        return [self] + return_list

    def get_level_traversal(self):
        return_list = [self]
        queue = [self]
        while len(queue) > 0:
            new_queue = []
            for u in queue:
                return_list.append(u)
                new_queue.extend(u.children)
            queue = new_queue
        return return_list

    def add_z_side_effects(self, circuit: stim.Circuit, z_prob=0.5):
        if len(self.children) == 0:
            return
        for u in self.children:
            u.add_z_side_effects(circuit, z_prob)
        if self.parent_index == -1:
            return
        if random.random() < z_prob:
            circuit.append("Z", self.qubit_index)
            self.has_z = not self.has_z

    def get_indices_from_level(self, k: int) -> list[int]:
        # Not that we need it right now?
        """get all the indices from the nodes"""
        cur_level = 0
        queue = [self]
        while len(queue) > 0:
            # arrive at the correct level, return the qubit indices
            if cur_level == k:
                return [v.qubit_index for v in queue]
            # need to go to the next level
            num_nodes_in_level = len(queue)
            for _ in range(num_nodes_in_level):
                u = queue.pop(0)
                queue.extend(u.children)
            cur_level += 1
        # should change this to an error
        return []  # empty list indicating the level specified is out of range

In [10]:
from enum import Enum

# functional syntax
Pauli = Enum('Basis', ['I', 'X', 'Y', 'Z'])

def helper_assign_qubit_indices(root: Node, bv: list[int], starting_index: int) -> int:
    cur_index = starting_index
    root.qubit_index = cur_index
    cur_index += 1
    queue = [root]
    for bi in bv:
        temp_queue = []
        for u in queue:
            for v in range(bi):
                v = Node(cur_index, u.qubit_index)
                cur_index += 1
                u.children.append(v)
                temp_queue.append(v)
        queue = temp_queue
    # u = self.arms[0]
    # for i in range(len(self.bv) + 1):
    #     vs = u.get_indices_from_level(i)
    #     print(f'{i} - {vs}')

    # return the next unused index
    return cur_index

def helper_initialize_rgs_arm(t: stim.TableauSimulator, root: Node, anchor: int, outer_emitter: int, root_ancilla: int) -> bool:
    # return whether the anchor should be flipped or not (side effects to the anchor)

    # make sure the qubits are properly initialized
    t.reset(outer_emitter, root_ancilla)

    # joining root with the first level nodes in the tree
    queue = root.children # nodes in the first level
    for u in queue:
        t.h(u.qubit_index)
    
    # outer qubit generation
    t.h(root.qubit_index, outer_emitter)
    t.cz(root.qubit_index, outer_emitter)
    
    # generate inner qubit tree
    t.h(root_ancilla)
    for u in queue:
        t.cz(root_ancilla, u.qubit_index)
    # assuming that the anchor is already has Hadamard applied
    while len(queue) > 0:
        temp_queue = []
        for u in queue:
            for v in u.children:
                t.h(v.qubit_index)
                t.cz(u.qubit_index, v.qubit_index)
                temp_queue.append(v)
        queue = temp_queue
    
    # add random side effects to nodes in the tree except the leaves
    queue = root.children
    while len(queue) > 0:
        temp_queue = []
        for u in queue:
            if len(u.children) == 0:
                break
            if random.random() < 0.5:
                t.z(u.qubit_index)
                u.has_z = not u.has_z
            temp_queue.extend(u.children)
        queue = temp_queue

    # join inner and outer qubits
    t.cz(anchor, outer_emitter)
    t.cz(root_ancilla, outer_emitter)
    t.h(outer_emitter, root_ancilla)
    meas_outer = t.measure(outer_emitter)
    meas_root = t.measure(root_ancilla)
    if meas_outer:
        # flip first level qubits
        for u in root.children:
            u.has_z = not u.has_z

    if meas_root:
        # flip outer qubits (and return the flip to the anchor)
        root.has_z = not root.has_z

    return meas_root


class HalfRGS:
    """Currently this is only used at end nodes"""
    def __init__(self, m: int, branching_params: list[int], anchor_index: int):
        self.m = m
        self.bv = branching_params
        self.arms = [Node() for _ in range(m)]
        self.anchor = anchor_index
        self.measurement_bases: list[Pauli | None] = [None for _ in range(m)]
    
    def assign_qubit_indices(self, starting_index: int) -> int:
        cur_index = starting_index
        for root in self.arms:
            cur_index = helper_assign_qubit_indices(root, self.bv, cur_index)
        # return the next unused index
        return cur_index

    def initialize_quantum_state(self, t: stim.TableauSimulator, outer_emitter: int, root_ancilla: int):
        anchor_has_z = False
        t.h(self.anchor)
        for root in self.arms:
            anchor_has_z = anchor_has_z ^ helper_initialize_rgs_arm(t, root, self.anchor, outer_emitter, root_ancilla)
        if anchor_has_z:
            t.z(self.anchor)

class RGS:
    def __init__(self, m: int, branching_params: list[int]):
        self.m = m
        self.bv = branching_params
        self.left_arms = [Node() for _ in range(m)]
        self.right_arms = [Node() for _ in range(m)]
        self.measurement_bases_left: list[Pauli | None] = [None for _ in range(m)]
        self.measurement_bases_right: list[Pauli | None] = [None for _ in range(m)]

    def assign_qubit_indices(self, starting_index: int) -> int:
        cur_index = starting_index
        for root in self.left_arms:
            cur_index = helper_assign_qubit_indices(root, self.bv, cur_index)
        for root in self.right_arms:
            cur_index = helper_assign_qubit_indices(root, self.bv, cur_index)
        # return the next unused index
        return cur_index

    def initialize_quantum_state(self, t: stim.TableauSimulator, anchor_left: int, anchor_right: int, outer_emitter: int, root_ancilla: int):
        # make sure the qubits are properly initialized
        t.reset(anchor_left, anchor_right)

        # generate the left arms
        anchor_has_z = False
        t.h(anchor_left)
        for root in self.left_arms:
            anchor_has_z = anchor_has_z ^ helper_initialize_rgs_arm(t, root, anchor_left, outer_emitter, root_ancilla)
        if anchor_has_z:
            t.z(anchor_left)

        # generate the right arms
        anchor_has_z = False
        t.h(anchor_right)
        for root in self.left_arms:
            anchor_has_z = anchor_has_z ^ helper_initialize_rgs_arm(t, root, anchor_right, outer_emitter, root_ancilla)
        if anchor_has_z:
            t.z(anchor_right)

        # join the two halves
        t.cz(anchor_left, anchor_right)
        t.h(anchor_left, anchor_right)
        meas_left = t.measure(anchor_left)
        meas_right = t.measure(anchor_right)

        # tracking the side effects (toggling first level nodes of all arms)
        if meas_left:
            for root in self.right_arms:
                for u in root.children:
                    u.has_z = not u.has_z
        if meas_right:
            for root in self.left_arms:
                for u in root.children:
                    u.has_z = not u.has_z


def process_photon_loss(t: stim.TableauSimulator, root: Node, loss_probability: float):
    """traverse the tree and apply loss probability to all qubits
    If a qubit is lost, randomly select Pauli X, Y, or Z to apply followed by a measurement in the Z basis.
    """
    queue = [root]
    while len(queue) > 0:
        temp_queue = []
        for u in queue:
            temp_queue.extend(u.children)
            if random.random() < loss_probability:
                q = u.qubit_index
                u.is_lost = True
                pauli_op = random.choice(["I", "X", "Y", "Z"])
                if pauli_op == "X":
                    t.x(q)
                elif pauli_op == "Y":
                    t.y(q)
                elif pauli_op == "Z":
                    t.z(q)
                t.measure(q)
        queue = temp_queue


def measurements_at_absa(t: stim.TableauSimulator, m: int, left_halfs: list[Node], right_halfs: list[Node]) -> int:
    """measurement of all qubits in the RGS (step 1)"""
    # return val is the arm index that passes the BSM
    if m != len(left_halfs) or m != len(right_halfs):
        ValueError(f'number of arms {m} does not equal the input length of list of two halves {len(left_halfs)}, {len(right_halfs)}')

    success_arm_index = -1
    for i in range(m):
        unode = left_halfs[i]
        vnode = right_halfs[i]
        # check lost
        if unode.is_lost or vnode.is_lost:
            continue

        u = unode.qubit_index
        v = vnode.qubit_index
        t.cz(u, v)
        t.h(u, v)
        unode.measurement_result = unode.eigenvalue = t.measure(u)
        vnode.measurement_result = vnode.eigenvalue = t.measure(v)
        
        if success_arm_index != -1 and unode.measurement_result != vnode.measurement_result:
            success_arm_index = i

    return success_arm_index


def update_tree_with_outer_qubits(left_tree: Node, right_tree: Node):
    """update in place with the BSM results; toggling 1st level results with root of another tree (step 2)"""
    pass

def compute_parity_for_end_nodes(m: int, left_halfs: list[Node], right_halfs: list[Node]):
    """apply the parity at end nodes (step 3)"""
    # parity multiplied together of Z of left (right) is sent to right (left), 
    # while X of left (right) is sent to left (right).
    pass

In [None]:
def experiment_setup(
    number_of_hops: int,
    m: int,
    branching_parameters: list[int],
    loss_probability: float = 0,
    # photon_error_probability: float = 0,
    # emitter_error_probability: float = 0,
):
    """In this method, we assign the qubit index to all the objects we will create and manage"""
    num_of_qubits_in_half_rgs = num_qubits_per_rgs_arm(branching_parameters)
    # Ancilla qubits we require (total 4)
    #   temporary anchor for tree encoding: 1 (ancilla[0])
    #   emitter for outer qubit: 1 (ancilla[1])
    #   anchor for the half-RGS: 2 (ancilla[2-3])
    # qubit range
    #   Alice:      0, 4 to (num_of_qubits_in_half_rgs + 3)
    #   Bob:        1, (num_of_qubits_in_half_rgs + 4) to (2 * num_of_qubits_in_half_rgs + 3)
    #   first RGS:  (2 * num_of_qubits_in_half_rgs + 4) to (3 * num_of_qubits_in_half_rgs + 3)
    #   ...
    alice = 0
    bob = 1
    # just for accounting at the moment
    # anchor_left = 2
    # anchor_right = 3
    # outer_emitter = 4
    # root_id = 5
    next_id = 6

    # we need (hop - 1) RGS
    rgss = [RGS(m, branching_parameters) for _ in range(number_of_hops - 1)]
    half_alice = HalfRGS(m, branching_parameters, alice)
    half_bob = HalfRGS(m, branching_parameters, bob)

    for rgs in rgss:
        next_id = rgs.assign_qubit_indices(next_id)
    next_id = half_alice.assign_qubit_indices(next_id)
    next_id = half_bob.assign_qubit_indices(next_id)


In [1]:
def experiment_run():
    # setup
    # RGS creation
    # ABSA measurements
    
    """
    steps to perform one trial
    0. experiment_setup
    1. RGS creation
    2. process photon loss
    3. ABSA measurements (create measurement trees)
        a. BSM
        b. Single qubit measurements
    4. Protocol (according to Sec 6 or 7?)
        a. measurement tree updated with the side effect (this is actually done prior at RGS creation)
        b. update from BSM (the pair tree)
        c. parity multiplied together of Z of left (right) is sent to right (left), 
            while X of left (right) is sent to left (right)
    """
    