# Dependencies

In [1]:
!pip3 install pyzmq cryptography sympy

Defaulting to user installation because normal site-packages is not writeable


# Initialization

In [1]:
from cryptography.fernet import Fernet
import sympy
import random
import zmq
import threading

In [2]:
# math
BITS_OF_PRIME_GROUP = 64

# socket
LOCAL_PORT = 4080
SERVER_HOST = "localhost"
SERVER_PORT = 4080

# Basics

## Primes and Modular arithmetics

In [3]:
def get_random_prime():
    """
    Generate a random prime number of the specified number of bits.
    """

    lower_bound = 2**(BITS_OF_PRIME_GROUP - 1)
    upper_bound = 2**BITS_OF_PRIME_GROUP - 1

    prime = sympy.randprime(lower_bound, upper_bound)
    
    return prime

def find_generator(prime):
    """
    For every prime divisor q of p -1 check if g^{(p-1)/q} ≡ 1 mod p. If that happens, discard g
    If it survives till the last prime divisor of p - 1 is a generator.
    """

    factors = sympy.primefactors(prime - 1)

    while True:
        candidate = random.randint(1, prime - 1)

        for factor in factors:
            if pow(candidate, (prime - 1) // factor, prime) == 1:
                break

        else:
            return candidate

## Secret Key Encryption

The cryptography library contains Fernet, an implementation of symmetric (secret key) authenticated cryptography.

In [4]:
def secret_generate_key():
    """
    Generate a secret key.
    """

    key = Fernet.generate_key()

    return key

def secret_encrypt(msg, key):
    """
    Encrypts a message with a key.
    https://cryptography.io/en/latest/fernet/
    """

    # initialize
    f = Fernet(key)

    # encrypt
    token = f.encrypt(msg)

    return token

def secret_decrypt(token, key):
    """
    Decrypts a message with a key.
    https://cryptography.io/en/latest/fernet/
    """

    # initialize
    f = Fernet(key)

    # decrypt
    msg = f.decrypt(token)

    return msg

In [5]:
test_string = b"Hello, World!"

# generate a random key
key = secret_generate_key()

# encrypt the test string
token = secret_encrypt(test_string, key)

# decrypt the test string
msg = secret_decrypt(token, key)

print("Original message: ", test_string)
print("Encrypted message: ", token)
print("Decrypted message: ", msg)
print("Key: ", key)

Original message:  b'Hello, World!'
Encrypted message:  b'gAAAAABnmnn2qSR56vtZYP18YI50-t6rMt3XLQmvVqohF_XilhxCiwCTxJbsjsIEzxogegOwSauADNgFmkuS42vkkRz4M_z3PQ=='
Decrypted message:  b'Hello, World!'
Key:  b'ysZYET0OwnuekzOMFT9cPfwU-8ntlvwCQzi61CUe7Es='


## Public Key Encryption

Implementation of ElGamal encryption scheme.

In [13]:
###############################################
# IMPLEMENTATION OF ELGAMAL ENCRYPTION SCHEME #
###############################################

def elgamal_generate_keys():
    """
    Generate public and private keys for the ElGamal encryption scheme.
    The naming scheme follows what has been used in the course slides.
    """

    # generate random prime that defines the group
    p = get_random_prime()

    # find a generator for the group
    g = find_generator(p)

    # generate a private key as a random number
    x = random.randint(1, p - 1)

    # calculate the public key as g^x mod p
    h = pow(g, x, p)

    return (g, p, h), x

def elgamal_encrypt(msg, public_key):
    """
    Encrypt a message with the ElGamal encryption scheme.
    """

    # from bytes to integer
    msg = int.from_bytes(msg, byteorder="big")

    # unpack public key
    g, p, h = public_key

    # generate a random number
    r = random.randint(1, p - 1)

    # calculate the two powers of the ciphertext
    gr = pow(g, r, p)
    hr = pow(h, r, p)

    # calculate the ciphertext
    c = (gr, hr * msg % p)

    return c

def elgamal_decrypt(c, private_key, public_key):
    """
    Decrypt a message with the ElGamal encryption scheme.
    """

    # unpack private key
    x = private_key

    # unpack public key
    g, p, h = public_key

    # unpack ciphertext
    c1, c2 = c

    # decrypt message
    m = (c2 * pow(c1, p - 1 - x, p)) % p

    # from integer to bytes
    m = m.to_bytes((m.bit_length() + 7) // 8, byteorder="big")

    return m

In [14]:
#####################################
# TEST OF ELGAMAL ENCRYPTION SCHEME #
#####################################

# generate keys
public_key, private_key = elgamal_generate_keys()

# message
message = b"Hello"

# encrypt message
ciphertext = elgamal_encrypt(message, public_key)

# decrypt message
retrieved = elgamal_decrypt(ciphertext, private_key, public_key)


print("Original message:", message.decode("utf-8"))
print("Encrypted message:", ciphertext)
print("Decrypted message:", retrieved.decode("utf-8"))

Original message: Hello
Encrypted message: (2274466495891982890, 1068617230306194410)
Decrypted message: Hello


## Socket Communication

In [8]:
class Socket:
    def __init__(self, socket_type, address):
        """
        Initialize a ZeroMQ socket.
        socket_type can be zmq.REQ, zmq.REP, zmq.PUB, zmq.SUB
        """

        # init socket settings
        self.context = zmq.Context.instance()
        self.socket = self.context.socket(socket_type)
        self.address = address

        if socket_type == zmq.REP:
            self.bind()
        elif socket_type == zmq.REQ:
            self.connect()
    
    def bind(self):
        """
        Bind the socket to given address.
        """

        self.socket.bind(self.address)
    
    def connect(self):
        """
        Connect the socket to a given address.
        """

        self.socket.connect(self.address)
    
    def send(self, msg):
        """
        Send a message through the socket.
        """

        self.socket.send_pyobj(msg)
    
    def receive(self):
        """
        Receive a message from the socket.
        """

        return self.socket.recv_pyobj()
    
    def close(self):
        """
        Close the socket.
        """

        self.socket.close()
    
    def __del__(self):
        """
        Destructor to close the socket.
        """

        self.close()

In [9]:
server = Socket(socket_type=zmq.REP, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}")
client = Socket(socket_type=zmq.REQ, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}")

client.send("Hello, Server!")
print("Client received:", server.receive())

server.send("Hello, Client!")
print("Server received:", client.receive())

server.close()
client.close()

Client received: Hello, Server!
Server received: Hello, Client!


# Oblivious Transfer - Passive Security

Implementation of a 1-out-of-2 Oblivious Transfer protocol with passive security.

In [10]:
class ObliviousTransfer:
    def __init__(self, socket):
        # to send and receive messages
        self.socket = socket

    def receiver_side(self, choice):
        """
        Implement Bob's side of the oblivious transfer protocol with passive security.
        """

        if choice == 0: # Bob wants message 0
            pk0, sk0 = elgamal_generate_keys()
            pk1, _ = elgamal_generate_keys() # oblivious of secret key
            
        else: # Bob wants message 1
            pk0, _ = elgamal_generate_keys() # oblivious of secret key
            pk1, sk1 = elgamal_generate_keys()

        # send public keys to Alice
        self.socket.send((pk0, pk1))

        # receive the ciphertexts
        c0, c1 = self.socket.receive()

        # delete socket
        del self.socket

        # decrypt the chosen ciphertext
        if choice == 0:
            m = elgamal_decrypt(c0, sk0, pk0)
        else:
            m = elgamal_decrypt(c1, sk1, pk1)

        return m
    
    def sender_side(self, msg0, msg1):
        """
        Implement Alice's side of the oblivious transfer protocol with passive security.
        """

        # get the public keys from Bob
        h0, h1 = self.socket.receive()

        # encrypt the messages
        c0 = elgamal_encrypt(msg0, h0)
        c1 = elgamal_encrypt(msg1, h1)

        # send the ciphertexts to Bob
        self.socket.send((c0, c1))

        # delete socket
        del self.socket

In [15]:
# Example of Oblivious Transfer with passive security

socket_bob = Socket(socket_type=zmq.REQ, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}")
socket_alice = Socket(socket_type=zmq.REP, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}")

def run_bob(socket_bob):
    # initialize Bob's side
    ot_bob = ObliviousTransfer(socket_bob)

    # Bob wants message 1
    choice = 1

    # run Bob's side of the protocol
    received_message = ot_bob.receiver_side(choice)

    # show the received message
    print("Bob received:", received_message.decode("utf-8"))

    # close the socket
    socket_bob.close()

def run_alice(socket_alice):
    # initialize Alice's side
    ot_alice = ObliviousTransfer(socket_alice)

    # get Alice's messages
    msg0 = b"Hello"
    msg1 = b"World"

    # run Alice's side of the protocol
    ot_alice.sender_side(msg0, msg1)

    # close the socket
    socket_alice.close()

# run Alice and Bob in separate threads
bob_thread = threading.Thread(target=run_bob, args=(socket_bob,))
alice_thread = threading.Thread(target=run_alice, args=(socket_alice,))

bob_thread.start()
alice_thread.start()

bob_thread.join()
alice_thread.join()

Bob received: World
