In [1]:
import socket
import datetime
from dateutil import parser
from _thread import *
from random import randrange, getrandbits
from math import gcd
import json

In [2]:
def is_prime(n, k=8):
    """ Test if a number is prime
        Args:
            n -- int -- the number to test
            k -- int -- the number of tests to do
        return True if n is prime
    """
    # Test if n is not even.
    # But care, 2 is prime !
    if n == 2 or n == 3:
        return True
    if n <= 1 or n % 2 == 0:
        return False
    # find r and s
    s = 0
    r = n - 1
    while r & 1 == 0:
        s += 1
        r //= 2
    # do k tests
    for _ in range(k):
        a = randrange(2, n - 1)
        x = pow(a, r, n)
        if x != 1 and x != n - 1:
            j = 1
            while j < s and x != n - 1:
                x = pow(x, 2, n)
                if x == 1:
                    return False
                j += 1
            if x != n - 1:
                return False
    return True

def generate_prime_candidate(length):
    """ Generate an odd integer randomly
        Args:
            length -- int -- the length of the number to generate, in bits
        return a integer
    """
    # generate random bits
    p = getrandbits(length)
    # apply a mask to set MSB and LSB to 1
    p |= (1 << length - 1) | 1
    return p

def generate_prime_number(prev, length=8):
    """ Generate a prime
        Args:
            length -- int -- length of the prime to generate, in          bits
        return a prime
    """
    p = 4
    # keep generating while the primality test fail
    while not is_prime(p, 4) or p == prev:
        p = generate_prime_candidate(length)
    return p


def generate_public_key(phi):
    e = 2
    while True:
        if gcd(e, phi) == 1:
            return e
        e += 1

def generate_rsa_key():
    p = generate_prime_number(1,8)
    q = generate_prime_number(p,8)
    n = p * q
    phi = (p-1) * (q-1)
    e = generate_public_key(phi)
    d = pow(e, -1, phi)
    return (e, n), (d, n)


def encrypt_rsa(pk, plaintext):
    # Unpack the key into it's components
    key, n = pk
    # Convert each letter in the plaintext to numbers based on the character using a^b mod m
    cipher = [pow(ord(char), key, n) for char in plaintext]
    # Return the array of bytes
    return cipher


def decrypt_rsa(pk, ciphertext):
    # Unpack the key into its components
    key, n = pk
    # Generate the plaintext based on the ciphertext and key using a^b mod m
    plain = [chr(pow(char, key, n)) for char in ciphertext]
    # Return the array of bytes as a string
    return ''.join(plain)

In [3]:
def add_client_info(public_key, key_ring, client_id):
    key_ring[client_id] = public_key
    
def get_client_info(key_ring, client_id):
    return key_ring[client_id]

In [4]:
def get_synchronized_time(clock_server_connection, id):
    #Time at which client sent a request to clockServer
    message = {
        "id": id,
        "message": "SYNC_TIME"
    }
    clock_server_connection.send(str.encode(json.dumps(message)))
    clock_server_request_time = datetime.datetime.now().timestamp()   #T0
    synced_time_str = clock_server_connection.recv(2048).decode()
    clock_data_recv = parser.parse(synced_time_str)
    #Time at which client received a response from the clockServer
    clock_server_response_time = datetime.datetime.now().timestamp() #T1
    latency = clock_server_response_time - clock_server_request_time
    client_time = clock_data_recv + datetime.timedelta(seconds= latency/2)
    return client_time.timestamp()

In [5]:

def multi_threaded_client(connection, clock_connection, address, id, public_key_ring, public_key):
    client_handshake_message_str = connection.recv(2048).decode('utf-8')
    client_handshake_message = json.loads(client_handshake_message_str)
    client_id = client_handshake_message['id']
    client_public_key = client_handshake_message['public_key']
    add_client_info(client_public_key, public_key_ring, client_id)
    server_handshake_message = {
        'id':id,
        'timestamp': get_synchronized_time(clock_connection, id),
        'public_key': public_key,
    }
    server_handshake_message_str = json.dumps(server_handshake_message)
    connection.sendall(str.encode(server_handshake_message_str))
    while True:
        client_request_message_encrypted_str = connection.recv(2048).decode('utf-8')
        if not client_request_message_encrypted_str:
            break
        else:
            client_request_message = json.loads(client_request_message_encrypted_str)
            client_id = client_request_message['id']
            encrypted_payload = client_request_message['encrypted_payload']
            decrypted_payload_str = decrypt_rsa(get_client_info(public_key_ring,client_id), encrypted_payload)
            decrypted_payload = json.loads(decrypted_payload_str)
            requested_id = decrypted_payload['requested_id']
            print('Public key request received from id: {} for id {} '.format(client_id, requested_id))
            requested_info = get_client_info(public_key_ring,requested_id)
            message_payload = {
                'requested_id': requested_id,
                'requested_public_key': requested_info,
            }
            encrypted_message_payload = encrypt_rsa(get_client_info(public_key_ring,client_id), json.dumps(message_payload))
            server_response_message = {
                'id':id,
                'timestamp': get_synchronized_time(clock_connection, id),
                'encrypted_payload': encrypted_message_payload,
            }
            connection.sendall(str.encode(json.dumps(server_response_message)))
    connection.close()

In [6]:
ServerSideSocket = socket.socket()
host = '127.0.0.1'
port = 6666
clock_port = 8080
ThreadCount = 0
id='kdc'
public_key, private_key = generate_rsa_key()
public_key_ring = {}

clock_server_connection = socket.socket()
clock_server_connection.connect((host, clock_port))

try:
    ServerSideSocket.bind((host, port))
except socket.error as e:
    print(str(e))
print('Socket is listening..')
ServerSideSocket.listen(5)

while True:
    Client, address = ServerSideSocket.accept()
    print('Connected to: ' + address[0] + ':' + str(address[1]))
    start_new_thread(multi_threaded_client, (Client, clock_server_connection, address, id, public_key_ring, public_key))
    ThreadCount += 1
    print('Thread Number: ' + str(ThreadCount))

Socket is listening..
Connected to: 127.0.0.1:52186
Thread Number: 1
Connected to: 127.0.0.1:52199
Thread Number: 2
Public key request received from id: 91a5fe2417c2dbd256209d3e75d677bb for id degree-granting-server 
Public key request received from id: degree-granting-server for id 91a5fe2417c2dbd256209d3e75d677bb 
