### Required Imports

In [None]:
%pip install qiskit==1.2.4
%pip install qiskit-aer==0.15.1
%pip install pylatexenc==2.10

from qiskit import QuantumCircuit
from qiskit.converters import circuit_to_gate
from qiskit.visualization import array_to_latex
from qiskit.quantum_info import Operator
from qiskit.quantum_info import Statevector
from qiskit import transpile 
from qiskit.providers.basic_provider import BasicSimulator
from qiskit.visualization import plot_histogram
from qiskit.circuit import ControlledGate
import math 

# The aim of the assignment is to simulate the Ekert91 key distribution protocol.
# This notebook is for a simulation of the protocol with an attacker, to demonstrate that the attacker can be detected.

### Circuit Construction
Constructs the circuits required for each measurement, along with the unitary matrices required for transformations of each kind.

In [11]:
root2 = math.sqrt(2)
denom1 = math.sqrt(4 + 2*root2)
denom2 = math.sqrt(4 - 2*root2) 

W_transform_matrix = [[ -1 / denom1  ,  (1 + root2) / denom1  ],
                      [  1 / denom2  ,  (root2 - 1) / denom2  ]]

V_transform_matrix = [[  1 / denom1  ,  (1 + root2) / denom1  ], 
                      [ -1 / denom2  ,  (root2 - 1) / denom2  ] ]

#---------------------------------------------------------------------------------
def get_entangled_pair(disrupted = False):

    if disrupted:
        q = QuantumCircuit(3) 
        q.h(0)
        q.cx(0,1)
    
        #boo!!
        q.cx(0, 2)
        return q
    
    #get a new entangled pair
    q = QuantumCircuit(2) 
    q.h(0)
    q.cx(0,1)

    #return entangled pair
    return q
#---------------------------------------------------------------------------------
def construct_circuit(entangled_pair, operators):
    
    #construct the appropriate circuit to match the operators
    for qubit, operator in enumerate(operators):
        if operator == "z": continue
        if operator == "x": entangled_pair.h(qubit)
        if operator == "w": entangled_pair.unitary(W_transform_matrix,[qubit])
        if operator == "v": entangled_pair.unitary(V_transform_matrix,[qubit])
    
    #return circuit
    return entangled_pair

### Circuit Measurement
simulates the measurement of a circuit with simulate_measurement(circuit). Parameter options for varying shots, defaulting to one.

In [3]:
def simulate_measurement(circuit, shots = 1):
    
    #simulate measurements
    backend = BasicSimulator()
    compiled = transpile(circuit, backend)
    job_sim = backend.run(compiled, shots=shots)
    result_sim = job_sim.result()

    #return the result
    return result_sim.get_counts(compiled)

### Random Entry From Array
(when len(arr) = 3, i = 0, T_transform_matrix is of the required state for a 1/3 probability)

In [4]:
#returns a random entry of arr
def get_random_entry(arr, i = 0):

    #get total entries left to consider
    entries = len(arr) - i

    #handle base cases
    if entries == 0: raise IndexError ("tried to get entry from empty array")
    if entries == 1: return arr[i]

    #get transform matrix to give 1/(remaining entries) probability of selecting next entry
    T_transform_matrix = [[  1/math.sqrt(entries)        ,  math.sqrt((entries-1)/entries)],
                          [  math.sqrt((entries-1)/entries),  - 1/math.sqrt(entries)]]
    
    # get a new circuit of this type
    circuit = QuantumCircuit(1)
    circuit.unitary(T_transform_matrix,[0])
    
    #measure circuit
    circuit.measure_all() 

    #simulate measurement and return current index if selected, else get random index from remaining indexes
    return arr[i] if simulate_measurement(circuit).get("0",0) == 1 else get_random_entry(arr, i+1)

### Entanglement Testing

Collates each pair of measurements (of the appropriate types) into the average for each pair over all repetitions. Combines and returns true with entanglement_test(S) if value is within a certain allowance.

In [21]:
measurement_map = {"00":1, "01": -1, "10":-1, "11":1}

def average(circuits, bits = 2): 

    #return the average result of simulating each circuit, considering only the first 'bits' bits, 
    #and converted to +-1 with measurement_map
    
    return sum([measurement_map.get(next(iter(simulate_measurement(circuit)))[:bits], 0) 
                for circuit in circuits]) / len(circuits)
    
def entanglement_test(S, allowance = (2*math.sqrt(2)-2), no_allowance = False):

    #get the average of each component of S, and inverse the case ('x', 'y')
    averages = [average(circuits) * (-1 if (a, b) == ('x', 'v') else 1) 
                for (a, b), circuits in S.items()]

    #calculate the absolute value of the sum of averages, |S|
    abs_s = abs(sum(averages))
    print(f"|S|: {abs_s}")

    #if no error to be raised when result outwith allowance, return early
    if no_allowance: return True

    #assert if |S| in close enough to 2*root2
    assert abs_s >= (2*math.sqrt(2)) - allowance, "entanglement has been disrupted."

### Ekert 91 Protocol

In [22]:
alice_operators = (a1:='x', a2:='w', a3:='z')
bob_operators = (b1:='w', b2:='z', b3:='v')

def get_alice_bit(circuit):
    
    #simulate circuit
    return int(next(iter(simulate_measurement(circuit)))[0])

def ekert91_protocol(key_length):

    #for each repetition
    repetition_measurements = []
    for repetition in range(int (9 * key_length / 2)):
        
        #1
#---------------------------------------------------------------------------------
        ab_pair = get_entangled_pair(disrupted=True)
#---------------------------------------------------------------------------------
        
        #2-3
        a_op = get_random_entry(alice_operators)
        b_op = get_random_entry(bob_operators)

        #4-5
        circuit = construct_circuit(ab_pair, (a_op, b_op)) 
        circuit.measure_all()

        #save measurement data
        repetition_measurements.append((a_op, b_op, circuit))

    S = {}
    shared_key = ""
    for a_op, b_op, circuit in repetition_measurements:
        
        #if bases are the same, use for key
        if a_op == b_op:
            shared_key += str(get_alice_bit(circuit))

        #else if case forms S, save for calculation of S
        elif (pair:=(a_op, b_op)) in ((a1,b1),(a1,b3),(a3,b1),(a3,b3)):
            if pair not in S: S[(a_op, b_op)] = []
            S[pair].append(circuit)

    entanglement_test(S, no_allowance=True)
    print(f"shared key of length {len(shared_key)}: {shared_key}\n")
    
ekert91_protocol(32) #1
ekert91_protocol(32) #2
ekert91_protocol(32) #3
ekert91_protocol(32) #4
ekert91_protocol(32) #5
ekert91_protocol(32) #6
ekert91_protocol(32) #7
ekert91_protocol(32) #8
ekert91_protocol(32) #9

|S|: 1.640749601275917
shared key of length 28: 1100110010011011000101011010

|S|: 1.0546218487394956
shared key of length 30: 001100000101011100000110000100

|S|: 1.1999999999999997
shared key of length 40: 1001110000010011111110100001010100000111

|S|: 1.1073232323232323
shared key of length 26: 10011101100011100001010011

|S|: 2.1142857142857143
shared key of length 34: 1000010110000001011101000011011001

|S|: 1.4981299402352033
shared key of length 20: 00100010011101001010

|S|: 1.7714285714285714
shared key of length 30: 001011101000000000111000000111

|S|: 1.383537229361028
shared key of length 24: 101100001101101000001110

|S|: 1.5027777777777778
shared key of length 30: 111101101101101100010001000101

