In [3]:
#https://qiskit.org/textbook/ch-algorithms/quantum-key-distribution.html
# Simulated QKD based off this tutorial


from qiskit import QuantumCircuit, Aer, transpile, assemble
from qiskit.visualization import plot_histogram, plot_bloch_multivector
from numpy.random import randint
import numpy as np


# Creates list of quantumCircuits each representing a single qubit in the message
def encode_message(bits, bases):
    message = []
    for i in range(n):
        qc = QuantumCircuit(1,1)
        if bases[i] == 0: # Prepare qubit in Z-basis
            if bits[i] == 0:
                pass 
            else:
                qc.x(0)
        else: # Prepare qubit in X-basis
            if bits[i] == 0:
                qc.h(0)
            else:
                qc.x(0)
                qc.h(0)
        qc.barrier()
        message.append(qc)
    return message


#Applies corresponding measurement and simulates result of measuring each qubit. 
# result is stored in User_results
def measure_message(message, bases):
    backend = Aer.get_backend('aer_simulator')
    measurements = []
    for q in range(n):
        if bases[q] == 0: # measuring in Z-basis
            message[q].measure(0,0)
        if bases[q] == 1: # measuring in X-basis
            message[q].h(0)
            message[q].measure(0,0)
        aer_sim = Aer.get_backend('aer_simulator')
        qobj = assemble(message[q], shots=1, memory=True)
        result = aer_sim.run(qobj).result()
        measured_bit = int(result.get_memory()[0])
        measurements.append(measured_bit)
    return measurements

#Removes non-corresponding bits based of the measured bases. This data is useless to us
def remove_garbage(a_bases, b_bases, bits):
    good_bits = []
    for q in range(n):
        if a_bases[q] == b_bases[q]:
            # If both used the same basis, add
            # this to the list of 'good' bits
            good_bits.append(bits[q])
    return good_bits

#Selects random part to ensure protocol worked correctly
def sample_bits(bits, selection):
    sample = []
    for i in selection:
        # use np.mod to make sure the
        # bit we sample is always in 
        # the list range
        i = np.mod(i, len(bits))
        # pop(i) removes the element of the
        # list at index 'i'
        sample.append(bits.pop(i))
    return sample


## SIMULATING THE QKD BETWEEN ALICE AND BOB
np.random.seed(seed=0)

n = 20
## Step 1
# Alice generates bits
alice_bits = randint(2, size=n)
print("Alice's random bits: " , alice_bits)


## Step 2
# Create an array to tell us which qubits
# are encoded in which bases
alice_bases = randint(2, size=n)
print("Alice's bases: " , alice_bases)
message = encode_message(alice_bits, alice_bases)


# X basis 0
# Z basis 1

'''
print('bit = %i' % alice_bits[0])
print('basis = %i' % alice_bases[0])

print('bit = %i' % alice_bits[4])
print('basis = %i' % alice_bases[4])

'''


message[4].draw()
message[0].draw()


## BOB's Side
#Step 3

bob_bases = randint(2, size=n)
print("Bobs bases: ", bob_bases)
bob_results = measure_message(message, bob_bases)
print("Bob result: ", bob_results)


#Step 4
alice_key = remove_garbage(alice_bases, bob_bases, alice_bits)
print("Alice key: ", alice_key)
bob_key = remove_garbage(alice_bases, bob_bases, bob_results)
print("Bob key: ", bob_key)

Alice's random bits:  [0 1 1 0 1 1 1 1 1 1 1 0 0 1 0 0 0 0 0 1]
Alice's bases:  [0 1 1 0 0 1 1 1 1 0 1 0 1 0 1 1 0 1 1 0]
Bobs bases:  [0 1 0 1 1 1 1 1 0 1 0 1 1 1 1 0 1 0 0 1]
Bob result:  [0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0]
Alice key:  [0, 1, 1, 1, 1, 0, 0]
Bob key:  [0, 1, 1, 1, 1, 0, 0]
