In [27]:
import random
from collections import defaultdict
import pandas as pd
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import Aer
import numpy as np

# -------------------------------
# Parameters
# -------------------------------
singlets = 1024
shots = 1
backend = Aer.get_backend('qasm_simulator')

# -------------------------------
# Alice/Bob measurement choices
# -------------------------------
# 1=X, 2=W, 3=Z
b = [random.randint(1,3) for _ in range(singlets)]       # Alice
b_prime = [random.randint(1,3) for _ in range(singlets)] # Bob

# -------------------------------
# Functions to prepare singlet
# -------------------------------
def singlet():
    qr = QuantumRegister(2, 'qr')
    cr = ClassicalRegister(2, 'cr')
    qc = QuantumCircuit(qr, cr)
    
    # |ψ_s> = (|01> - |10>)/√2
    qc.x(1)
    qc.h(0)
    qc.cx(0,1)
    qc.z(0)
    return qc

def measure_a(qc, basis, target_qubit=0):
    if basis == 1:  # X
        qc.h(target_qubit)
    elif basis == 2:  # Y
        qc.s(target_qubit)
        qc.h(target_qubit)
        qc.t(target_qubit)
        qc.h(target_qubit)
    elif basis == 3:  # Z
        pass
    return qc

def measure_b(qc, basis, target_qubit=1):
    if basis == 1:  # X
        qc.s(target_qubit)
        qc.h(target_qubit)
        qc.t(target_qubit)
        qc.h(target_qubit)

    elif basis == 2:  # Y
        pass
    elif basis == 3:  # Z
        qc.s(target_qubit)
        qc.h(target_qubit)
        qc.tdg(target_qubit)
        qc.h(target_qubit)
    return qc

# -------------------------------
# Execute measurements
# -------------------------------
results = []

for i in range(singlets):
    qc = singlet()
    qc = measure_a(qc, b[i], 0)
    qc = measure_b(qc, b_prime[i], 1)
    qc.measure([0,1],[0,1])
    
    job = backend.run(qc, shots=shots)
    counts = job.result().get_counts()
    bitstring = list(counts.keys())[0]
    
    # Convert to ±1
    a = 1 if bitstring[0]=='0' else -1
    a_prime = 1 if bitstring[1]=='0' else -1
    
    results.append({
        "alice_bit": 0 if a==1 else 1,
        "bob_bit": 0 if a_prime==1 else 1,
        "alice_basis": b[i],
        "bob_basis": b_prime[i],
        "a": a,
        "a_prime": a_prime
    })

# -------------------------------
# Key generation (compatible bases)
# -------------------------------
alice_key = []
bob_key = []
mismatches = 0

for r in results:
    # Compatible measurements: a2/b1 or a3/b2
    if (r["alice_basis"]==2 and r["bob_basis"]==1) or (r["alice_basis"]==3 and r["bob_basis"]==2):
        alice_key.append(r["alice_bit"])
        bob_corrected = 1 - r["bob_bit"]
        bob_key.append(bob_corrected)
        if r["alice_bit"] != bob_corrected:
            mismatches += 1

print(f"Number of mismatched key bits: {mismatches}")

# -------------------------------
# Group results for CHSH
# -------------------------------
grouped = defaultdict(list)
for r in results:
    grouped[(r["alice_basis"], r["bob_basis"])].append((r["a"], r["a_prime"]))

def countGroup(group):
    counts = {(1,1):0,(1,-1):0,(-1,1):0,(-1,-1):0}
    for a,a_p in group:
        counts[(a,a_p)] +=1
    return counts

def calculateExpectation(counts):
    total = sum(counts.values())
    return sum(a*a_p*(count/total) for (a,a_p), count in counts.items())

# -------------------------------
# Build CHSH table
# -------------------------------
table_rows = []
measurement_types = { "X ⊗ W":(1,1), "X ⊗ V":(1,3), "Z ⊗ W":(3,1), "Z ⊗ V":(3,3) }

for meas_name, meas_pair in measurement_types.items():
    group = grouped[meas_pair]
    counts = countGroup(group)
    total = sum(counts.values())
    expectation = calculateExpectation(counts)  # only once per type
    
    for outcome, n_ij in counts.items():
        p_ij = n_ij / total if total > 0 else 0
        contribution = p_ij * (outcome[0]*outcome[1])
        table_rows.append({
            "Measurement": meas_name,
            "(b,b')": meas_pair,
            "(a,a')": outcome,
            "n_ij": n_ij,
            "p_ij": round(p_ij,4),
            "p·(a·a')": round(contribution,4),
            "Expectation": round(expectation,4)  # same for all rows in this group
        })


df_chsh = pd.DataFrame(table_rows)
print(df_chsh)

# -------------------------------
# Compute CHSH S value
# -------------------------------
E_XW = calculateExpectation(countGroup(grouped[(1,1)]))
E_XV = calculateExpectation(countGroup(grouped[(1,3)]))
E_ZW = calculateExpectation(countGroup(grouped[(3,1)]))
E_ZV = calculateExpectation(countGroup(grouped[(3,3)]))
S = E_XW - E_XV + E_ZW + E_ZV

print(f"CHSH correlation value: S = {S}")
print(f"Total key bits: {len(alice_key)}")


Number of mismatched key bits: 0
   Measurement  (b,b')    (a,a')  n_ij    p_ij  p·(a·a')  Expectation
0        X ⊗ W  (1, 1)    (1, 1)    10  0.0862    0.0862      -0.6897
1        X ⊗ W  (1, 1)   (1, -1)    40  0.3448   -0.3448      -0.6897
2        X ⊗ W  (1, 1)   (-1, 1)    58  0.5000   -0.5000      -0.6897
3        X ⊗ W  (1, 1)  (-1, -1)     8  0.0690    0.0690      -0.6897
4        X ⊗ V  (1, 3)    (1, 1)    42  0.3717    0.3717       0.6283
5        X ⊗ V  (1, 3)   (1, -1)    14  0.1239   -0.1239       0.6283
6        X ⊗ V  (1, 3)   (-1, 1)     7  0.0619   -0.0619       0.6283
7        X ⊗ V  (1, 3)  (-1, -1)    50  0.4425    0.4425       0.6283
8        Z ⊗ W  (3, 1)    (1, 1)    10  0.1000    0.1000      -0.7400
9        Z ⊗ W  (3, 1)   (1, -1)    46  0.4600   -0.4600      -0.7400
10       Z ⊗ W  (3, 1)   (-1, 1)    41  0.4100   -0.4100      -0.7400
11       Z ⊗ W  (3, 1)  (-1, -1)     3  0.0300    0.0300      -0.7400
12       Z ⊗ V  (3, 3)    (1, 1)     5  0.0431    0.0431 