In [1]:
from typing import List, Callable, TYPE_CHECKING
from pathlib import Path
from copy import copy, deepcopy
from pprint import pprint

if TYPE_CHECKING:
    from src.components.photon import Photon

from json import dump
import numpy as np

from src.kernel.event import Event
from src.kernel.process import Process
from src.kernel.timeline import Timeline
from src.kernel.quantum_manager import FOCK_DENSITY_MATRIX_FORMALISM
from src.components.detector import QSDetectorFockDirect, QSDetectorFockInterference
from src.components.light_source import SPDCSource
from src.components.memory import AbsorptiveMemory
from src.components.optical_channel import QuantumChannel
from src.components.photon import Photon
from src.topology.node import Node
from src.protocol import Protocol
from src.kernel.quantum_utils import *  # only for manual calculation and should not be used in simulation

## Parameters and Utils

In [2]:
# quantum manager
TRUNCATION = 1  # truncation of Fock space (=dimension-1)

# photon sources
TELECOM_WAVELENGTH = 1436  # telecom band wavelength of SPDC source idler photon
WAVELENGTH = 606  # wavelength of AFC memory resonant absorption, of SPDC source signal photon
SPDC_FREQUENCY = 80e6  # frequency of both SPDC sources' photon creation (same as memory frequency and detector count rate)
MEAN_PHOTON_NUM1 = 0.1  # mean photon number of SPDC source on node 1
MEAN_PHOTON_NUM2 = 0.1  # mean photon number of SPDC source on node 2

# detectors
BSM_DET1_EFFICIENCY = 0.6  # efficiency of detector 1 of BSM
BSM_DET2_EFFICIENCY = 0.6  # efficiency of detector 2 of BSM
BSM_DET1_DARK = 150  # Dark count rate (Hz)
BSM_DET2_DARK = 150
MEAS_DET1_EFFICIENCY = 0.6  # efficiency of detector 1 of DM measurement
MEAS_DET2_EFFICIENCY = 0.6  # efficiency of detector 2 of DM measurement
MEAS_DET1_DARK = 150
MEAS_DET2_DARK = 150

# fibers
DIST_ANL_ERC = 20  # distance between ANL and ERC, in km
DIST_HC_ERC = 20  # distance between HC and ERC, in km
ATTENUATION = 0.2  # attenuation rate of optical fibre (in dB/km)
DELAY_CLASSICAL = 5e-3  # delay for classical communication between BSM node and memory nodes (in s)

# memories
MODE_NUM = 100  # number of temporal modes of AFC memory (same for both memories)
MEMO_FREQUENCY1 = SPDC_FREQUENCY  # frequency of memory 1
MEMO_FREQUENCY2 = SPDC_FREQUENCY  # frequency of memory 2
ABS_EFFICIENCY1 = 0.35  # absorption efficiency of AFC memory 1
ABS_EFFICIENCY2 = 0.35  # absorption efficiency of AFC memory 2
PREPARE_TIME1 = 0  # time required for AFC structure preparation of memory 1
PREPARE_TIME2 = 0  # time required for AFC structure preparation of memory 2
COHERENCE_TIME1 = -1  # spin coherence time for AFC memory 1 spinwave storage, -1 means infinite time
COHERENCE_TIME2 = -1  # spin coherence time for AFC memory 2 spinwave storage, -1 means infinite time
AFC_LIFETIME1 = -1  # AFC structure lifetime of memory 1, -1 means infinite time
AFC_LIFETIME2 = -1  # AFC structure lifetime of memory 2, -1 means infinite time
DECAY_RATE1 = 4.3e-8  # retrieval efficiency decay rate for memory 1
DECAY_RATE2 = 4.3e-8  # retrieval efficiency decay rate for memory 2

# experiment settings
time = int(1e12)
calculate_fidelity_direct = True
calculate_rate_direct = True
num_direct_trials = 200
num_bs_trials_per_phase = 50
phase_settings = np.linspace(0, 2*np.pi, num=20, endpoint=False)


# function to generate standard pure Bell state for fidelity calculation
def build_bell_state(truncation, sign, phase=0, formalism="dm"):
    """Generate standard Bell state which is heralded in ideal BSM.

    For comparison with results from imperfect parameter choices.
    """

    basis0 = np.zeros(truncation+1)
    basis0[0] = 1
    basis1 = np.zeros(truncation+1)
    basis1[1] = 1
    basis10 = np.kron(basis1, basis0)
    basis01 = np.kron(basis0, basis1)
    
    if sign == "plus":
        ket = (basis10 + np.exp(1j*phase)*basis01)/np.sqrt(2)
    elif sign == "minus":
        ket = (basis10 - np.exp(1j*phase)*basis01)/np.sqrt(2)
    else:
        raise ValueError("Invalid Bell state sign type " + sign)

    dm = np.outer(ket, ket.conj())

    if formalism == "dm":
        return dm
    elif formalism == "ket":
        return ket
    else:
        raise ValueError("Invalid quantum state formalism " + formalism)


# retrieval efficiency as function of storage time for absorptive quantum memory, using exponential decay model
def efficiency1(t: int) -> float:
    return np.exp(-t*DECAY_RATE1)


def efficiency2(t: int) -> float:
    return np.exp(-t*DECAY_RATE2)


def add_channel(node1: Node, node2: Node, timeline: Timeline, **kwargs):
    name = "_".join(["qc", node1.name, node2.name])
    qc = QuantumChannel(name, timeline, **kwargs)
    qc.set_ends(node1, node2.name)
    return qc


# protocol to control photon emission on end node
class EmitProtocol(Protocol):
    def __init__(self, own: "EndNode", name: str, other_node: str, photon_pair_num: int,
                 source_name: str, memory_name: str):
        """Constructor for Emission protocol.

        Args:
            own (EndNode): node on which the protocol is located.
            name (str): name of the protocol instance.
            other_node (str): name of the other node to generate entanglement with
            photon_pair_num (int): number of output photon pulses to send in one execution.
            source_name (str): name of the light source on the node.
            memory_name (str): name of the memory on the node.
        """

        super().__init__(own, name)
        self.other_node = other_node
        self.num_output = photon_pair_num
        self.source_name = source_name
        self.memory_name = memory_name

    def start(self):
        if not self.own.components[self.memory_name].is_prepared:
            self.own.components[self.memory_name]._prepare_AFC()
        
        states = [None] * self.num_output  # for Fock encoding only list length matters and list elements do not matter
        source = self.own.components[self.source_name]
        source.emit(states)

    def received_message(self, src: str, msg):
        pass


class EndNode(Node):
    """Node for each end of the network (the memory node).

    This node stores an SPDC photon source and a quantum memory.
    The properties of attached devices are made customizable for each individual node.
    """

    def __init__(self, name: str, timeline: "Timeline", other_node: str, bsm_node: str, measure_node: str,
                 mean_photon_num: float, spdc_frequency: float, memo_frequency: float, abs_effi: float,
                 afc_efficiency: Callable, mode_number: int):
        super().__init__(name, timeline)

        self.bsm_name = bsm_node
        self.meas_name = measure_node

        # hardware setup
        self.spdc_name = name + ".spdc_source"
        self.memo_name = name + ".memory"
        spdc = SPDCSource(self.spdc_name, timeline, wavelengths=[TELECOM_WAVELENGTH, WAVELENGTH],
                          frequency=spdc_frequency, mean_photon_num=mean_photon_num)
        memory = AbsorptiveMemory(self.memo_name, timeline, frequency=memo_frequency,
                                  absorption_efficiency=abs_effi, afc_efficiency=afc_efficiency,
                                  mode_number=mode_number, wavelength=WAVELENGTH, destination=measure_node)
        self.add_component(spdc)
        self.add_component(memory)
        spdc.add_receiver(self)
        spdc.add_receiver(memory)
        memory.add_receiver(self)

        # protocols
        self.emit_protocol = EmitProtocol(self, name + ".emit_protocol", other_node, mode_number, self.spdc_name, self.memo_name)

    def get(self, photon: "Photon", **kwargs):
        dst = kwargs.get("dst")
        if dst is None:
            # from spdc source: send to bsm node
            self.send_qubit(self.bsm_name, photon)
        else:
            # from memory: send to destination (measurement) node
            self.send_qubit(dst, photon)


class EntangleNode(Node):
    def __init__(self, name: str, timeline: "Timeline", src_list: List[str]):
        super().__init__(name, timeline)

        # hardware setup
        self.bsm_name = name + ".bsm"
        # assume no relative phase between two input optical paths
        bsm = QSDetectorFockInterference(self.bsm_name, timeline, src_list)
        self.add_component(bsm)
        bsm.attach(self)
        self.set_first_component(self.bsm_name)
        self.resolution = max([d.time_resolution for d in bsm.detectors])

        # detector parameter setup
        bsm.set_detector(0, efficiency=BSM_DET1_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=BSM_DET1_DARK)
        bsm.set_detector(1, efficiency=BSM_DET2_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=BSM_DET1_DARK)

    def receive_qubit(self, src: str, qubit) -> None:
        self.components[self.first_component_name].get(qubit, src=src)

    def get_detector_entries(self, detector_name: str, start_time: int, num_bins: int, frequency: float):
        """Returns detection events for density matrix measurement. Used to determine BSM result.

        Args:
            detector_name (str): name of detector to get measurements from.
            start_time (int): simulation start time of when photons received.
            num_bins (int): number of arrival bins
            frequency (float): frequency of photon arrival (in Hz).

        Returns:
            List[int]: list of length (num_bins) with result for each time bin.
        """

        trigger_times = self.components[detector_name].get_photon_times()
        return_res = [0] * num_bins

        for time in trigger_times[0]:
            closest_bin = int(round((time - start_time) * frequency * 1e-12))
            expected_time = (float(closest_bin) * 1e12 / frequency) + start_time
            if abs(expected_time - time) < self.resolution and 0 <= closest_bin < num_bins:
                return_res[closest_bin] += 1

        for time in trigger_times[1]:
            closest_bin = int(round((time - start_time) * frequency * 1e-12))
            expected_time = (float(closest_bin) * 1e12 / frequency) + start_time
            if abs(expected_time - time) < self.resolution and 0 <= closest_bin < num_bins:
                return_res[closest_bin] += 2

        return return_res


class MeasureNode(Node):
    def __init__(self, name: str, timeline: "Timeline", other_nodes: List[str]):
        super().__init__(name, timeline)

        self.direct_detector_name = name + ".direct"
        direct_detector = QSDetectorFockDirect(self.direct_detector_name, timeline, other_nodes)
        self.add_component(direct_detector)
        direct_detector.attach(self)

        self.bs_detector_name = name + ".bs"
        bs_detector = QSDetectorFockInterference(self.bs_detector_name, timeline, other_nodes)
        self.add_component(bs_detector)
        bs_detector.add_receiver(self)

        self.set_first_component(self.direct_detector_name)

        # time resolution of SPDs
        self.resolution = max([d.time_resolution for d in direct_detector.detectors + bs_detector.detectors])

        # detector parameter setup
        direct_detector.set_detector(0, efficiency=MEAS_DET1_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=MEAS_DET1_DARK)
        direct_detector.set_detector(1, efficiency=MEAS_DET2_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=MEAS_DET2_DARK)
        bs_detector.set_detector(0, efficiency=MEAS_DET1_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=MEAS_DET1_DARK)
        bs_detector.set_detector(1, efficiency=MEAS_DET2_EFFICIENCY, count_rate=SPDC_FREQUENCY, dark_count=MEAS_DET2_DARK)

    def receive_qubit(self, src: str, qubit) -> None:
        self.components[self.first_component_name].get(qubit, src=src)

    def set_phase(self, phase: float):
        self.components[self.bs_detector_name].set_phase(phase)

    def get_detector_entries(self, detector_name: str, start_time: int, num_bins: int, frequency: float):
        """Returns detection events for density matrix measurement.

        Args:
            detector_name (str): name of detector to get measurements from.
            start_time (int): simulation start time of when photons received.
            num_bins (int): number of arrival bins
            frequency (float): frequency of photon arrival (in Hz).

        Returns:
            List[int]: list of length (num_bins) with result for each time bin.
        """

        trigger_times = self.components[detector_name].get_photon_times()
        return_res = [0] * num_bins

        for time in trigger_times[0]:
            closest_bin = int(round((time - start_time) * frequency * 1e-12))
            expected_time = (float(closest_bin) * 1e12 / frequency) + start_time
            if abs(expected_time - time) < self.resolution and 0 <= closest_bin < num_bins:
                return_res[closest_bin] += 1

        for time in trigger_times[1]:
            closest_bin = int(round((time - start_time) * frequency * 1e-12))
            expected_time = (float(closest_bin) * 1e12 / frequency) + start_time
            if abs(expected_time - time) < self.resolution and 0 <= closest_bin < num_bins:
                return_res[closest_bin] += 2

        return return_res
    
tl = Timeline(time, formalism=FOCK_DENSITY_MATRIX_FORMALISM, truncation=TRUNCATION)

anl_name = "Argonne"
hc_name = "Harper Court"
erc_name = "Eckhardt Research Center BSM"
erc_2_name = "Eckhardt Research Center Measurement"
seeds = [1, 2, 3, 4]
src_list = [anl_name, hc_name]  # the list of sources, note the order

anl = EndNode(anl_name, tl, hc_name, erc_name, erc_2_name, mean_photon_num=MEAN_PHOTON_NUM1,
                spdc_frequency=SPDC_FREQUENCY, memo_frequency=MEMO_FREQUENCY1, abs_effi=ABS_EFFICIENCY1,
                afc_efficiency=efficiency1, mode_number=MODE_NUM)
hc = EndNode(hc_name, tl, anl_name, erc_name, erc_2_name, mean_photon_num=MEAN_PHOTON_NUM2,
                spdc_frequency=SPDC_FREQUENCY, memo_frequency=MEMO_FREQUENCY2, abs_effi=ABS_EFFICIENCY2,
                afc_efficiency=efficiency2, mode_number=MODE_NUM)
erc = EntangleNode(erc_name, tl, src_list)
erc_2 = MeasureNode(erc_2_name, tl, src_list)

for seed, node in zip(seeds, [anl, hc, erc, erc_2]):
    node.set_seed(seed)

# extend fiber lengths to be equivalent
fiber_length = max(DIST_ANL_ERC, DIST_HC_ERC)

qc1 = add_channel(anl, erc, tl, distance=fiber_length, attenuation=ATTENUATION)
qc2 = add_channel(hc, erc, tl, distance=fiber_length, attenuation=ATTENUATION)
qc3 = add_channel(anl, erc_2, tl, distance=fiber_length, attenuation=ATTENUATION)
qc4 = add_channel(hc, erc_2, tl, distance=fiber_length, attenuation=ATTENUATION)

tl.init()

# calculate start time for protocol
# since fiber lengths equal, both start at 0
start_time_anl = start_time_hc = 0

# calculations for when to start recording measurements
delay_anl = anl.qchannels[erc_2_name].delay
delay_hc = hc.qchannels[erc_2_name].delay
assert delay_anl == delay_hc
start_time_bsm = start_time_anl + delay_anl
mem = anl.components[anl.memo_name]
total_time = mem.total_time
start_time_meas = start_time_anl + total_time + delay_anl

results_direct_measurement = []
results_bs_measurement = [[] for _ in phase_settings]


## Main function

### Setup


In [3]:

"""Pre-simulation explicit calculation of effective entanglement fidelity upon successful BSM"""

# use non-transmitted Photon as interface with existing methods in SeQUeNCe
spdc_anl = anl.components[anl.spdc_name]
spdc_hc = hc.components[hc.spdc_name]
memo_anl = anl.components[anl.memo_name]
memo_hc = hc.components[hc.memo_name]
channel_anl = anl.qchannels[erc_name]
channel_hc = hc.qchannels[erc_name]
bsm = erc.components[erc.bsm_name]



### Initial state
Here, we create the initial state for the simulation (TMSV state). How it works:
\begin{equation}
\frac{1}{\sqrt{\mu + 1}}\sum^{\infty}_{n=0}{\left(\sqrt{\frac{\mu}{\mu + 1}}\right)^n}\ket{n,n}
\end{equation}
(See the notes to see how this came about. Straightforward)
We create the array of coefficients first. Now, we want to make the $\ket{n,n}$ state. It is basically just the $\ket{n}\otimes\ket{n}$ state where $\ket{n}$ is simply the basis state where the $n^{th}$ element is 1 and all others are 0. We use the truncation here ssince we cannot have all states upto infinity. So, we ake the truncation+1 as the size of the array (+1 for the case vacuum state). \
\
See the _generate_tmsv_state function for reference. \
\
We take those states and assign them to the photon objects at the ANL and HC nodes.

In [4]:
# set shared state to squeezed state
# photon0: idler, photon1: signal
photon0_anl = Photon("", tl, wavelength=spdc_anl.wavelengths[0], location=spdc_anl,
                        encoding_type=spdc_anl.encoding_type, use_qm=True)
photon1_anl = Photon("", tl, wavelength=spdc_anl.wavelengths[1], location=spdc_anl,
                        encoding_type=spdc_anl.encoding_type, use_qm=True)

state_spdc_anl = spdc_anl._generate_tmsv_state()
keys = [photon0_anl.quantum_state, photon1_anl.quantum_state]
tl.quantum_manager.set(keys, state_spdc_anl)

print("ANL TMSV state:", state_spdc_anl)
print("density matrix:")
print(tl.quantum_manager.states[keys[0]])

photon0_hc = Photon("", tl, wavelength=spdc_hc.wavelengths[0], location=spdc_hc,
                    encoding_type=spdc_hc.encoding_type, use_qm=True)
photon1_hc = Photon("", tl, wavelength=spdc_hc.wavelengths[1], location=spdc_hc,
                    encoding_type=spdc_hc.encoding_type, use_qm=True)

# set shared state to squeezed state
state_spdc_hc = spdc_hc._generate_tmsv_state()
keys = [photon0_hc.quantum_state, photon1_hc.quantum_state]
tl.quantum_manager.set(keys, state_spdc_hc)



ANL TMSV state: [0.95346259 0.         0.         0.30151134]
density matrix:
Keys:
[0, 1]
State:
[[0.90909091+0.j 0.        +0.j 0.        +0.j 0.28747979+0.j]
 [0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j]
 [0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j]
 [0.28747979+0.j 0.        +0.j 0.        +0.j 0.09090909+0.j]]


Add losses due to transmission and memory functions. We simulate the effect of a beamsplitter (refer the notes) where we find the transmissivity of the channel and use that ratio to find the beamsplitter angle. The eventual picture is a linear combination (superposition) of binomial coeffieicents. Compare the output state after applying the loss krauss operators.\
\
You can see from the output state that new coherences were created as a result of the loss operation and the diagonal elements have all decreased and new diagonal entries have been created

In [5]:

# photon loss upon absorption by memories
key_anl = photon1_anl.quantum_state
loss_anl = 1 - memo_anl.absorption_efficiency
tl.quantum_manager.add_loss(key_anl, loss_anl, verbose = True)
key_hc = photon1_hc.quantum_state
loss_hc = 1 - memo_hc.absorption_efficiency
tl.quantum_manager.add_loss(key_hc, loss_hc)



# transmission loss through optical fibres
key_anl = photon0_anl.quantum_state
loss_anl = channel_anl.loss
tl.quantum_manager.add_loss(key_anl, loss_anl)
key_hc = photon0_hc.quantum_state
loss_hc = channel_anl.loss
tl.quantum_manager.add_loss(key_hc, loss_hc)



[[0.90909091+0.j 0.        +0.j 0.        +0.j 0.17007534+0.j]
 [0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j]
 [0.        +0.j 0.        +0.j 0.05909091+0.j 0.        +0.j]
 [0.17007534+0.j 0.        +0.j 0.        +0.j 0.03181818+0.j]]


Finally, we perform the measurements. The results are ordered as such:\
The first matrix corresponds to the state when none of the detectors clicked. This means that the entanglement was unsuccessful. Hence, the density matrix is only diagonal. \
The second and third matrices correspond to one of the detectors clicking. This is the main entangled states that we are after. 

In [13]:
# QSDetector measurement
povms = bsm.povms[0]
povm_tuple = tuple([tuple(map(tuple, povm)) for povm in povms])
keys = [photon0_anl.quantum_state, photon0_hc.quantum_state]
new_state, all_keys = tl.quantum_manager._prepare_state(keys)
indices = tuple([all_keys.index(key) for key in keys])
state_tuple = tuple(map(tuple, new_state))
print(type(state_tuple), type(indices), type(len(all_keys)), len(povm_tuple),type(tl.quantum_manager.truncation))
states, probs = measure_multiple_with_cache_fock_density(state_tuple, indices, len(all_keys), povm_tuple,tl.quantum_manager.truncation)


# effective Bell state generated 
def effective_state(state):
    state_copy = copy(state)
    state_copy[0][0] = 0
    state_copy = state_copy/np.trace(state_copy)
    
    return state_copy


for state,prob in zip(states, probs):
    indices = tuple([all_keys.index(key) for key in keys])
    new_state_tuple = tuple(map(tuple, state))
    remaining_state = density_partial_trace(new_state_tuple, indices, len(all_keys),
                                            tl.quantum_manager.truncation)

    remaining_state_eff = effective_state(remaining_state)
    print("state with probability:", prob)
    pprint(remaining_state_eff)
state_plus, state_minus = states[1], states[2]

# calculate remaining state
indices = tuple([all_keys.index(key) for key in keys])
new_state_tuple = tuple(map(tuple, state_plus))
remaining_state = density_partial_trace(new_state_tuple, indices, len(all_keys),
                                        tl.quantum_manager.truncation)
remaining_keys = [key for key in all_keys if key not in keys]

print("intermediate state:")
print(remaining_state)

remaining_state_eff = effective_state(remaining_state)

# calculate the fidelity with reference Bell state
bell_plus = build_bell_state(tl.quantum_manager.truncation, "plus")
bell_minus = build_bell_state(tl.quantum_manager.truncation, "minus")
print("bell_plus state:")
print(bell_plus)

print("bell_minus state:")
print(bell_minus)

fidelity = np.trace(remaining_state_eff.dot(bell_minus)).real

print("Directly calculated effective fidelity:", fidelity)


<class 'tuple'> <class 'tuple'> <class 'int'> 4 <class 'int'>
state with probability: 0.9570416653191935
array([[0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.49373356+0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        -0.j, 0.49373356-0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, 0.01253289-0.j]])
state with probability: 0.021243398105161033
array([[ 0.        +0.j,  0.        +0.j,  0.        +0.j,  0.        +0.j],
       [ 0.        +0.j,  0.48816629+0.j, -0.44421251+0.j,  0.        +0.j],
       [ 0.        +0.j, -0.44421251+0.j,  0.48816629+0.j,  0.        +0.j],
       [ 0.        +0.j,  0.        +0.j,  0.        +0.j,  0.02366742+0.j]])
state with probability: 0.021243398105161033
array([[0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.48816629+0.j, 0.44421251+0.j, 0.        +0.j],
       [0.        +0.j, 0.44421251+0.j, 0.48816629+0.j, 0

In [8]:
# What if we have false heralding and we need to trace out the 
# We create a list of idler photons to get the entire state of both the entangled photons. 
keys = [photon0_anl.quantum_state, photon0_hc.quantum_state]

# Suppose we just want to trace out the idler photon of the anl node.
trace_keys = [photon0_anl.quantum_state]

# Get the state and sort all the keys and stuff. We also convert the state into a tuple
# since that is how it will be used in the next steps
new_state, all_keys = tl.quantum_manager._prepare_state(keys)
state_tuple = tuple(map(tuple, new_state))

# Get a list of indices for which key corresponds to which index in the all_keys array which is 
# how the new_state matrix is prepared
indices = tuple([all_keys.index(key) for key in keys])
# Do the same thing for the trace keys
trace_indices = tuple([all_keys.index(key) for key in trace_keys])

# Now, we perform the actual partial tracing. 
false_heralding_state = density_partial_trace(state_tuple, trace_indices, len(all_keys), tl.quantum_manager.truncation)

# Printing results
print("False heralding state:")
pprint(false_heralding_state)


False heralding state:
array([[0.91460003+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.10389586+0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.01854178+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.03005728+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.00341442+0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, 0.00060935+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.02277601+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.10389586+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.012264  +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j, 0.00074851+0.j, 0.        +0.j],
       [0.        

## EXPERIMENTAL AREA!!!!!

In [9]:
from scipy.linalg import fractional_matrix_power
from math import factorial
def generate_povms(bsm):
    """Method to generate POVM operators corresponding to photon detector having 00, 01, 10 and 11 click(s).

    Will be used to generated outcome probability distribution.
    """

    # assume using Fock quantum manager
    truncation = bsm.timeline.quantum_manager.truncation
    create1, destroy1, create2, destroy2 = bsm._generate_transformed_ladders()

    # print("dimensions of create1:", type(create1), len(create1))

    # for detector1 (index 0)
    
    # In effect, this is: 
    #   -1^(i) * [ a_dagger^(i+1) @ a^(i+1) ] / (i+1)! 
    # and, we sum this over all possible number of photons n in the truncation assumed. See paper. Formula used directly.  
    series_elem_list1 = [(-1)**i * fractional_matrix_power(create1, i+1).dot(
        fractional_matrix_power(destroy1, i+1)) / factorial(i+1) for i in range(truncation)]
    
    povm1_1 = sum(series_elem_list1)
    povm0_1 = eye((truncation+1) ** 2) - povm1_1


    
    # for detector2 (index 1)
    series_elem_list2 = [(-1)**i * fractional_matrix_power(create2, i+1).dot(
        fractional_matrix_power(destroy2,i+1)) / factorial(i+1) for i in range(truncation)]
    povm1_2 = sum(series_elem_list2)
    povm0_2 = eye((truncation+1) ** 2) - povm1_2

    identity = eye((truncation+1) ** 2)
    

    # POVM operators for 4 possible outcomes (When both detectors active)
    # Note: povm01 and povm10 are relevant to BSM
    povm00 = povm0_1 @ povm0_2
    povm01 = povm0_1 @ povm1_2
    povm10 = povm1_1 @ povm0_2
    povm11 = povm1_1 @ povm1_2

    # POVM when detector 1 off:
    povm_0 = identity @ povm0_2
    povm_1 = identity @ povm1_2

    # POVM when detector 2 off:
    povm0_ = povm0_1 @ identity
    povm1_ = povm1_1 @ identity

    # POVM when both detectors are off:
    povm = identity

    # print("dimensions of POVMs:", type(povm11), len(povm11))

    return [[povm00, povm01, povm10, povm11], [povm0_, povm1_], [povm_0, povm_1], [povm]]

In [19]:
# effective Bell state generated 
def effective_state(state):
    state_copy = copy(state)
    # state_copy[0][0] = 0
    # state_copy = state_copy/np.trace(state_copy)
    
    return state_copy

povms = generate_povms(bsm)
for i, povm in enumerate(povms):
    print("\nPOVM Case", i)
    povm_tuple = tuple([tuple(map(tuple, povm)) for povm in povm])
    keys = [photon0_anl.quantum_state, photon0_hc.quantum_state]
    new_state, all_keys = tl.quantum_manager._prepare_state(keys)
    indices = tuple([all_keys.index(key) for key in keys])
    state_tuple = tuple(map(tuple, new_state))

    states, probs = measure_multiple_with_cache_fock_density(state_tuple, indices, len(all_keys), povm_tuple,
                                                                tl.quantum_manager.truncation)


    for state,prob in zip(states, probs):
        indices = tuple([all_keys.index(key) for key in keys])
        new_state_tuple = tuple(map(tuple, state))
        remaining_state = density_partial_trace(new_state_tuple, indices, len(all_keys),
                                                tl.quantum_manager.truncation)

        remaining_state_eff = effective_state(remaining_state)
        print("state with probability:", prob)
        pprint(remaining_state_eff)


POVM Case 0
state with probability: 0.9570416653191935
array([[0.9511018 -0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.02414268+0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        -0.j, 0.02414268-0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, 0.00061284-0.j]])
state with probability: 0.021243398105161033
array([[ 0.63390891+0.j,  0.        +0.j,  0.        +0.j,  0.        +0.j],
       [ 0.        +0.j,  0.17871333+0.j, -0.16262224+0.j,  0.        +0.j],
       [ 0.        +0.j, -0.16262224+0.j,  0.17871333+0.j,  0.        +0.j],
       [ 0.        +0.j,  0.        +0.j,  0.        +0.j,  0.00866443+0.j]])
state with probability: 0.021243398105161033
array([[0.63390891+0.j, 0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.17871333+0.j, 0.16262224+0.j, 0.        +0.j],
       [0.        +0.j, 0.16262224+0.j, 0.17871333+0.j, 0.        +0.j],
       [0.        +0.j, 0.       