In [1]:
import socket
import os
from _thread import *
from random import randrange, getrandbits
from math import gcd
import json
import time

In [2]:
def is_prime(n, k=8):
    """ Test if a number is prime
        Args:
            n -- int -- the number to test
            k -- int -- the number of tests to do
        return True if n is prime
    """
    # Test if n is not even.
    # But care, 2 is prime !
    if n == 2 or n == 3:
        return True
    if n <= 1 or n % 2 == 0:
        return False
    # find r and s
    s = 0
    r = n - 1
    while r & 1 == 0:
        s += 1
        r //= 2
    # do k tests
    for _ in range(k):
        a = randrange(2, n - 1)
        x = pow(a, r, n)
        if x != 1 and x != n - 1:
            j = 1
            while j < s and x != n - 1:
                x = pow(x, 2, n)
                if x == 1:
                    return False
                j += 1
            if x != n - 1:
                return False
    return True

def generate_prime_candidate(length):
    """ Generate an odd integer randomly
        Args:
            length -- int -- the length of the number to generate, in bits
        return a integer
    """
    # generate random bits
    p = getrandbits(length)
    # apply a mask to set MSB and LSB to 1
    p |= (1 << length - 1) | 1
    return p
def generate_prime_number(prev, length=8):
    """ Generate a prime
        Args:
            length -- int -- length of the prime to generate, in          bits
        return a prime
    """
    p = 4
    # keep generating while the primality test fail
    while not is_prime(p, 4) or p == prev:
        p = generate_prime_candidate(length)
    return p


In [3]:
def generate_public_key(phi):
    e = 2
    while True:
        if gcd(e, phi) == 1:
            return e
        e += 1

In [4]:
def generate_key():
    p = generate_prime_number(1,8)
    q = generate_prime_number(p,8)
    print(p,q)
    n = p * q
    phi = (p-1) * (q-1)
    e = generate_public_key(phi)
    d = pow(e, -1, phi)
    return (e, n), (d, n)

In [5]:
def encrypt(pk, plaintext):
    # Unpack the key into it's components
    key, n = pk
    # Convert each letter in the plaintext to numbers based on the character using a^b mod m
    cipher = [pow(ord(char), key, n) for char in plaintext]
    # Return the array of bytes
    return cipher

In [6]:
def decrypt(pk, ciphertext):
    # Unpack the key into its components
    key, n = pk
    # Generate the plaintext based on the ciphertext and key using a^b mod m
    plain = [chr(pow(char, key, n)) for char in ciphertext]
    # Return the array of bytes as a string
    return ''.join(plain)

In [7]:
def generate_id():
    return os.urandom(16).hex()

def get_timestamp():
    return int(time.time())

def get_nonce():
    # generate a random 64-bit number between 2^32 and 2^64-1
    return randrange(2**16)

def get_nonce_acknowledgment(nonce_received):
    # XORing with a constant value twice gives back the original value
    return nonce_received ^ 0xFFFF 

def get_validity_timestamp():
    return get_timestamp() + 5*60*1000

def check_validity_message(valid_till):
    return valid_till > get_timestamp()


In [8]:
def add_client_info(client_name, public_key, client_info_dict, client_id, host_address, port_number):
    client_info_dict[client_name] = {
        "public_key": public_key,
        "client_id": client_id,
        "host_address": host_address,
        "port_number": port_number
    }

def get_client_info(client_name, client_info_dict):
    return client_info_dict[client_name]

In [9]:
def multi_threaded_fellow_client(connection, address):
    while True:
        data = connection.recv(1024)
        if not data:
            break
        print("Received from fellow client: ", data.decode())
        data = input(' -> ')
        connection.send(data.encode())
    connection.close()

In [10]:
def server_talk_thread(ClientMultiSocket, client_id, input_client_name, public_key_ring):
    while True:
        inputFromUser = input('Input client name for connecting : ')
        if inputFromUser == 'exit':
            break
        else:
            latest_message_nonce = get_nonce()
            client_request_message = {
                "type":'request',
                "id": client_id,
                'client_name':input_client_name,
                'requested_client': inputFromUser,
                'timestamp': get_timestamp(),
                'nonce':latest_message_nonce,
                'valid_till': get_validity_timestamp()
                # this message doesn't have an acknowledgement nonce because it is the first message
            }
            client_request_message_str = json.dumps(client_request_message)
            ClientMultiSocket.send(str.encode(client_request_message_str))
            while True:
                server_response_message_str = ClientMultiSocket.recv(131072).decode('utf-8')
                if not server_response_message_str:
                    break
                else:
                    server_response_message = json.loads(server_response_message_str)
                    if not check_validity_message(server_response_message['valid_till']):
                        print('Invalid message, message expired')
                    else:
                        print('Server responded to the public key request ')
                        encrypted_message = server_response_message['encrypted_message']
                        pkda_info = get_client_info('pkda', public_key_ring)
                        decrypted_message = decrypt(pkda_info['public_key'], encrypted_message)
                        response_payload = json.loads(decrypted_message)
                        requested_client_info = response_payload['requested_client_info']
                        add_client_info(inputFromUser, requested_client_info['public_key'], public_key_ring, requested_client_info['client_id'], requested_client_info['host_address'], requested_client_info['port_number'])
                        print('Client info added to the public key ring', requested_client_info['host_address'], requested_client_info['port_number'], public_key_ring)

In [11]:
ClientMultiSocket = socket.socket()
host = '127.0.0.1'
port = 3000
# generate random number greater than 1024 for the port number
self_port = 3000
while self_port <= 1024 or self_port == 3000: 
    self_port = randrange(2**16)
print("Self port number: ", self_port)
input_client_name = input('Input your client name for identification: ')
public_key, private_key = generate_key()
id = generate_id()
public_key_ring = {}
## public_key_ring structure
# {
#     'client_name': { # unique string name of the client
#         'host_address': host_address, # host address of the client
#         'port_number': port, # port of the client to connect to
#         'public_key': public_key, # public key of the client
#         'id': client_id # unique 16 byte id of the client
#     }
# }
try:
    ClientMultiSocket.connect((host, port))
except socket.error as e:
    print(str(e))

## message structure
# {
#     "type":'handshake', # type of message
#     'id':'', # unique id of the sender
#     'client_name': '', # name of the sender
#     'public_key': '', # public key of the sender
#     'timestamp': '', # timestamp of the message
#     'nonce': '', # nonce of the message
#     'valid_till': '', # validity timestamp of the message
#     'acknowledgement_nonce': '' # f(previous_message_nonce) XORed with a constant value
# }

## first message in the channel 
latest_message_nonce = get_nonce()
client_handshake_message = {
    "type":'handshake',
    "id": id,
    'client_name': input_client_name,
    "public_key": public_key,
    "client_host_address": host,
    "client_port_number": self_port,
    'timestamp': get_timestamp(),
    'nonce':latest_message_nonce,
    'valid_till': get_validity_timestamp()
    # this message doesn't have an acknowledgement nonce because it is the first message
}

client_handshake_message_str = json.dumps(client_handshake_message)
ClientMultiSocket.send(str.encode(client_handshake_message_str))

server_handshake_message_str = ClientMultiSocket.recv(2048).decode('utf-8')
server_handshake_message = json.loads(server_handshake_message_str)

if not check_validity_message(server_handshake_message['valid_till']):
    print('Invalid message, message expired')
else:
    print(server_handshake_message)
    acknowledgement_nonce = server_handshake_message['acknowledgement_nonce']
    if latest_message_nonce != get_nonce_acknowledgment(acknowledgement_nonce):
        print('The server did not acknowledge the handshake message, nonce mismatch')
    else:
        print('Handshake successful')
        add_client_info('pkda', server_handshake_message["public_key"], public_key_ring, server_handshake_message["id"], host, port)

ServerMultiSocket = socket.socket()
# take user input
try:
    ServerMultiSocket.bind((host, self_port))
except socket.error as e:
    print(str(e))
print('Client is accepting connections now....')
ServerMultiSocket.listen(5)
start_new_thread(server_talk_thread, (ClientMultiSocket, id, input_client_name, public_key_ring))
while True:
    Client, address = ServerMultiSocket.accept()
    print('Connected to fellow client: ' + address[0] + ':' + str(address[1]))
    start_new_thread(multi_threaded_fellow_client, (Client, address))
ClientMultiSocket.close()


Self port number:  19769
{'type': 'handshake', 'id': 'public key distribution authority', 'public_key': [3, 121], 'timestamp': 1680410876, 'nonce': 42274, 'acknowledgement_nonce': 5213, 'valid_till': 1680710876}
Handshake successful
Client is accepting connections now....


Exception ignored in thread started by: <function server_talk_thread at 0x00000250BE149E40>
Traceback (most recent call last):
  File "C:\Users\SAATVIK\AppData\Local\Temp\ipykernel_9824\2669586614.py", line 21, in server_talk_thread
ConnectionAbortedError: [WinError 10053] An established connection was aborted by the software in your host machine
