In [1]:
import random
import math
from Crypto.Util import number
import multiprocessing
from joblib import Parallel, delayed

In [2]:
class keyGen:
    def __init__(self,nbits):                         # constructor
        self.q = number.getPrime(nbits)               # q
        self.alpha = self.getalpha()                  # alpha
        self.privateKey = random.randrange(1,10,1)    # random private key within a range
    
    def getalpha(self):                               
        alpha = -1
        while alpha == -1:
            alpha = self.getPrimitiveRoot(self.q)     # gets one of the primitive roots of q
        return alpha
    
    def power(self, x, y, p):  
        res = 1   
        x = x % p 
        while (y > 0):  
            if (y & 1):                               # If y is odd, multiply x with result  
                res = (res * x) % p   
            y = y >> 1                                # y = y/2  
            x = (x * x) % p  
        return res  

    def getPrimeFactors(self,phi):
        factors = set()
        while (phi % 2 == 0) : 
            factors.add(2)  
            phi = phi // 2

        for i in range(3, int(math.sqrt(phi)), 2):    # when phi is odd
            while (phi % i == 0) : 
                factors.add(i)  
                phi = phi // i  

        if (phi > 2) :                                # This condition is to handle the case when phi is a prime number greater than 2  
            factors.add(phi)  

        return factors
    
    def getPrimitiveRoot(self,prime):
        phi = prime - 1                               # Euler's totient of prime number p is p-1
        factors = self.getPrimeFactors(phi)           # get the prime factors of euler's totient
        
        '''Checks if the factors are primitive primes or not'''
        ret_list = Parallel(n_jobs = multiprocessing.cpu_count())(delayed(self.checkPrimitivePrime)(i, prime, factors) for i in range(2, phi+1))
        ret_list = list(filter(None, ret_list))       # remove the None values
        
        print(ret_list[:len(ret_list)-10:-1])         # print last 10 primitive roots
        return(ret_list[len(ret_list)-1])             # select highest primitive root
    
    def checkPrimitivePrime(self,i,q,factors):
        '''Checks if primitive root'''
        phi = q - 1
        for factor in factors:
            if self.power(i, phi // factor, q) == 1:
                return None
        return i

In [3]:
keyA = keyGen(20)
print('Selected q:', keyA.q)
print('Selected alpha:', keyA.alpha)
print('Selected private key:', keyA.privateKey)

[1022499, 1022493, 1022491, 1022490, 1022489, 1022487, 1022483, 1022480, 1022475]
Selected q: 1022501
Selected alpha: 1022499
Selected private key: 9


In [4]:
publicKeyA = (keyA.alpha**keyA.privateKey)%keyA.q
print('Public Key of A:',publicKeyA)

Public Key of A: 1021989


In [5]:
import socket                               # Import socket module

s = socket.socket()                         # Create a socket object
host = socket.gethostname()                 # Get local machine name
port = 8001                                 # Reserve a port for your service.
s.bind((host, port))                        # Bind to the port

s.listen(5)                                 # Now wait for client connection.
while True:
    c, addr = s.accept()                    # Establish connection with client.
    print('Got connection from', addr)
    c.send(bytes(str(keyA.q),'utf8'))
    c.send(bytes(str(keyA.alpha),'utf8'))
    c.send(bytes(str(publicKeyA),'utf8')) 
    publicKeyB = c.recv(1024)               # receive public key from client
    c.close() 
    break

Got connection from ('127.0.0.1', 44546)


In [6]:
'''Generate Secret Key'''
publicKeyB = str(publicKeyB,'utf8')
publicKeyB = int(publicKeyB)
secretKey = (publicKeyB**keyA.privateKey)%keyA.q
secretKey

10562