In [1]:
import stim
import numpy as np
from enum import Enum

In [2]:
from helper import Node, Pauli
from tree_code_helper import tree_code_physical_measure
from rgs import HalfRGS, RGS
from rgs_probability_calculation import prob_rgs_trial

## Data structure for Node 
this is used for:
- referencing to the qubits in the protocol
- measurement tree for each RGS arm
- 

## Implementation Details

Qubits indices:
- qubit 0 and 1 are located at end nodes for anchoring half-RGS, Alice and Bob respectively. So the goal is to verify the stabilizers of qubit 0 and 1 and expecting to see XZ and ZX.
- qubit 2 and 3 are for building the biclique RGS. The RGS is generated sequentially so we can reuse the qubits.



In [3]:
def measurements_at_absa(
    t: stim.TableauSimulator, m: int, left_halfs: list[Node], right_halfs: list[Node]
) -> int:
    """measurement of all qubits in the RGS (step 1) and return the index of the arm that has a successful BSM"""
    if m != len(left_halfs) or m != len(right_halfs):
        ValueError(f'number of arms {m} does not equal the input length of list of two halves {len(left_halfs)}, {len(right_halfs)}')

    # print('------')
    # print(t.canonical_stabilizers())
    # print('------')

    # BSM part
    success_arm_index = -1
    for i in range(m):
        unode = left_halfs[i]
        vnode = right_halfs[i]
        # check lost
        if unode.is_lost or vnode.is_lost:
            continue

        u = unode.qubit_index
        v = vnode.qubit_index
        t.cz(u, v)
        t.h(u, v)
        unode.measurement_result = unode.eigenvalue = t.measure(u)
        vnode.measurement_result = vnode.eigenvalue = t.measure(v)
        unode.measurement_basis = vnode.measurement_basis = Pauli.X

        # print(f'bsm result = {unode.measurement_result, vnode.measurement_result}')

        # mark successful arm index and physical measurements of all inner qubits
        if success_arm_index == -1 and unode.measurement_result != vnode.measurement_result:
            success_arm_index = i
            tree_code_physical_measure(t, unode, True)
            tree_code_physical_measure(t, vnode, True)
        else:
            tree_code_physical_measure(t, unode, False)
            tree_code_physical_measure(t, vnode, False)

    return success_arm_index

def update_tree_with_outer_qubits(left_tree_root: Node, right_tree_root: Node):
    """update in place with the BSM results; toggling 1st level results with root of another tree (step 2)"""
    if left_tree_root.is_lost or right_tree_root.is_lost:
        RuntimeError("trying to update tree with outer qubits that were lost!")

    if left_tree_root.eigenvalue:
        for u in right_tree_root.children:
            if u.is_lost:
                continue
            u.eigenvalue = not u.eigenvalue
    if right_tree_root.eigenvalue:
        for u in left_tree_root.children:
            if u.is_lost: 
                continue
            u.eigenvalue = not u.eigenvalue

def compute_parity_for_end_nodes(m: int, left_logical_results: list[bool], right_logical_results: list[bool], successful_bsm_index: int) -> tuple[bool, bool]:
    """apply the parity at end nodes (step 3)
    return tuple of parity to be sent to the left and right respectively"""
    # parity multiplied together of Z of left (right) is sent to right (left),
    # while X of left (right) is sent to left (right).
    left_par, right_par = False, False
    for i in range(m):
        if i == successful_bsm_index:
            continue
        left_par ^= left_logical_results[i]
        right_par ^= right_logical_results[i]
    left_par ^= right_logical_results[successful_bsm_index]
    right_par ^= left_logical_results[successful_bsm_index]
    return left_par, right_par

In [4]:
total_photons = 0
lost_photons = 0

rng = np.random.default_rng()

def experiment_setup(
    number_of_hops: int,
    m: int,
    branching_parameters: list[int],
    loss_probability: float = 0,
    # photon_error_probability: float = 0,
    # emitter_error_probability: float = 0,
) -> tuple[bool, int, int]:
    """One run of the biclique RGS protocol
    Return: success-or-failure of the trial (bool), expectation value of ZX, expectation value of XZ at the end between Alice and Bob"""
    global total_photons, lost_photons
    # Ancilla qubits we require (total 4)
    #   temporary anchor for tree encoding: 1 (ancilla[0])
    #   emitter for outer qubit: 1 (ancilla[1])
    #   anchor for the half-RGS: 2 (ancilla[2-3])
    # qubit range
    alice = 0
    bob = 1
    # ancilla qubits
    anchor_left = 2
    anchor_right = 3
    outer_emitter = 4
    root_id = 5
    # starting index of unused qubit
    next_id = 6

    # we need (hop - 1) RGS
    rgss = [RGS(m, branching_parameters) for _ in range(number_of_hops - 1)]
    half_alice = HalfRGS(m, branching_parameters, alice)
    half_bob = HalfRGS(m, branching_parameters, bob)

    # assign qubit indices
    for rgs in rgss:
        next_id = rgs.assign_qubit_indices(next_id)
    next_id = half_alice.assign_qubit_indices(next_id)
    next_id = half_bob.assign_qubit_indices(next_id)

    # RGS creation
    t = stim.TableauSimulator()
    for rgs in rgss:
        rgs.initialize_quantum_state(t, anchor_left, anchor_right, outer_emitter, root_id)
    half_alice.initialize_quantum_state(t, outer_emitter, root_id)
    half_bob.initialize_quantum_state(t, outer_emitter, root_id)

    # process photon loss
    for rgs in rgss:
        rgs.process_photon_loss(t, loss_probability, rng)
    half_alice.process_photon_loss(t, loss_probability, rng)
    half_bob.process_photon_loss(t, loss_probability, rng)

    # Debugging, check how many photon got lost
    for rgs in rgss:
        lost_ph, total_ph = rgs.count_lost_photons()
        lost_photons += lost_ph
        total_photons += total_ph
    lost_ph, total_ph = half_alice.count_lost_photons()
    lost_photons += lost_ph
    total_photons += total_ph
    lost_ph, total_ph = half_bob.count_lost_photons()
    lost_photons += lost_ph
    total_photons += total_ph

    # (Protocol step 1) ABSA measurements
    success_bsm_indices = [-1] * number_of_hops  # number of ABSAs in the repeater chain
    for i in range(len(rgss) - 1):
        success_bsm_indices[i + 1] = measurements_at_absa(t, m, rgss[i].right_arms, rgss[i + 1].left_arms)
        rgss[i].successful_right_arm_index = rgss[i + 1].successful_left_arm_index = success_bsm_indices[i + 1]
    if len(rgss) > 0:
        success_bsm_indices[0] = measurements_at_absa(t, m, half_alice.arms, rgss[0].left_arms)
        success_bsm_indices[-1] = measurements_at_absa(t, m, rgss[-1].right_arms, half_bob.arms)
        rgss[0].successful_left_arm_index = half_alice.successful_arm_index = success_bsm_indices[0]
        rgss[-1].successful_right_arm_index = half_bob.successful_arm_index = success_bsm_indices[-1]
    else:
        # special case for 1 hop (no RGSS source nodes)
        # print("special case for 1 hop")
        success_bsm_indices[0] = measurements_at_absa(t, m, half_alice.arms, half_bob.arms)
        half_alice.successful_arm_index = half_bob.successful_arm_index = success_bsm_indices[0]

    # print(success_bsm_indices)
    if any(map(lambda id: id == -1, success_bsm_indices)):
        return False, None, None, None, None

    # (Protocol Step 1) Update measurements tree by assigning eigenvalues to the nodes taking side effects into account
    # imitating the classical messages received from RGSSs to ABSAs
    # we need to take note of the successful BSM arm index to denote the arm that undergone logical X measurements of inner qubits
    half_alice.update_measurement_with_side_effects()
    half_bob.update_measurement_with_side_effects()
    for i, rgs in enumerate(rgss):
        rgs.update_measurements_with_side_effect()

    # (Protocol Step 2) Propagating side effects of BSMs of outer qubits into their connected inner qubits
    # TODO: replace this with get arm methods
    bsm_arm_pairs = [half_alice.arms[success_bsm_indices[0]]]
    for i in range(len(rgss)):
        bsm_arm_pairs.append(rgss[i].left_arms[success_bsm_indices[i]])
        bsm_arm_pairs.append(rgss[i].right_arms[success_bsm_indices[i + 1]])
    bsm_arm_pairs.append(half_bob.arms[success_bsm_indices[-1]])
    for i in range(0, len(bsm_arm_pairs), 2):
        update_tree_with_outer_qubits(bsm_arm_pairs[i], bsm_arm_pairs[i + 1])

    # (Protocol Step 2/3?) Decoding logical measurements
    is_trial_successful = half_alice.decode_logical_results()
    is_trial_successful &= half_bob.decode_logical_results()
    for rgs in rgss:
        is_trial_successful &= rgs.decode_logical_results()
    if not is_trial_successful:
        return False, None, None, None, None

    # (Protocol Step 3) Compute parity at each ABSA for Pauli frame corrections
    parities: list[tuple[bool, bool]] = []
    if number_of_hops == 1:
        parities.append(compute_parity_for_end_nodes(m, half_alice.logical_results, half_bob.logical_results, success_bsm_indices[0]))
    else:
        parities.append(compute_parity_for_end_nodes(m, half_alice.logical_results, rgss[0].left_logical_results, success_bsm_indices[0]))
        for i in range(len(rgss) - 1):
            parities.append(compute_parity_for_end_nodes(m, rgss[i].right_logical_results, rgss[i + 1].left_logical_results, success_bsm_indices[i + 1]))
        parities.append(compute_parity_for_end_nodes(m, rgss[-1].right_logical_results, half_bob.logical_results, success_bsm_indices[-1]))

    # (Protocol Step 4) Combining all the parities from all ABSAs and correct at end nodes
    total_parity = (False, False)
    for l, r in parities:
        total_parity = total_parity[0] ^ l, total_parity[1] ^ r
    if total_parity[0]:
        t.z(alice)
    if total_parity[1]:
        t.z(bob)

    # verifying that we actually have bipartite graph state (XZ, ZX) stabilizers
    # print(t.peek_observable_expectation(stim.PauliString("XZ")), t.peek_observable_expectation(stim.PauliString("ZX")))
    # print(t.canonical_stabilizers())

    return True, t.peek_observable_expectation(stim.PauliString("XZ")), t.peek_observable_expectation(stim.PauliString("ZX")), t.canonical_stabilizers(), total_parity

In [62]:
required_runs = 10
num_ticks = 5
progress_marks = [int(i * required_runs / num_ticks) for i in range(1, num_ticks)]
progress_marks.append(required_runs)

fail_count = 0
good_count = 0
actual_run_count = 0

total_photons = 0
lost_photons = 0

# at loss_probability = 0.1, this means that the distance between RGSS and the ABSA is ~2.28787km

number_of_hops = 2
m = 14
# bv = [2, 3, 5]
bv = [10, 5]
# bv = [3, 3]
# photon_loss_prob = 0.05
photon_loss_prob = 0

print(f"theoretical success probability is {prob_rgs_trial(m, bv, (1-photon_loss_prob), number_of_hops)}")
print('----------------------')

# for _ in range(required_runs):
while actual_run_count < required_runs:
    is_successful, exp_xz, exp_zx, cano_stabs, parity = experiment_setup(number_of_hops, m, bv, photon_loss_prob)
    if is_successful and (exp_xz != 1 or exp_zx!= 1):
        fail_count += 1
        # print(cano_stabs)
        # print(cano_stabs[0], cano_stabs[1])
    elif is_successful and exp_zx == 1 and exp_xz == 1:
        good_count += 1
    actual_run_count += 1
    # if (actual_run_count % 10 == 0):
    if actual_run_count in progress_marks:
        print(f'    has been running for {actual_run_count} trials with {good_count} successful distribution.')

print(f'photon accounting loss_rate = {lost_photons/total_photons} ({lost_photons}, {total_photons})')
print(f'fail count = {fail_count}, good count = {good_count}')
print(f'success probability = {(fail_count + good_count) / actual_run_count}')
print(f'ran {actual_run_count} times')

theoretical success probability is 0.9998779334127903
----------------------
    has been running for 2 trials with 2 successful distribution.
    has been running for 4 trials with 4 successful distribution.
    has been running for 6 trials with 6 successful distribution.
    has been running for 8 trials with 8 successful distribution.
    has been running for 10 trials with 10 successful distribution.
photon accounting loss_rate = 0.0 (0, 34160)
fail count = 0, good count = 10
success probability = 1.0
ran 10 times


In [6]:
# required_runs = 10

# fail_count = 0
# good_count = 0

# actual_run_count = 0

# ## for 2 hops, the protocol always fail, no entanglement is created at all
# ## when printed at ABSA before measurements, there seems to be no RGS generated?

# # at loss_probability = 0.1, this means that the distance between RGSS and the ABSA is ~2.28787km

# # for _ in range(required_runs):
# while good_count < required_runs:
#     is_successful, exp_xz, exp_zx, cano_stabs, parity = experiment_setup(4, 3, [10, 5], loss_probability=0.1)
#     if is_successful and (exp_xz != 1 or exp_zx!= 1):
#         fail_count += 1
#     elif is_successful and exp_zx == 1 and exp_xz == 1:
#         good_count += 1
#     actual_run_count += 1
#     if (actual_run_count % 10 == 0):
#         print(f'    has been running for {actual_run_count} trials with {good_count} successful distribution.')

# print(f'fail count = {fail_count}, good count = {good_count}')
# print(f'success probability = {(fail_count + good_count) / actual_run_count}')
# print(f'ran {actual_run_count} times')

steps that we can test
1. test whether the creation actually creates the graph state (up to Z)
2. 

In [7]:
def experiment_run():
    # setup
    # RGS creation
    # ABSA measurements
    
    """
    steps to perform one trial
    0. experiment_setup
    1. RGS creation --> store side effect (has Z)
    2. process photon loss
    3. ABSA measurements (create measurement trees) --> store the measurement value (raw result)
        a. BSM
        b. Single qubit measurements
    4. Protocol (according to Sec 6 or 7?)
        a. measurement tree updated with the side effect (this is actually done prior at RGS creation) --> store the eigenvalue (from raw result and side effect)
        b. update from BSM (the pair tree) --> 
        c. parity multiplied together of Z of left (right) is sent to right (left), 
            while X of left (right) is sent to left (right)
    """
    

In [8]:
import random

N = 1464
M = 0
for i in range(N):
    if random.random() < photon_loss_prob:
        M += 1
print(f'ratio {M/N} ({M}, {N})')

ratio 0.0 (0, 1464)
