# References

https://qiskit.org/textbook/ch-algorithms/quantum-key-distribution.html
    
Protecting Infromation by Susan Leopp, William K. Wootters

# Without Interception

In [45]:
from qiskit import QuantumCircuit, Aer, transpile, assemble, execute
from qiskit.visualization import plot_histogram, plot_bloch_multivector
from numpy.random import randint
from qiskit import IBMQ
import random
import numpy as np

#IBMQ.save_account('97ea24b31772b69c9b8fce4424efb5656330b4108602699ee146f0055effa02edecafb462323ce60dd7225e1bb66b57b433786d8df7a51638750b4e7ee1b7f5c', overwrite=True)
#provider=IBMQ.load_account()
#backend=provider.get_backend('ibmq_belem')
simulator = Aer.get_backend('qasm_simulator')

In [46]:
def without_eve(bits, al_bases, b_bases):
    message = []
    measurements = []
    for i in range(len(bits)):
        qc = QuantumCircuit(1,1)
        if bits[i] == 1:Z
            qc.x(0)
        if al_bases[i] == 1:
            qc.h(0)
        
        qc.barrier()
        message.append(qc)
        
        if b_bases[i] == 1: # measuring in X-basis
            message[i].h(0)
        message[i].measure(0,0)
        #result = execute(message[i], backend=backend, shots=1, memory = True).result()
        result = execute(message[i], backend=simulator, shots=1, memory = True).result()
        measured_bit = int(result.get_memory()[0])
        measurements.append(measured_bit)
        
        #if bases[q] == 0: # measuring in Z-basis
        #    message[q].measure(0,0)
        #if bases[q] == 1: # measuring in X-basis
        #    message[q].h(0)
        #qc.draw()
    return measurements

In [47]:
n = 20

alice_bits = []
for i in range(n):
    alice_bits.append(random.randint(0,1))
    

alice_bases = []
for i in range(n):
    alice_bases.append(random.randint(0,1))

    
bob_bases = []
for i in range(n):
    bob_bases.append(random.randint(0,1))
    
    
key_initial = without_eve(alice_bits, alice_bases, bob_bases)

print(key_initial)

[0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0]


In [48]:
print("Bob measures =", key_initial)
print ("Alice sends = ",alice_bits)

print("Bob basis =   ", bob_bases)
print ("Alice basis  =",alice_bases)

secure_keya = []
secure_keyb = []

for i in range(len(alice_bits)):
    if(alice_bases[i]==bob_bases[i]):
        secure_keya.append(key_initial[i])
        secure_keyb.append(alice_bits[i])
        
print("Secure key with alice  =",secure_keya)
print("Secure key with bob    =",secure_keyb)

Bob measures = [0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0]
Alice sends =  [0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
Bob basis =    [1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0]
Alice basis  = [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0]
Secure key with alice  = [0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0]
Secure key with bob    = [0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0]


# With Interception

In [49]:
def with_eve(bits, al_bases, eve_bases , b_bases):
    message = []
    eve_measurements = []
    bob_measurements = []
    for i in range(len(bits)):
        qc = QuantumCircuit(1,1)
        if bits[i] == 1:
            qc.x(0)
        if al_bases[i] == 1:
            qc.h(0)
        
        qc.barrier()
        message.append(qc)
        
        if eve_bases[i] == 1: # measuring in X-basis
            message[i].h(0)
        message[i].measure(0,0)
        #result = execute(message[i], backend=backend, shots=1, memory = True).result()
        result = execute(message[i], backend=simulator, shots=1, memory = True).result()
        measured_bit = int(result.get_memory()[0])
        eve_measurements.append(measured_bit)
        if eve_bases[i] == 1: # measuring in X-basis
            message[i].h(0)
        
        
        
        if b_bases[i] == 1: # measuring in X-basis
            message[i].h(0)
        message[i].measure(0,0)
        #result = execute(message[i], backend=backend, shots=1, memory = True).result()
        result = execute(message[i], backend=simulator, shots=1, memory = True).result()
        measured_bit = int(result.get_memory()[0])
        bob_measurements.append(measured_bit)
        
        #qc.draw()
    return bob_measurements

In [50]:
n = 50

alice_bits = []
for i in range(n):
    alice_bits.append(random.randint(0,1))
    

alice_bases = []
for i in range(n):
    alice_bases.append(random.randint(0,1))
    
eve_bases = []
for i in range(n):
    eve_bases.append(random.randint(0,1))

    
bob_bases = []
for i in range(n):
    bob_bases.append(random.randint(0,1))
    
    
message_bob = with_eve(alice_bits, alice_bases, eve_bases, bob_bases)


In [51]:
#print("Bob measures =", message_bob)
#print ("Alice sends = ",alice_bits)

#print("Bob basis =   ", bob_bases)
#print ("Alice basis  =",alice_bases)

secure_keya = []
secure_keyb = []

for i in range(len(alice_bits)):
    if(alice_bases[i]==bob_bases[i]):
        secure_keya.append(message_bob[i])
        secure_keyb.append(alice_bits[i])
        
#print("Secure key with alice  =",secure_keya)
#print("Secure key with bob    =",secure_keyb)

# Error Correction

In [57]:
def error_cal(secure_keya, secure_keyb):

    error = 0
    sample_size = int(len(secure_keya)/5)
    for i in range(sample_size):
        index = randint(0,len(secure_keya))
        if secure_keya[index] != secure_keyb[index]:
            error += 1
        secure_keya.pop(index)
        secure_keyb.pop(index)

    error_rate = error/sample_size 
    
    return error_rate
#print (secure_keya)
#print (secure_keyb)


def error_calt(secure_keya, secure_keyb):

    error = 0
    for i in range(len(secure_keya)):
        if secure_keya[i] != secure_keyb[i]:
            error += 1
        #secure_keya.pop(index)
        #secure_keyb.pop(index)

    error_rate = error/len(secure_keya)
    
    return error_rate


#print("lenght =", len(secure_keya))


In [53]:
def shuffling(seca, secb):
    # Shuffling the keys in a similar order
    keys = list(zip(seca, secb))
    random.shuffle(keys)
    
    shuffle_seca , shuffle_secb = zip(*keys)
    
    return shuffle_seca, shuffle_secb

def parity(sec):
    # Determine parity of bit string
    return np.sum(sec)%2

def error_correction(seca, secb, flag):
    # Recursive error correction
    
    #print(seca)
    #print(secb)
    if(len(seca)==1):
        if(parity(seca) == parity(secb)):
            return seca, secb ,0
        else:
            return seca, seca , 1
        
    if(len(seca)>1):
        
        new_size = int(len(seca)/2)
        
        seca_first  = seca[0:new_size] 
        seca_second = seca[new_size:] 
        
        secb_first  = secb[0:new_size] 
        secb_second = secb[new_size:] 
        
        if ((parity(seca_first) != parity(secb_first)) and (parity(seca_second) != parity(secb_second))):
            corr_seca_f, corr_secb_f, f = error_correction(seca_first, secb_first, flag)
            corr_seca_s, corr_secb_s, f = error_correction(seca_second, secb_second, flag)
            return corr_seca_f + corr_seca_s, corr_secb_f + corr_secb_s, f
        elif (parity(seca_first) != parity(secb_first)):
            corr_seca_f, corr_secb_f, f = error_correction(seca_first, secb_first, flag)
            return corr_seca_f + seca_second, corr_secb_f + secb_second, f
        elif (parity(seca_second) != parity(secb_second)):
            corr_seca_s, corr_secb_s, f = error_correction(seca_second, secb_second, flag)
            return seca_first + corr_seca_s, secb_first + corr_secb_s, f
        else:
            return seca, secb, 0
    
        

In [54]:
error_rate_sample = error_cal(secure_keya, secure_keyb)
error_rate_total = error_calt(secure_keya, secure_keyb)
print("Error rate sample = ",error_rate_sample)
print("Error rate total= ", error_calt(secure_keya, secure_keyb))

errors = int(error_rate_sample*len(secure_keya))
print("Length of key before error correction = " , len(secure_keya))
print("Errors = ", errors)


Error rate sample =  0.6
Error rate total=  0.2727272727272727
Length of key before error correction =  22
Errors =  13


In [55]:
sec_ka, sec_kb = secure_keya, secure_keyb
iterations = 0


while errors > 0:
    shuf_a, shuf_b = shuffling(sec_ka, sec_kb)
    sec_ka, sec_kb, flag = error_correction(shuf_a, shuf_b, 0)
    if (flag == True):
        errors -= 1
        
    if(shuf_a == sec_ka and shuf_b == sec_kb):
        iterations += 1
    
    if(iterations > 5):
        break
        
    #print(sec_ka)
    #print(sec_kb)
    

In [59]:
#print("Secure key with alice  =",sec_ka)
#print("Secure key with bob    =",sec_kb)

secure_key_with_alice = []
secure_key_with_bob = []

for i in range(len(sec_ka)):
    secure_key_with_alice.append(sec_ka[i])
    secure_key_with_bob.append(sec_kb[i])


print(len(secure_key_with_alice))

print("Error rate sample = ",error_cal(secure_key_with_alice, secure_key_with_bob))
print("Error rate total= ", error_calt(secure_key_with_alice, secure_key_with_bob))

22
Error rate sample =  0.0
Error rate total=  0.0


# Privacy Amplification

In [87]:
new_size = int(len(secure_key_with_alice)/4)*4

sec_ka = secure_key_with_alice[0:new_size]
sec_kb = secure_key_with_bob[0:new_size]


In [86]:
hamming_matrix = [[1, 0, 1, 0], [0, 1, 0, 1]]

#print(hamming_matrix)
print("Key with Alice before PA = ", sec_ka)
print("Key with Bob before PA   = ", sec_kb)

ampl_ka = []
ampl_kb = []

for i in range(int(new_size/4)):
    sub_ka = sec_ka[i*4:i*4+4]
    sub_kb = sec_kb[i*4:i*4+4]
    
    for j in range(len(hamming_matrix)):

        ampl_ka.append((np.matmul(hamming_matrix, sub_ka)%2)[j])
        ampl_kb.append((np.matmul(hamming_matrix, sub_kb)%2)[j])
    
print("Key with Alice after PA  = ", ampl_ka)
print("Key with Bob after PA    = ", ampl_kb)

Key with Alice before PA =  [1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1]
Key with Bob before PA   =  [1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1]
Key with Alice after PA  =  [1, 0, 0, 1, 1, 1, 0, 1]
Key with Bob after PA    =  [1, 0, 0, 1, 1, 1, 0, 1]
