In [1]:
import numpy as np
import random
import math

In [2]:
class LinearPolarizer:
    def __init__(self):
        
        self.x = np.array([[1], [0]])
        self.y = np.array([[0], [1]])
    
    def horizontal_vertical(self, bit):
        if bit == 0:
            return self.x
        else:
            return self.y
    
    def diagonal_polarization(self, bit):
        jones = (1/np.sqrt(2))*np.array([[1,1],[1,-1]])
        
        if bit == 0:
            return np.dot(jones, self.x)
        else:
            return np.dot(jones, self.y)
    
    def general_polarization(self, angle, basis):
        """
            angle to be in degrees
        """
        angle = (math.pi/180) * (angle)
        jones = np.array([[np.cos(angle), np.sin(angle)], [np.sin(angle), -np.cos(angle)]])
        
        return np.dot(jones, basis)

class PolarizingBeamSplitter:
    def __init__(self):
        pass
    
    def measure(self, vector, basis):
        """
            basis  : basis chosen by bob to measure polarization encoded photon
                        0 -> horizontal/vertical 
                        1 -> diagonal
            vector : Jones vector for polarized photon
            
            returns a dictionary with probabilities of the encoded bit sent by Alice being 0 or 1
        """
        #horizontal-vertical can be clubbed into an identity matrix
        horizontal = np.array([[1, 0], [0, 0]])
        vertical = np.array([[0, 0], [0, 1]])
        plus_minus = (1/np.sqrt(2))*np.array([[1,1],[1,-1]])

        if basis == 0:
            zero = np.dot(horizontal, vector)[0]
            one = np.dot(vertical, vector)[1]
        
        elif basis == 1:
            zero = np.dot(plus_minus, vector)[0]
            one = np.dot(plus_minus, vector)[1]
        else:
            return None
            
        return {0: zero[0]**2, 1: one[0]**2}

In [3]:
class Alice:
    def __init__(self, n):
        self.n = n
        self.alice = {} #{no. : [bit encoded, basis chosen to encode in ]} no. is some unique number
                        #{1:[0,0],2:[0,1],3:[0,0]} Example

    def generate_and_encode(self): 
        """
            Will generate n bits randomly
            For each bit generated, a basis is chosen in which it is encoded
            Dependency for encoding: <class LinearPolarizer>
                0-> horizontal/vertical polarization
                1-> diagonal polarization
            
            Should generate a dictionary of the form self.alice mentioned above
        """
        LP = LinearPolarizer()
        encode = []
        count = self.n
        
        while count!= 0:
            self.alice[count] = [ random.randint(0,1), random.randint(0,1)]
            if self.alice[count][1] == 0:
                encode.append(LP.horizontal_vertical(self.alice[count][0]))
            else:
                encode.append(LP.diagonal_polarization(self.alice[count][0]))
            count-=1
        
        return encode
    

class Bob:
    def __init__(self, n):
        self.n = n
        self.bob = {} #{no. : [bit after measurement, basis chosen to measure in]}
                      #{1:[1,0],2:[0,0],3:[1,0]} Example
    
    def choose_basis_and_measure(self, received):
        """
            received : the data received by bob
            Dependency for measurement: <class PolarizingBeamSplitter>
            
                0-> horizontal/vertical polarization
                1-> diagonal polarization
            
            Should generate a dictionary of the form self.bob mentioned above
        """
        #self.bob[n][0] is the measured bit
        
        PBS = PolarizingBeamSplitter()
        count = self.n
        i = 0
        while count!= 0:
            self.bob[count] = [0, random.randint(0,1)]
            measure = PBS.measure(received[i], self.bob[count][1])
            if measure[0] == measure[1]:
                self.bob[count][0] = random.randint(0,1)
            elif measure[0] > measure[1]:
                self.bob[count][0] = 0
            else:
                self.bob[count][0] = 1
            i += 1
#             self.bob[count][0] = random.randint(0,1) if (abs(PBS.measure(received[i-count], self.bob[count][1])[0]) == abs(PBS.measure(received[i-count], self.bob[count][1])[1])) else (0 if (PBS.measure(received[i-count], self.bob[count][1])[0]> PBS.measure(received[i-count], self.bob[count][1])[1]) else 1)
            #Here,  picking randomly between 0 and 1 if wrong basis is choosen else 0 or 1 based of actual measurement
            count-=1

### Errors are caused by interference from Eve and noise in the channel
### Assumption: All errors are assumed to be caused by Eve   

    1. Alice generates a bit string d of (4+delta)n random bits
    2. Alice randomly chooses a basis to encoded each bit of her bit string and encodes it
        - Basis chosen and the encoded bit is kept track of 
    3. Alice sends the resulting encoded bits to Bob
    4. Bob recieves (4+delta)n bits from Alice (probably erroneous)
        - Bob then radnomly chooses a basis for each of these bits and measures in that basis
        - Basis chosen for measurement and the result of measurement are both kept track of
    5. Alice and Bob do basis reconciliation 
        - With high probability there are still 2n bits remaining, if not abort the protocol
    6. Alice randomly samples 2n bits from the remaining bits (>=2n) and announces which bits she picked
        (but not their values)
    7. Alice randomly selects n bits from these 2n bits as "check bits" and announces them AND their values
    8. Bob compares the bit values he measured for the n check bits selected by Alice and 
        announces the bits where they disagree
        - If more than an acceptable numberof these check bit values disagree, they abort the protocol 
    9. Alice now has an nbit string x, and Bob has an n-bit string x+ e1, 
        where e1 is the error caused by Eve’s interference and/or channel noise
    10. Error correction is done due to which Alice and Bob have the same keys
    11. Privacy amplification

In [4]:
class BB84:
    def __init__(self, n, delta, error_threshold):
        """
            Alice generates (4+delta)n bits 
            delta: small fraction less than one 
            error_threshold: if error while announcing n bits from 2n bits is greater than this
                                key generation is aborted 
        """
        if delta > 1:
            print("Value for delta should be lesser than 1")
            return 
        
        self.n = n
        self.total = math.ceil(4 + delta)*n
        self.alice = Alice(self.total)
        self.bob = Bob(self.total)
        
        self.error = error_threshold
        
    def distribute(self):
        encoded = self.alice.generate_and_encode()
        self.bob.choose_basis_and_measure(encoded)

        recon = Reconciliation(self.error, self.alice.alice, self.bob.bob, self.n)
        
        recon_alice, recon_bob = recon.basis_reconciliation(self.alice.alice, self.bob.bob)
        try:
            final_alice, final_bob = recon.error_correction(recon_alice, recon_bob)
            return final_alice, final_bob
            
        except:
            self.abort()
            return [], []
        
    
    def abort(self):
        print("Protocol aborted")
        return 

In [5]:
def calcRedundantBits(m): 
  
    # Use the formula 2 ^ r >= m + r + 1 
    # to calculate the no of redundant bits. 
    # Iterate over 0 .. m and return the value 
    # that satisfies the equation 
  
    for i in range(m): 
        if(2**i >= m + i + 1): 
            return i 
  
  
def posRedundantBits(data, r): 
   
    j = 0
    k = 1
    m = len(data) 
    res = '' 
   
    for i in range(1, m + r+1): 
        if(i == 2**j): 
            res = res + '0'
            j += 1
        else: 
            res = res + data[-1 * k] 
            k += 1
  
    return res[::-1] 
  
def calcParityBits(arr, r): 
    n = len(arr) 
  
    # For finding rth parity bit, iterate over 
    # 0 to r - 1 
    for i in range(r): 
        val = 0
        for j in range(1, n + 1): 
  
            # If position has 1 in ith significant 
            # position then Bitwise OR the array value 
            # to find parity bit value. 
            if(j & (2**i) == (2**i)): 
                val = val ^ int(arr[-1 * j]) 
                # -1 * j is given since array is reversed 
  
        # String Concatenation 
        # (0 to n - 2^r) + parity bit + (n - 2^r + 1 to n) 
        arr = arr[:n-(2**i)] + str(val) + arr[n-(2**i)+1:] 
    return arr 
  
def detectError(arr, nr): 
    n = len(arr) 
    res = 0
  
    # Calculate parity bits again 
    for i in range(nr): 
        val = 0
        for j in range(1, n + 1): 
            if(j & (2**i) == (2**i)): 
                val = val ^ int(arr[-1 * j]) 
  
        # Create a binary no by appending 
        # parity bits together. 
  
        res = res + val*(10**i) 
  
    return int(str(res), 2) 

In [6]:
class Reconciliation:
    def __init__(self, error_threshold, alice, bob, n):
        
        self.alice = alice
        self.bob = bob
        self.n = n
        self.error_threshold = error_threshold
        
        
    def basis_reconciliation(self, alice, bob):
        """
            alice: {no. : [bit encoded, basis chosen to encode in ]}
            bob  : {no. : [bit after measurement, basis chosen to measure in]}

            First check if the length of both lists are the same
                -> if yes, keep only those bits for alice and bob for which
                   the basis encoded in and measured in is the same. 
        """
        basis_bit_alice = list(alice.values())
        basis_bit_bob = list(bob.values())

        if len(basis_bit_alice) == len(basis_bit_bob):
            raw_key_alice = []
            raw_key_bob = []

            for i in range(len(basis_bit_alice)):
                if basis_bit_alice[i][1] == basis_bit_bob[i][1]:
                    raw_key_alice.append(basis_bit_alice[i][0])
                    raw_key_bob.append(basis_bit_bob[i][0])

            return raw_key_alice, raw_key_bob

        else:
            return None, None
    
    
    def abort(self):
        print("Protocol aborted here")
        return
    
    
    
    def sampling(self, raw_key_alice, raw_key_bob, n):
        
        sampled_key_alice, sampled_key_bob, sampled_key_index = [], [], []
        sampled_key_index = random.sample(list(enumerate(raw_key_alice)), n)
        indices = []
        
        for idx, val in sampled_key_index:
            sampled_key_alice.append(val)
            sampled_key_bob.append(raw_key_bob[idx])
            indices.append(idx)
        
        return sampled_key_alice, sampled_key_bob, indices
    
    def error_correction(self, raw_key_alice, raw_key_bob):
                
        if len(raw_key_alice)<2*self.n:
            self.abort()
        
        else:
            
            sampled_key_alice, sampled_key_bob, sample_indices = self.sampling(raw_key_alice, raw_key_bob, 2*self.n)

            check_alice, check_bob, indices = self.sampling(sampled_key_alice, sampled_key_bob, self.n)

            error = 0

            for i in range(len(check_alice)):
                if check_alice[i] != check_bob[i]:
                    error+=1

            error_rate = error/self.n
            
            if error_rate >= self.error_threshold:
                self.abort()
                
            else:
                
                req_alice = [sampled_key_alice[i] for i in range(len(sampled_key_alice)) if i not in indices]
                req_bob = [sampled_key_bob[i] for i in range(len(sampled_key_bob)) if i not in indices]
                
                if error_rate == 0.0:
                    return req_alice, req_bob
                
                string_alice = "".join(list(map(str, req_alice)))
                string_bob = "".join(list(map(str, req_bob)))
                
                m = len(string_bob)
                r = calcRedundantBits(m) 
                arr = posRedundantBits(string_bob, r) 
                arr = calcParityBits(arr, r) 
                
                correction = self.n - detectError(arr, r) - 1
                req_bob[correction] = int(not req_bob[correction])
                
                
                return req_alice, req_bob

In [12]:
random.seed(9)
bb84 = BB84(10, 0.6, 0.3)
a, b = bb84.distribute()
count = 0

for i in range(len(a)):
    if a[i] != b[i]:
        count += 1

print("Error", count/len(a))
print(a, b)

Error {%f} % 0.0
[1, 1, 0, 1, 0, 1, 0, 0, 0, 0] [1, 1, 0, 1, 0, 1, 0, 0, 0, 0]
