In [12]:
import numpy as np
from qiskit.circuit import QuantumCircuit, Qubit, Clbit
from qiskit.quantum_info import Operator
from qiskit_aer import AerSimulator


def beam_splitter(r: float) -> np.array:
    """
    Returns the beam splitter matrix.

    Args:
        - r (float): The reflection coefficient of the beam splitter.
    Returns:
        - (np.array): 2 x 2 matrix that represents the beam
        splitter matrix.    
    """
    t = np.sqrt(1 - r**2)
    return np.array([[r, t], [t, -r]])


def mz_interferometer(r: float) -> np.array:
    """
    This quantum circuit returns the probability that either A or C
    detect a photon, and the probability that D detects a photon.
    
    Args:
        - r (float): The reflection coefficient of the beam splitters.
    Returns: 
        - np.array(float): An array of shape (2,), where the first 
        element is the probability of detection at A or C,
        and the second element is the probability of detection at D.
    """
    
    bits = [Qubit(), Clbit(), Clbit(), Clbit()]
    qc = QuantumCircuit(bits)

    splitter_op = Operator(beam_splitter(r))
    qc.unitary(splitter_op, 0, label="Beam Splitter")
    qc.measure(0, 0)

    with qc.if_test((bits[1], 0)) as else_:
        pass
    with else_:
        qc.unitary(splitter_op, 0, label="Beam Splitter")
        qc.measure(0, 1)
        qc.measure(0, 2)

    backend = AerSimulator()
    
    shots = 2**20
    job = backend.run(qc, shots=shots)
    result = job.result()
    counts = result.get_counts()

    det_a = counts['000']
    det_c = counts['001']
    det_d = counts['111']

    p_ac = (det_a + det_c)/shots
    p_d = det_d/shots
    
    return np.array([p_ac, p_d])

In [11]:
import threading

def print_result(r: float) -> None:
    print(f"r: {r} -> {mz_interferometer(r)}\n")


test_cases = [
    ('0.1', '[0.990234375, 0.009765625]'),
    ('0.3124456', '[0.915283203125, 0.084716796875]'),
    ('0.5', '[0.8125, 0.1875]'),
    ('0.577350269', '[0.777778, 0.222222]'),
    ('0.9', '[0.842529296875, 0.157470703125]')
]

for i in range(5):
    arg = float(test_cases[i][0])
    t = threading.Thread(target=print_result, args=(arg,))
    t.start()
    t.join()


r: 0.1 -> [0.99012852 0.00987148]

r: 0.3124456 -> [0.91197968 0.08802032]

r: 0.5 -> [0.8128376 0.1871624]

r: 0.577350269 -> [0.77794933 0.22205067]

r: 0.9 -> [0.84581852 0.15418148]

