## Dependencies

Three libraries are required:
* ZeroMQ for socket communication
* Sympy for working with prime numbers
* PyCryptoDome for AES implementation

In [48]:
!pip3 install pyzmq sympy pycryptodome

Defaulting to user installation because normal site-packages is not writeable


# Initialization

In [49]:
import random
import base64
import threading

import sympy
import zmq
from Crypto.Cipher import AES
from Crypto.Util import Counter
from Crypto.Random import get_random_bytes

In [50]:
# math
BITS_OF_PRIME_GROUP = 512
BITS_GARBLED_TABLES = 16

# socket
LOCAL_PORT = 4080
SERVER_HOST = "localhost"
SERVER_PORT = 4080

# Basics and Primitives

## Primes numbers, Modular arithmetics, Bitwise operations

Basic operations that used to handle prime numbers, modular arithmetics for groups and bitwise operations.

In [51]:
def get_random_prime():
    """
    Generate a random prime number of amount of bits specified number
    in the constant defined above.
    """

    lower_bound = 2**(BITS_OF_PRIME_GROUP - 1)
    upper_bound = 2**BITS_OF_PRIME_GROUP - 1

    prime = sympy.randprime(lower_bound, upper_bound)
    
    return prime

def find_generator(prime):
    """
    Given a prime number find a generator for its group. Also some tests to avoid
    generators for which some attacks are known are performed. Implementation of
    this is inspired from the implementation in PyCryptoDome:

    https://github.com/Legrandin/pycryptodome/blob/master/lib/Crypto/PublicKey/ElGamal.py    
    """

    while 1:
        # generate candidate
        candidate = pow(random.randint(2, prime - 1), 2, prime)

        # avoid g=2 because of Bleichenbacher's attack described # in "Generating
        # ElGamal signatures without knowning the secret key", 1996
        if candidate in (1, 2):
            continue

        # Discard g if it divides p-1 because of the attack described
        # in Note 11.67 (iii) in HAC
        if (prime - 1) % candidate == 0:
            continue

        # g^{-1} must not divide p-1 because of Khadir's attack described in
        # "Conditions of the generator for forging ElGamal signature", 2011
        ginv = inverse(candidate, prime)
        if (prime - 1) % ginv == 0:
            continue

        # if found exit from the cycle
        break

    return candidate

def inverse(num, prime):
    """
    Implementation of the nultiplicative inverse operation in
    modular arithmetic.
    """

    return pow(num, prime - 2, prime)

def bytes_to_int(x):
    """
    Function that converts a sequence of bytes into an int.
    """

    return int.from_bytes(x, byteorder="big")

def int_to_bytes(x, length = None):
    """
    Function that converts an int into a sequence of bytes.
    """

    if length is None:
        length = (x.bit_length() + 7) // 8
    return x.to_bytes(length, byteorder='big')

def bitwise_xor(x, y):
    """
    Implementation of the bitwise XOR given two byte sequences.
    """

    x_int = int.from_bytes(x, byteorder="big")
    y_int = int.from_bytes(y, byteorder="big")

    return (x_int ^ y_int).to_bytes(len(x), byteorder="big")

## Secret Key Encryption

The Secret Key Encryption used exploits AES in CTR mode.

In [52]:
###############################################
# IMPLEMENTATION OF AES KEY ENCRYPTION SCHEME #
###############################################

def aes_generate_key():
    """
    Generate a 256-bit (32-byte) secret key for AES encryption.
    https://github.com/Legrandin/pycryptodome/blob/master/lib/Crypto/Cipher/AES.py
    https://pycryptodome.readthedocs.io/en/latest/src/cipher/aes.html
    """

    # AES-256 requires a 32-byte key
    key = get_random_bytes(32)

    return key

def pad(msg):
    """
    Pad message to be a multiple of AES block size (16 bytes).
    """

    # get how much padding
    padding_length = AES.block_size - len(msg) % AES.block_size

    # apply padding
    return msg + bytes([padding_length] * padding_length)

def unpad(msg):
    """
    Remove padding from decrypted message.
    """

    return msg[:-msg[-1]]

def aes_encrypt(msg, key):
    """
    Encrypts a message with a key using AES-256 in CTR mode.
    """

    # pad message
    msg = pad(msg)

    if len(key) < 32:
        print("KEY TOO SHORT")
        print(f"key: {key}")

    # init AES
    cipher = AES.new(key, AES.MODE_CTR)

    # encrypt message
    encrypted_msg = cipher.encrypt(msg)

    # return concatenation with nonce
    return cipher.nonce + encrypted_msg
    

def aes_decrypt(token, key):
    """
    Decrypts a message with a key using AES-256 in CBC mode.
    """

    # extract nonce and encrypted message
    nonce = token[:AES.block_size // 2]
    encrypted_msg = token[AES.block_size // 2:]

    # init AES
    cipher = AES.new(key, AES.MODE_CTR, nonce=nonce)

    # decrypt
    decrypted_msg = cipher.decrypt(encrypted_msg)

    # unpad
    decrypted_msg = unpad(decrypted_msg)
    return decrypted_msg

In [53]:
#####################################
# TEST OF AES KEY ENCRYPTION SCHEME #
#####################################

# test message
input_msg = b"Hello, World!"

# generate random key
key = aes_generate_key()

# encrypt the test message
token = aes_encrypt(input_msg, key)

# decrypt the test message
output_msg = aes_decrypt(token, key)

# show results
print("Original message: ", input_msg.decode("utf-8"))
print("Encrypted message: ", base64.b64encode(token))
print("Decrypted message: ", output_msg.decode("utf-8"))

Original message:  Hello, World!
Encrypted message:  b'Wd043I+kSoIWSgATm9MDWo293IbMlEJs'
Decrypted message:  Hello, World!


## Verifiable Encryption

Expoits SKE to create the encryption scheme defined in slides 67-68 of the Multi-Party Computation chapter of the course. This "veriable encryption" is an ecryption scheme for which it is easy to test if the ciphertext has been decrypted with the correct key.

Given a vector $x$ of length $n$ we want to encrypt we pick a random noise vector $r$ of the same length and apply to it a pseudo-random function $F(\cdot, k)$ that maps $n$ bits to $2n$ bits. The output $s$ of this function is then XORed with the vector $x$ concatenated with other $n$ zeros. The ciphertext is $c=(r,s)$.

One can verify if a key $k$ is correct knowing $r$ and $s$: first $F(k, r)$ is computed and then XORed with $s$ to perform the decryption. If the last n bits of the output are all zeros the key was correct.

In [54]:
###########################################
# IMPLEMENTATION OF VERIFIABLE ENCRYPTION #
###########################################

def aes_as_doubling_prf(msg, key):
    """
    Use AES as a pseudo-random function that doubles the bits in input.
    """

    if len(key) < 16:
        print("KEY TOO SHORT")
        print(f"key: {key}")

    # init counters
    counter1 = Counter.new(112, prefix=b'\x00\x01', initial_value=1)
    counter2 = Counter.new(112, prefix=b'\x00\x02', initial_value=2)

    # init AES
    cipher1 = AES.new(key, AES.MODE_CTR, counter=counter1)
    cipher2 = AES.new(key, AES.MODE_CTR, counter=counter2)

    # encrypt message
    encrypted_msg1 = cipher1.encrypt(msg)
    encrypted_msg2 = cipher2.encrypt(msg)

    # return concatenation
    return encrypted_msg1 + encrypted_msg2

def verifiable_encrypt(x, key):
    """
    Perform encryption verifiable by adding padding with zeros in such a way that when
    decrypting you can understand if the decryption has been successful or not.

    Nomenclature is the same used in the slides.
    """

    # create random noise
    r = get_random_bytes(len(x))

    # add padding
    x += b'\x00' * len(x)

    # compute PRF
    prf_output = aes_as_doubling_prf(r, key)

    # encrypt message
    s = bitwise_xor(prf_output, x)

    # return concatenation
    return r + s

def verifiable_decrypt(c, key):
    """
    Perform verifiable decryption

    Nomenclature is the same used in the slides.
    """

    # unpack inputs
    r = c[:len(c) // 3]
    s = c[len(c) // 3:]

    # recompute PRF
    prf_output = aes_as_doubling_prf(r, key)

    # decrypt function
    x = bitwise_xor(prf_output, s)


    # check if decryption is correct, if not return None

    n = len(x) // 2
    if x[n:] == b'\x00' * n:
        return x[:n]

    else:
        return None

In [55]:
#################################
# TEST OF VERIFIABLE ENCRYPTION #
#################################

# generate random message
msg = get_random_bytes(16)

# generate two keys
key1 = get_random_bytes(32)
key2 = get_random_bytes(32)

# encrypt with key1
ciphertext = verifiable_encrypt(msg, key1)

# decrypt with both keys
decryption1 = verifiable_decrypt(ciphertext, key1)
decryption2 = verifiable_decrypt(ciphertext, key2)

# show results
print(f"Original message: {msg}")
print(f"Decryption with correct key: {decryption1}")
print(f"Decryption with wrong key: {decryption2}")

Original message: b'y?)d\xba\x8b\xe2B\xbd\xe4v\xc9\xbb\xc9J\x10'
Decryption with correct key: b'y?)d\xba\x8b\xe2B\xbd\xe4v\xc9\xbb\xc9J\x10'
Decryption with wrong key: None


## Double Verifiable Encryption

Implement the encryption needed for the Garbled Circuits's Garbled Tables.

In [56]:
####################################################
# IMPLEMENTATION OF DOUBLE VERIFIABLE ENCRCRYPTION #
####################################################

def double_verifiable_encrypt(msg, key_x, key_y):
    """
    Function for verifiable double encryption that will be needed for creating the
    garbled tables. The order of encryption is first y then x.
    """

    # first perform encryption with y
    mid_ciphertext = verifiable_encrypt(msg, key_y)

    # then perform encryption with x
    return verifiable_encrypt(mid_ciphertext, key_x)

def double_verifiable_decrypt(ciphertext, key_x, key_y):
    """
    Function for verifiable double decryption that will be needed for computing the
    output of the garbled tables. The order of decryption is first x then y.
    """

    # first perform decryption with x
    mid_ciphertext = verifiable_decrypt(ciphertext, key_x)

    if mid_ciphertext is None:
        # if first decryption goes wrong return None
        return None
    
    else:
        # then perform decryption wih y
        return verifiable_decrypt(mid_ciphertext, key_y)

In [57]:
########################################
# TEST OF VERIFIABLE DOUBLE ENCRYPTION #
########################################

# generate random message
msg = get_random_bytes(16)

# generate two keys
key1 = get_random_bytes(32)
key2 = get_random_bytes(32)

# perform double encryption
ciphertext = double_verifiable_encrypt(msg, key1, key2)

# perform double decryption
decryption = double_verifiable_decrypt(ciphertext, key1, key2)

# show results
print(f"Original message: {msg}")
print(f"Decryption with correct key: {decryption}")

Original message: b'\x94M\xceE\xd05J\xb5S7I\x11x\xf0\x1f\x80'
Decryption with correct key: b'\x94M\xceE\xd05J\xb5S7I\x11x\xf0\x1f\x80'


## Public Key Encryption

Implementation of ElGamal encryption scheme as described in the second chapter of the course.

In [58]:
###############################################
# IMPLEMENTATION OF ELGAMAL ENCRYPTION SCHEME #
###############################################

def elgamal_prime_and_generator():
    """
    Generate random prime and genrator for ElGamal encryption scheme.
    The naming scheme follows what has been used in the course slides.
    """

    # generate random prime that defines the group
    p = get_random_prime()

    # find a generator for the group
    g = find_generator(p)

    return p, g

def elgamal_generate_keys(p_g = None):
    """
    Generate public and private keys for the ElGamal encryption scheme.
    The naming scheme follows what has been used in the course slides.
    """

    # get prime and generator
    if p_g is None:
        p, g = elgamal_prime_and_generator()

    else:
        p, g = p_g

    # generate a private key as a random number
    x = random.randint(1, p - 1)

    # calculate the public key as g^x mod p
    h = pow(g, x, p)

    return int_to_bytes(g) + int_to_bytes(p) + int_to_bytes(h), int_to_bytes(x)

def elgamal_encrypt(msg, public_key):
    """
    Encrypt a message with the ElGamal encryption scheme.
    """

    # unpack public key
    g = bytes_to_int(public_key[: len(public_key) // 3])
    p = bytes_to_int(public_key[len(public_key) // 3 : 2 * len(public_key) // 3])
    h = bytes_to_int(public_key[len(public_key) // 3 :])

    # from bytes to integer
    msg = bytes_to_int(msg)

    # generate a random number
    r = random.randint(1, p - 1)

    # calculate the two powers of the ciphertext
    gr = pow(g, r, p)
    hr = pow(h, r, p)

    # calculate the ciphertext
    return int_to_bytes(gr) + int_to_bytes((hr * msg) % p)

def elgamal_decrypt(c, private_key, public_key):
    """
    Decrypt a message with the ElGamal encryption scheme.
    """

    # unpack private key
    x = bytes_to_int(private_key)

    # unpack used part of public key
    p = bytes_to_int(public_key[len(public_key) // 3 : 2 * len(public_key) // 3])

    # unpack ciphertext
    c1 = bytes_to_int(c[: len(c) // 2])
    c2 = bytes_to_int(c[len(c) // 2 :])

    # decrypt message
    m = (c2 * inverse(pow(c1, x, p), p)) % p

    # from integer to bytes
    return int_to_bytes(m)

In [59]:
#####################################
# TEST OF ELGAMAL ENCRYPTION SCHEME #
#####################################

# test message
input_msg = b"Hello, World!"

# generate keys
public_key, private_key = elgamal_generate_keys()

# encrypt message
ciphertext = elgamal_encrypt(input_msg, public_key)

# decrypt message
output_msg = elgamal_decrypt(ciphertext, private_key, public_key)

# show results
print("Original message:", input_msg.decode())
print("Encrypted message:", ciphertext)
print("Decrypted message:", output_msg)

Original message: Hello, World!
Encrypted message: b'a\xcc\x1f\xe2\xbe\x9d\x98\xb3\xcb\x81l\x9a/\xa0\x18\x10\x9e\xb3A\xcfR,\x87$\xd8\xb9\x9a\xb3(-\x1d\x14\xf0@\x83.\xe1#5\xec2s\x0c\x94)xf\xa0\x03L\x0e\x86\x8c\xf4J\xdc\x99\xd0o[l\xb5\xce\xad\xc6\xfb\xe0<Bo\xb5\xba\xc3\x93G\xa9\xfa\xb6\xbcV#j\x01\xc1:e{V\xfdY\xc6\xa04\x05\xa4\xc3E\x16\xce\xc8z\xb9\x8b\xf7\x06\xde?\xeb\x90\x97Q0\xe1R\xb5\x8c\x1bEqM_G\x10\x9b\x92;n}'
Decrypted message: b'Hello, World!'


In [60]:
input_msg = b"Hello, World!"

p_g = elgamal_prime_and_generator()

public_key1, private_key1 = elgamal_generate_keys(p_g)
public_key2, private_key2 = elgamal_generate_keys(p_g)

# encrypt message
ciphertext = elgamal_encrypt(input_msg, public_key1)

# decrypt message
output_msg1 = elgamal_decrypt(ciphertext, private_key1, public_key1)
output_msg2 = elgamal_decrypt(ciphertext, private_key2, public_key2)

print(input_msg)
print(output_msg1)
print(output_msg2)

b'Hello, World!'
b'Hello, World!'
b'\x1b\xaeN\xb2o/\x9f\xf4r\xc0I\x96\x14F\xfb\x7fE\xfd\xfc-\x08\x12\xc1\xd2\xc3@\xbb\xd8\x15\x07\xa9\xacL\x1c\xd8\xd8\xe9\x9a\xf9\xcd\xf9\x05\xdd\x1b\xc1\xe4\x01\xa9\x9e\xd6J+\x0bV\x15\x8e#\xfc\xbd\xddA\xd2;\x1c'


## Socket Communication

Implementation of message exchange between sockets to simulate the communication between two users for Multi-Party Computation.

In [61]:
##########################################
# IMPLEMENTATION OF SOCKET COMMUNICATION #
##########################################

class ReturningThread(threading.Thread):
    """
    Custom Thread class to return the output values of the function.
    """

    def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, verbose=None):
        super().__init__(group, target, name, args, kwargs)
        self._return = None

    def run(self):
        if self._target is not None:
            self._return = self._target(*self._args, **self._kwargs)

    def join(self):
        super().join()
        return self._return

class Socket:
    def __init__(self, socket_type, address, is_server):
        """
        Initialize a ZeroMQ Socket.
        socket_type can be zmq.REQ, zmq.REP, zmq.PUB, zmq.SUB
        """

        # init socket settings
        self.context = zmq.Context.instance()
        self.socket = self.context.socket(socket_type)
        self.address = address

        if is_server:
            self.socket.bind(self.address)
        else:
            self.socket.connect(self.address)
    
    def send(self, msg):
        """
        Send a message through the socket.
        """

        self.socket.send_pyobj(msg)
    
    def receive(self):
        """
        Receive a message from the socket.
        """

        return self.socket.recv_pyobj()
    
    def close(self):
        """
        Close the socket.
        """

        self.socket.close()

def start_sockets():
    server = Socket(socket_type=zmq.PAIR, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}", is_server=True)
    client = Socket(socket_type=zmq.PAIR, address=f"tcp://{SERVER_HOST}:{SERVER_PORT}", is_server=False)

    return server, client

In [62]:
################################
# TEST OF SOCKET COMMUNICATION #
################################

# create a server and a client sockets
server, client = start_sockets()

# send a message from client to server
client.send("Hello, Server!")
print("Client received:", server.receive())

# send a message from server to client
server.send("Hello, Client!")
print("Server received:", client.receive())

# close the sockets
server.close()
client.close()

Client received: Hello, Server!
Server received: Hello, Client!


## Circuit parsing

Function to convert files containing circuits into a format usable for the next code. The circuits are contained inside text files in a format called "Bristol Format", which was designed exactly for Multi-Party Computation. More information about the Bristol Format and some example circuits are available here: https://nigelsmart.github.io/MPC-Circuits/old-circuits.html

The circuits used for examples in this work are taken from: https://tudatalib.ulb.tu-darmstadt.de/handle/tudatalib/3776

In [63]:
#####################################
# IMPLEMENTATION OF CIRCUIT PARSING #
#####################################

def parse_circuit_file(filepath):
    """
    Parse a circuit written inside a text file into a dictionary of informations
    about the circuit itself and an ordered list of gates.
    """

    # open and read file
    with open(filepath, 'r') as file:
        lines = file.readlines()

    # first line contains amount of gates and of wires
    n_gates, n_wires = map(int, lines[0].split())

    # second line contains amount of wires of input1, input2 and output 
    n_input1, n_input2, n_output = map(int, lines[1].split())

    # process gates lines
    gates = []

    for line in lines[2:]:
        parts = line.split()

        if len(parts) == 0: # if empty line
            continue

        elif parts[0] == 1: # if gate with 1 input (INV)
            # get components
            _, _, id_input, id_output, op = parts

            # converts ids to int
            id_input, id_output = map(int, (id_input, id_output))

            # append gate
            gates.append((op, (id_input, id_output)))

        else: # if gate with 2 inputs
            # get components
            _, _, id_input1, id_input2, id_output, op = parts

            # converts ids to int
            id_input1, id_input2, id_output = map(int, (id_input1, id_input2, id_output))

            # append gate
            gates.append((op, (id_input1, id_input2, id_output)))

    # return dictionary of informationa about the circuit
    return {"n_gates": n_gates,
            "n_wires": n_wires,
            "n_input1": n_input1,
            "n_input2": n_input2,
            "n_output": n_output,
            "gates": gates}

In [64]:
##############################
# EXAMPLE OF CIRCUIT PARSING #
##############################

filepath = r"circuits\fullAdder.bristol"

circuit = parse_circuit_file(filepath)
circuit

{'n_gates': 5,
 'n_wires': 8,
 'n_input1': 2,
 'n_input2': 1,
 'n_output': 2,
 'gates': [('XOR', (0, 1, 3)),
  ('XOR', (2, 3, 7)),
  ('AND', (2, 3, 4)),
  ('AND', (0, 1, 5)),
  ('OR', (4, 5, 6))]}

# Core Algorithms

## Oblivious Transfer

### Passive Security

Implementation of a 1-out-of-2 Oblivious Transfer protocol with Passive Security as described in the slides of the course.

In [65]:
##############################################################
# IMPLEMENTATION OF OBLIVIOUS TRANSFER WITH PASSIVE SECURITY #
##############################################################

class ObliviousTransfer_PS:
    """
    Implement the oblivious transfer protocol with passive security.

    In passive security, Bob sends both public keys to Alice, the one he wants to receive
    normally and the other obliviously. Alice sends the ciphertexts of both messages to Bob,
    encrypted with the corresponding public key. Bob decrypts the chosen ciphertext,
    receiving the message he wants.
    """

    def __init__(self, socket):
        # to send and receive messages
        self.socket = socket

    def receiver_side(self, choice):
        """
        Implement Bob's side of the oblivious transfer protocol with passive security.
        """

        # Bob wants message 0
        if choice == 0:
            pk0, sk0 = elgamal_generate_keys()
            pk1, _ = elgamal_generate_keys() # oblivious of secret key
            
        # Bob wants message 1
        else:
            pk0, _ = elgamal_generate_keys() # oblivious of secret key
            pk1, sk1 = elgamal_generate_keys()

        # send public keys to Alice
        self.socket.send((pk0, pk1))

        # receive the ciphertexts
        c0, c1 = self.socket.receive()

        # decrypt the chosen ciphertext
        if choice == 0:
            m = elgamal_decrypt(c0, sk0, pk0)
        else:
            m = elgamal_decrypt(c1, sk1, pk1)

        return m
    
    def sender_side(self, msg0, msg1):
        """
        Implement Alice's side of the oblivious transfer protocol with passive security.
        """

        # get the public keys from Bob
        h0, h1 = self.socket.receive()

        # encrypt the messages
        c0 = elgamal_encrypt(msg0, h0)
        c1 = elgamal_encrypt(msg1, h1)

        # send the ciphertexts to Bob
        self.socket.send((c0, c1))

In [66]:
#######################################################
# EXAMPLE OF OBLIVIOUS TRANSFER WITH PASSIVE SECURITY #
#######################################################

socket_alice, socket_bob = start_sockets()

def run_bob(socket_bob):
    # initialize Bob's side
    ot_bob = ObliviousTransfer_PS(socket_bob)

    # Bob wants message 1
    choice = 1

    # run Bob's side of the protocol
    received_message = ot_bob.receiver_side(choice).decode("utf-8")

    # close the socket
    socket_bob.close()

    return received_message

def run_alice(socket_alice):
    # initialize Alice's side
    ot_alice = ObliviousTransfer_PS(socket_alice)

    # get Alice's messages
    msg0 = b"Hello"
    msg1 = b"World"

    # run Alice's side of the protocol
    ot_alice.sender_side(msg0, msg1)

    # close the socket
    socket_alice.close()

# run Alice and Bob in separate threads
bob_thread = ReturningThread(target=run_bob, args=(socket_bob,))
alice_thread = ReturningThread(target=run_alice, args=(socket_alice,))

bob_thread.start()
alice_thread.start()

output = bob_thread.join()
alice_thread.join()

print(f"Bob received: {output}")

Bob received: World


### Active Security

Implementation of a 1-out-of-2 Oblivious Transfer protocol with active security as described in the slides of the course. This is the protocol that will be used to give to Bob its inputs for the Garbled Circuit.

In [67]:
#############################################################
# IMPLEMENTATION OF OBLIVIOUS TRANSFER WITH ACTIVE SECURITY #
#############################################################

class ObliviousTransfer_AS:
    """
    Implement the oblivious transfer protocol with active security.

    In active security, Alice sends two public keys to Bob, Bob will use the one linked to the message
    he wants to recieve to encrypt his key. Alice will decrypt it using both keys, obtaining two different
    keys. Alice will then encrypt both messages with the two keys and send them to Bob. Bob will be able
    to decrypt only the message linked to the key he chose.
    """

    def __init__(self, socket):
        # to send and receive messages
        self.socket = socket

    def sender_side(self, msg0, msg1):
        """
        Implement Alice's side of the oblivious transfer protocol with active security.
        """

        # generate two key pairs
        p_g = elgamal_prime_and_generator()
        pk0, sk0 = elgamal_generate_keys(p_g)
        pk1, sk1 = elgamal_generate_keys(p_g)

        # send public keys to Bob
        self.socket.send((pk0, pk1))

        # recieve encrypted key from Bob
        c = self.socket.receive()

        # decrypt the key with both private keys
        k0 = elgamal_decrypt(c, sk0, pk0)[:32]
        k1 = elgamal_decrypt(c, sk1, pk1)[:32]

        # encrypt the messages with the keys
        c0 = aes_encrypt(msg0, k0)
        c1 = aes_encrypt(msg1, k1)

        # send the ciphertexts to Bob
        self.socket.send((c0, c1))

    def receiver_side(self, choice):
        """
        Implement Bob's side of the oblivious transfer protocol with active security.
        """

        # generate secret key
        k = aes_generate_key()

        # receive the public keys
        pk0, pk1 = self.socket.receive()

        # encrypt with the chosen public key
        if choice == 0:
            c = elgamal_encrypt(k, pk0)
        
        else:
            c = elgamal_encrypt(k, pk1)

        # send the encrypted key to Alice
        self.socket.send(c)

        # receive the ciphertexts
        c0, c1 = self.socket.receive()

        # decrypt the chosen ciphertext
        if choice == 0:
            m = aes_decrypt(c0, k)

        else:
            m = aes_decrypt(c1, k)

        return m

In [68]:
######################################################
# EXAMPLE OF OBLIVIOUS TRANSFER WITH ACTIVE SECURITY #
######################################################

socket_bob, socket_alice = start_sockets()

def run_bob(socket_bob):
    # initialize Bob's side
    ot_bob = ObliviousTransfer_AS(socket_bob)

    # Bob wants message 1
    choice = 1

    # run Bob's side of the protocol
    received_message = ot_bob.receiver_side(choice).decode("utf-8")

    # close the socket
    socket_bob.close()

    return received_message

def run_alice(socket_alice):
    # initialize Alice's side
    ot_alice = ObliviousTransfer_AS(socket_alice)

    # get Alice's messages
    msg0 = b"Hello"
    msg1 = b"World"

    # run Alice's side of the protocol
    ot_alice.sender_side(msg0, msg1)

    # close the socket
    socket_alice.close()

# run Alice and Bob in separate threads
bob_thread = ReturningThread(target=run_bob, args=(socket_bob,))
alice_thread = ReturningThread(target=run_alice, args=(socket_alice,))

bob_thread.start()
alice_thread.start()

output = bob_thread.join()
alice_thread.join()

print(f"Bob received: {output}")

Bob received: World


## Garbled Circuits

Implementation first of a single Garbled gate then of a whole Garbled circuit as described by the course slides.

In [69]:
############################
# EXAMPLE OF GARBLED GATES #
############################

class GarbledGate:
    def __init__(self, op, wire_keys):
        """
        Takes in input the operation, the wires and their keys and outputs
        a garbled table for the gate.
        """

        # functions
        self.encrypt = verifiable_encrypt
        self.double_encrypt = double_verifiable_encrypt

        # tables
        truth_table = self.generate_truth_table(op)
        key_table = self.generate_key_table(op, truth_table, wire_keys)
        self.garbled_table = self.generate_garbled_table(op, key_table)

    def generate_key_table(self, op, truth_table, wire_keys):
        """
        Take in input the truth table and the wire keeys and produce a table of encryption keys.
        """

        key_table = []

        if op == "INV":
            for entry in truth_table:
                key_table.append((wire_keys[0][entry[0]], wire_keys[1][entry[1]]))

        else:
            for entry in truth_table:
                key_table.append((wire_keys[0][entry[0]], wire_keys[1][entry[1]], wire_keys[2][entry[2]]))

        return key_table

    def generate_garbled_table(self, op, key_table):
        """
        Take in input the key table and return the garbled table.
        """

        garbled_table = []

        if op == "INV":
            for entry in key_table:
                garbled_table.append(self.encrypt(entry[1], entry[0]))

        else:
            for entry in key_table:
                garbled_table.append(self.double_encrypt(entry[2], entry[0], entry[1]))

        random.shuffle(garbled_table)
        
        return garbled_table

    def generate_truth_table(self, op):
        """
        Given operation name return truth table.
        """

        if op == "AND":
            return [(0, 0, 0),
                    (0, 1, 0),
                    (1, 0, 0),
                    (1, 1, 1)]
        
        elif op == "XOR":
            return [(0, 0, 0),
                    (0, 1, 1),
                    (1, 0, 1),
                    (1, 1, 0)]

        elif op == "OR":
            return [(0, 0, 0),
                    (1, 0, 1),
                    (0, 1, 1),
                    (1, 1, 1)]

        elif op == "INV":
            return [(0, 1),
                    (1, 0)]
        
        else:
            print(f"Gate not supported: {op}")

In [70]:
###############################
# EXAMPLE OF GARBLED CIRCUITS #
###############################

class GarbledCircuit:
    def __init__(self, circuit):
        """
        Given a circuit initialize the Garbled circuit
        """

        # setup input data and functions
        self.circuit = circuit
        self.generate_key = get_random_bytes
        self.garbled_table_bytes = BITS_GARBLED_TABLES

        # generate all keys linked to the wires
        self.wire_keys = self.generate_wire_keys(self.circuit["n_wires"])

        # generate all garble gates
        self.garbled_sequence = self.generate_garbled_gates(self.circuit["gates"])

    def generate_wire_keys(self, n_wires):
        """
        Creates for each wire in the circuit a tuple of two keys, one for bit 1 and
        the other for bit 0.
        """

        return [(self.generate_key(self.garbled_table_bytes), self.generate_key(self.garbled_table_bytes)) for _ in range(n_wires)]
    
    def generate_garbled_gates(self, gates):
        """
        Given the list of gates and their information creates a list of garbled gates.
        """

        # initialize garbled gates list
        garbled_sequence = []

        # populate the list
        for gate in gates:
            # get specifics from gate
            op, wires = gate

            # get keys for the required wires
            wire_keys = [self.wire_keys[wire] for wire in wires]

            # create GarbledGate object
            garbled_sequence.append((op, wires, GarbledGate(op, wire_keys).garbled_table))
        
        return garbled_sequence
    
    def get_garbled_circuit(self):
        return {'n_gates': self.circuit["n_gates"],
                'n_wires': self.circuit["n_wires"],
                'n_input1': self.circuit["n_input1"],
                'n_input2': self.circuit["n_input2"],
                'n_output': self.circuit["n_output"],
                'garbled_sequence': self.garbled_sequence}

In [71]:
######################################
# EXAMPLE OF GARBLED GATE DECRYPTION #
######################################

# take example circuit
filepath = r"circuits\fullAdder.bristol"

# parse circuit
circuit = parse_circuit_file(filepath)

# create garbled circuit
garbled_circuit = GarbledCircuit(circuit)

# take first gate (XOR) and show key for output = 0
print(garbled_circuit.wire_keys[3][0])

# take first garbled circuit and show key obtained with input (0,0)
print(double_verifiable_decrypt(garbled_circuit.garbled_sequence[0][2][0], garbled_circuit.wire_keys[0][0], garbled_circuit.wire_keys[1][0]))

# try to ungarble the same entry of the table using wrong bit for first input
print(double_verifiable_decrypt(garbled_circuit.garbled_sequence[0][2][0], garbled_circuit.wire_keys[0][1], garbled_circuit.wire_keys[1][0]))

b"K&\x1a\xd4\x86\x01?'Rrb\x9a\xf2\x16[\xd3"
None
b'\xf6M\x16\xf4\x97\x92\t\x9c\x95\x8c\xd5\xc5Vk\xbc\x10'


## Garbled Circuits Evaluation

In [72]:
##########################################
# EXAMPLE OF GARBLED CIRCUITS EVALUATION #
##########################################

def garbled_gate_evaluation(garbled_table, inputs):
    """
    Given the garbled gate and its inputs returns the output
    of the gate.
    """
    
    if len(garbled_table) == 2: # INV gate
        # INV gate has only one input
        x = inputs

        for entry in garbled_table:
            decryption = verifiable_decrypt(entry, x)

            if decryption is not None:
                return decryption

    else: # other gate
        # all other gates have two inputs
        x, y = inputs

        for entry in garbled_table:
            decryption = double_verifiable_decrypt(entry, x, y)

            if decryption is not None:
                return decryption
            
def garbled_circuit_evaluation(garbled_sequence, mem):
    """
    Given the garbled sequence and the starting input values perform the
    computation of the circuit.
    """

    # execute gates in sequence
    for op, wires, garbled_table in garbled_sequence:

        if op == 'INV': # INV has only one input
            # prepare input
            gate_inputs = mem[wires[0]]

            # perform gate and save result
            mem[wires[1]] = garbled_gate_evaluation(garbled_table, gate_inputs)

        else: # all other operations have two inputs
            # prepare inputs
            gate_inputs = (mem[wires[0]], mem[wires[1]])

            # perform gate and save results
            mem[wires[2]] = garbled_gate_evaluation(garbled_table, gate_inputs)
    
    return mem

def garbled_circuit_local_execution(filepath, input1, input2):

    # parse circuit
    circuit = parse_circuit_file(filepath)

    # get circuit object
    gc = GarbledCircuit(circuit)

    # get circuit dict
    garbled_circuit = gc.get_garbled_circuit()

    # initialize memory with the inputs
    mem = dict()
    n_inputs = 0
    for i in range(garbled_circuit["n_input1"]):
        # add element to memory
        mem[n_inputs] = gc.wire_keys[n_inputs][input1[i]]
        n_inputs += 1

    for i in range(garbled_circuit["n_input2"]):
        # add element to memory
        mem[n_inputs] = gc.wire_keys[n_inputs][input2[i]]
        n_inputs += 1

    # evalute circuit
    mem = garbled_circuit_evaluation(garbled_circuit["garbled_sequence"], mem)

    # sort memory
    mem = sorted(mem.items(), key=lambda item: item[0], reverse=False)
    
    # get outputs from memory
    encrypted_output = [(key, value) for key, value in mem[-garbled_circuit["n_output"]:]]

    # decrypt outputs
    output = []
    for key, value in encrypted_output:

        if gc.wire_keys[key][0] == value:
            output.append(0)
        
        else:
            output.append(1)

    return output

In [73]:
#########################################
# EXAMPLE OF GARBLED CIRCUIT EVALUATION #
#########################################

# example circuit (full adder)
filepath = r"circuits\fullAdder.bristol"

# example inputs
input1 = [0, 1]
input2 = [1]

# execution of garbled circuit
garbled_circuit_local_execution(filepath, input1, input2)

[1, 0]

## Oblivious Transfer of Inputs

In [74]:
def input_OT_alice(socket, key0, key1):
    """
    Implement garbler's side of one iteration of oblivious transfer
    for one input of the garbled circuit.
    """

    # initialie Alice's side
    ot = ObliviousTransfer_AS(socket)

    # run Alice's side of the protocol
    ot.sender_side(key0, key1)

def input_OT_bob(socket, choice):
    """
    Implement garbler's side of one iteration of oblivious transfer
    for one input of the garbled circuit.
    """

    # initialize Bob's side
    ot = ObliviousTransfer_AS(socket)

    # run Bob's side of the protocol
    key = ot.receiver_side(choice)

    return key

## Yao's Protocol Implementation

In [75]:
def yao_alice(socket, filepath, input1):
    """
    Implements Alice's side of Yao's protocol.
    """

    # parse circuit
    circuit = parse_circuit_file(filepath)

    # get circuit object
    gc = GarbledCircuit(circuit)

    # get circuit dict
    garbled_circuit = gc.get_garbled_circuit()
    n_input1 = garbled_circuit["n_input1"]
    n_input2 = garbled_circuit["n_input2"]

    # transfer circuit to Bob
    socket.send(garbled_circuit)

    # encrypt Alice's inputs
    encrypted_input1 = []
    for i in range(n_input1):
        encrypted_input1.append(gc.wire_keys[i][input1[i]])

    # transfer Alice's encrypted inputs to Bob
    socket.send(encrypted_input1)

    # perform oblivious transfer of Bob inputs
    for i in range(n_input1, n_input1 + n_input2):
        key0, key1 = gc.wire_keys[i]

        input_OT_alice(socket, key0, key1)

    # get results from Bob
    encrypted_output = socket.receive()

    # decrypt outputs
    output = []
    for key, value in encrypted_output:

        if gc.wire_keys[key][0] == value:
            output.append(0)
        
        else:
            output.append(1)

    # send decrypted output to Bob
    socket.send(output)

    return output

def yao_bob(socket, input2):
    """
    Implements Bob's side of Yao's protocol.
    """

    # receive garbled circuit from Alice
    garbled_circuit = socket.receive()
    n_input1 = garbled_circuit["n_input1"]
    n_input2 = garbled_circuit["n_input2"]

    # receive Alice's inputs
    encrypted_input1 = socket.receive()

    # perform oblivious transfer to get Bob's inputs
    encrypted_input2 = []
    for i in range(n_input2):
        encrypted_input2.append(input_OT_bob(socket, input2[i]))

    # build starting memory
    mem = dict()
    n_inputs = 0
    for elem in encrypted_input1:
        # add element to memory
        mem[n_inputs] = elem
        n_inputs += 1

    for elem in encrypted_input2:
        # add element to memory
        mem[n_inputs] = elem
        n_inputs += 1

    # evalute circuit
    mem = garbled_circuit_evaluation(garbled_circuit["garbled_sequence"], mem)

    # sort memory
    mem = sorted(mem.items(), key=lambda item: item[0], reverse=False)

    # get outputs from memory
    encrypted_output = [(key, value) for key, value in mem[-garbled_circuit["n_output"]:]]

    # give encrypted output to Alice
    socket.send(encrypted_output)
    
    # recieve decrypted output from Alice
    output = socket.receive()

    return output

In [76]:
def run_garbled_circuit_simulation(filepath, input1, input2):
    # create a server and a client sockets
    socket_bob, socket_alice = start_sockets()

    # run Alice and Bob in separate threads
    bob_thread = ReturningThread(target=yao_bob, args=(socket_bob, input2))
    alice_thread = ReturningThread(target=yao_alice, args=(socket_alice, filepath, input1))

    bob_thread.start()
    alice_thread.start()

    output_bob = bob_thread.join()
    output_alice = alice_thread.join()

    # close sockets
    socket_alice.close()
    socket_bob.close()

    return output_bob, output_alice

filepath = r'circuits\fullAdder.bristol'
input1 = [0, 1]
input2 = [1]

output_bob, output_alice = run_garbled_circuit_simulation(filepath, input1, input2)

print(f"Output of Bob: {output_bob}")
print(f"Output of Alice: {output_alice}")

Output of Bob: [1, 0]
Output of Alice: [1, 0]


# Examples of Complex Circuits

## 64 BIT Adder

In [None]:
def int_to_bits(n):
    """
    Convert Python int to sequence of 64 bits.
    """

    return [int(bit) for bit in format(n & 0xFFFFFFFFFFFFFFFF, '064b')]

def bits_to_int(bits):
    """
    Convert sequence of 64 bits to Python int
    """

    # merge list to sequence
    bit_string = ''.join(map(str, bits))

    # convert to int
    num = int(bit_string, 2)

    # convert to signed
    if bits[0] == 1:
        num -= 1 << 64

    return num

filepath = r"circuits\int_add64_size.bristol"
num1 = 640
num2 = 360
input1 = int_to_bits(num1)
input2 = int_to_bits(num2)

output_bob, output_alice = run_garbled_circuit_simulation(filepath, input1, input2)

print(f"Real output: {num1 + num2}")
print(f"Output of Bob: {bits_to_int(output_bob)}")
print(f"Output of Alice: {bits_to_int(output_alice)}")

Exception in thread Thread-56 (yao_bob):
Traceback (most recent call last):
  File "c:\ProgramData\anaconda3\Lib\threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\1718333950.py", line 16, in run
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\1276862266.py", line 84, in yao_bob
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\2537481725.py", line 52, in garbled_circuit_evaluation
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\2537481725.py", line 26, in garbled_gate_evaluation
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\321325838.py", line 24, in double_verifiable_decrypt
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\2329936175.py", line 64, in verifiable_decrypt
  File "C:\Users\Utente\AppData\Local\Temp\ipykernel_10060\2329936175.py", line 19, in aes_as_doubling_prf
  File "C:\Users\Utente\AppData\Roaming\Python\Python312\site-packages\Crypto\Cipher\AES.p

## Floating Point Adder

In [78]:
# TODO

## AES

In [79]:
# TODO

# From Semi-Honest to Malicious attacker

## Cut and Choose

In [80]:
# TODO

## Against Input Consistency Attack

In [81]:
# TODO

## Against Selective Failure Attack

In [82]:
# TODO