In [2]:
import socket
import os
from _thread import *
from random import randrange, getrandbits
from math import gcd
import json
import time

### Generating big prime numbers using Miller Rabin Test
<a href='https://medium.com/@ntnprdhmm/how-to-generate-big-prime-numbers-miller-rabin-49e6e6af32fb'>Source</a>

In [3]:
def is_prime(n, k=128):
    """ Test if a number is prime
        Args:
            n -- int -- the number to test
            k -- int -- the number of tests to do
        return True if n is prime
    """
    # Test if n is not even.
    # But care, 2 is prime !
    if n == 2 or n == 3:
        return True
    if n <= 1 or n % 2 == 0:
        return False
    # find r and s
    s = 0
    r = n - 1
    while r & 1 == 0:
        s += 1
        r //= 2
    # do k tests
    for _ in range(k):
        a = randrange(2, n - 1)
        x = pow(a, r, n)
        if x != 1 and x != n - 1:
            j = 1
            while j < s and x != n - 1:
                x = pow(x, 2, n)
                if x == 1:
                    return False
                j += 1
            if x != n - 1:
                return False
    return True

def generate_prime_candidate(length):
    """ Generate an odd integer randomly
        Args:
            length -- int -- the length of the number to generate, in bits
        return a integer
    """
    # generate random bits
    p = getrandbits(length)
    # apply a mask to set MSB and LSB to 1
    p |= (1 << length - 1) | 1
    return p
def generate_prime_number(length=1024):
    """ Generate a prime
        Args:
            length -- int -- length of the prime to generate, in          bits
        return a prime
    """
    p = 4
    # keep generating while the primality test fail
    while not is_prime(p, 128):
        p = generate_prime_candidate(length)
    return p


### Key generation

In [4]:
def generate_public_key(phi):
    e = phi-1
    while True:
        if gcd(e, phi) == 1:
            return e
        e -= 1

In [5]:
def generate_key():
    p = generate_prime_number()
    q = generate_prime_number()
    n = p * q
    phi = (p-1) * (q-1)
    e = generate_public_key(phi)
    d = pow(e, -1, phi)
    return (e, n), (d, n)

### Encryption

In [6]:
def encrypt(pk, plaintext):
    # Unpack the key into it's components
    key, n = pk
    # Convert each letter in the plaintext to numbers based on the character using a^b mod m
    cipher = [pow(ord(char), key, n) for char in plaintext]
    # Return the array of bytes
    return cipher

### Decryption

In [7]:
def decrypt(pk, ciphertext):
    # Unpack the key into its components
    key, n = pk
    # Generate the plaintext based on the ciphertext and key using a^b mod m
    plain = [chr(pow(char, key, n)) for char in ciphertext]
    # Return the array of bytes as a string
    return ''.join(plain)

In [8]:
def get_timestamp():
    return int(time.time())

def get_nonce():
    return randrange(2**64)

def get_nonce_acknowledgment(nonce_received):
    # XORing with a constant value twice gives back the original value
    return nonce_received ^ 0xFFFF 

def get_validity_timestamp():
    return get_timestamp() + 5*60*1000

def check_validity_message(valid_till):
    return valid_till > get_timestamp()

In [9]:
def add_public_key_to_ring(public_key, public_key_ring, client_id):
    public_key_ring[client_id] = public_key

def get_public_key_from_ring(public_key_ring, client_id):
    return public_key_ring[client_id]

In [10]:
ServerSideSocket = socket.socket()
host = '127.0.0.1'
port = 8080
ThreadCount = 0
public_key, private_key = generate_key()
public_key_ring = {}
message_nonce_list = []
try:
    ServerSideSocket.bind((host, port))
except socket.error as e:
    print(str(e))
print('Socket is listening..')
ServerSideSocket.listen(5)
def multi_threaded_client(connection):
    client_handshake_message_str = connection.recv(2048).decode('utf-8')
    client_handshake_message = json.loads(client_handshake_message_str)
    if not check_validity_message(client_handshake_message['valid_till']):
        print('Invalid message, message expired')
        return
    else:
        print('Validity check passed')
        client_id = client_handshake_message['id']
        client_public_key = client_handshake_message['public_key']
        nonce_received = client_handshake_message['nonce']
        add_public_key_to_ring(client_public_key, public_key_ring, client_id)

        message_nonce = get_nonce()
        message_nonce_list.append(message_nonce)

        server_handshake_message = {
            'type':'handshake',
            'id': 'public key distribution authority', # server id
            'public_key':public_key,
            'timestamp': get_timestamp(),
            'nonce': message_nonce, # message nonce of the server handshake message
            'acknowledgement_nonce': get_nonce_acknowledgment(nonce_received), # acknowledgment_nonce of the client handshake message,
            'valid_till': get_validity_timestamp()
        }

        server_handshake_message_str = json.dumps(server_handshake_message)
        connection.sendall(str.encode(server_handshake_message_str))
    connection.close()

    print(public_key_ring)
    
while True:
    Client, address = ServerSideSocket.accept()
    print('Connected to: ' + address[0] + ':' + str(address[1]))
    start_new_thread(multi_threaded_client, (Client, ))
    ThreadCount += 1
    print('Thread Number: ' + str(ThreadCount))
ServerSideSocket.close()

Socket is listening..
Connected to: 127.0.0.1:56689
Thread Number: 1
Validity check passed
{'c88451de68345346a911447ff0935760': [27066444483047106755758558391298117089038987444058711146891607668182248374475584424558386142964543410199568270921772655601693971925263028349563896637502134035465672269666535197770407563767384574074461003454776501499457637237746671924979368849484196271475912621209745361356002901260105571288079731160627870084047384010778220116220727281806851658915502889287505124515981990667947316421516193344933196436698402203732287990101942446496400442889632903725390243431080343777267057440252197280678481613682264045525284655381556557848354375216948708986395961453442132049116940993930511326414982021610158243435783514228706632163, 270664444830471067557585583912981170890389874440587111468916076681822483744755844245583861429645434101995682709217726556016939719252630283495638966375021340354656722696665351977704075637673845740744610034547765014994576372377466719249793688494841962714