In [16]:
%pip install qiskit==1.2.4
%pip install qiskit-aer==0.15.1
%pip install pylatexenc==2.10

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [17]:
from qiskit import QuantumCircuit
from qiskit.converters import circuit_to_gate
from qiskit.visualization import array_to_latex
from qiskit.quantum_info import Operator
from qiskit.quantum_info import Statevector
from qiskit import transpile 
from qiskit.providers.basic_provider import BasicSimulator
from qiskit.visualization import plot_histogram
from qiskit.circuit import ControlledGate
import math 

# The aim of the assignment is to simulate the Ekert91 key distribution protocol.

# This notebook is for a simulation of the protocol without an attacker.

In [18]:
def prepare_singlet(qc, a, b):
    qc.x(b) # |01>
    qc.h(a)
    qc.cx(a, b)
    qc.z(a) # add the relative minus sign

def measure_in_basis(qc, qubit, cbit, basis):
    if basis == 'Z':
        pass
    elif basis == 'X':
        qc.h(qubit)
    elif basis == 'W':
        qc.ry(-math.pi/4, qubit)
    elif basis == 'V':
        qc.ry(+math.pi/4, qubit)
    else:
        raise ValueError("basis must be one of 'Z','X','W','V'")
    qc.measure(qubit, cbit)

def bit_to_pm1(bit):
    return 1 if bit == 0 else -1

def expectation_from_counts(counts):
    total = sum(counts.values())
    acc = 0.0
    for s, n in counts.items():
        b_bit = int(s[0])
        a_bit = int(s[1])
        acc += (bit_to_pm1(a_bit) * bit_to_pm1(b_bit)) * n
    return acc / total

THETA_1_OVER_3 = 2 * math.acos(math.sqrt(1/3)) # cos^2(theta/2)=1/3

def sample_trit_1_over_3():
    backend = BasicSimulator()
    qc = QuantumCircuit(2, 2) # q0 biased, q1 fair

    qc.ry(THETA_1_OVER_3, 0)
    qc.measure(0, 0)

    qc.h(1)
    qc.measure(1, 1)

    tqc = transpile(qc, backend)
    result = backend.run(tqc, shots=1).result()
    bitstr = list(result.get_counts().keys())[0] # 'c1c0'
    c1 = int(bitstr[0])
    c0 = int(bitstr[1])

    if c0 == 0:
        return 0
    else:
        return 1 if c1 == 0 else 2

ALICE_BASES = ['X', 'W', 'Z']
BOB_BASES   = ['W', 'Z', 'V']

def sample_alice_basis():
    return ALICE_BASES[sample_trit_1_over_3()]

def sample_bob_basis():
    return BOB_BASES[sample_trit_1_over_3()]


In [19]:
def run_one_round():
    a_basis = sample_alice_basis()
    b_basis = sample_bob_basis()

    backend = BasicSimulator()
    qc = QuantumCircuit(2, 2)
    prepare_singlet(qc, 0, 1)

    measure_in_basis(qc, 0, 0, a_basis)
    measure_in_basis(qc, 1, 1, b_basis)

    tqc = transpile(qc, backend)
    result = backend.run(tqc, shots=1).result()
    bitstr = list(result.get_counts().keys())[0] # 'c1c0'
    b_bit = int(bitstr[0])
    a_bit = int(bitstr[1])
    return a_basis, b_basis, a_bit, b_bit

def e91_plain(N=80, min_chsh_each=200, max_rounds=20000, verbose=True):
    key_alice, key_bob = [], []
    sums = {'XW':0, 'XV':0, 'ZW':0, 'ZV':0}
    cnts = {'XW':0, 'XV':0, 'ZW':0, 'ZV':0}
    discards = 0
    rounds = 0

    def pm1(bit):
        return 1 if bit == 0 else -1

    def chsh_enough():
        return all(cnts[tag] >= min_chsh_each for tag in cnts)

    while (len(key_alice) < N) or (not chsh_enough()):
        rounds += 1
        if rounds > max_rounds:
            break

        a_basis, b_basis, a_bit, b_bit = run_one_round()

        # key cases
        if a_basis == 'W' and b_basis == 'W':
            if len(key_alice) < N:
                key_alice.append(a_bit)
                key_bob.append(1 - b_bit)
            continue
        if a_basis == 'Z' and b_basis == 'Z':
            if len(key_alice) < N:
                key_alice.append(a_bit)
                key_bob.append(1 - b_bit)
            continue

        # CHSH cases
        prod = pm1(a_bit) * pm1(b_bit)
        if a_basis == 'X' and b_basis == 'W':
            sums['XW'] += prod; cnts['XW'] += 1
        elif a_basis == 'X' and b_basis == 'V':
            sums['XV'] += prod; cnts['XV'] += 1
        elif a_basis == 'Z' and b_basis == 'W':
            sums['ZW'] += prod; cnts['ZW'] += 1
        elif a_basis == 'Z' and b_basis == 'V':
            sums['ZV'] += prod; cnts['ZV'] += 1
        else:
            discards += 1

    def E(tag):
        return sums[tag] / cnts[tag] if cnts[tag] else float('nan')

    EXW, EXV, EZW, EZV = E('XW'), E('XV'), E('ZW'), E('ZV')
    S = abs(EXW - EXV + EZW + EZV)

    ok_key = (key_alice == key_bob)

    if verbose:
        print(f"Rounds used: {rounds}")
        print(f"Key length: {len(key_alice)} (target {N})")
        print(f"Key matches after Bob flip? {ok_key}")
        print("CHSH sample counts:", cnts, " discards:", discards)
        print("E values:", {"XW":EXW,"XV":EXV,"ZW":EZW,"ZV":EZV})
        print(f"S = {S:.4f}   (expected ≈ {2*math.sqrt(2):.4f})")

    return {"rounds": rounds, "key_len": len(key_alice), "key_match": ok_key,
            "counts": cnts, "E": {"XW":EXW,"XV":EXV,"ZW":EZW,"ZV":EZV}, "S": S}

In [20]:
_ = e91_plain(N=80, min_chsh_each=200, verbose=True)

Rounds used: 1852
Key length: 80 (target 80)
Key matches after Bob flip? True
CHSH sample counts: {'XW': 223, 'XV': 242, 'ZW': 200, 'ZV': 201}  discards: 586
E values: {'XW': -0.7219730941704036, 'XV': 0.6859504132231405, 'ZW': -0.67, 'ZV': -0.7213930348258707}
S = 2.7993   (expected ≈ 2.8284)


The result confirms that the shared states are entangled when no attacker is present.