In [None]:
%pip install zmq web3 cryptography mysql-connector-python toml requests

In [1]:
import zmq
import pickle

from web3 import Web3

from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

from collections import namedtuple

In [4]:
def encrypt(message):
    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    encrypted_message = public_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return encrypted_message

In [5]:
def decrypt(encrypted_message, encrypted_symmetric_key):

    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    # Decrypt the symmetric key with RSA private key
    symmetric_key = private_key.decrypt(
        encrypted_symmetric_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        )
    )

    # The first 16 bytes of the encrypted file data is the IV
    iv = encrypted_message[:16]
    encrypted_data = encrypted_message[16:]

    # Decrypt the file data with AES
    cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_file_data = decryptor.update(encrypted_data) + decryptor.finalize()

    return decrypted_file_data

In [6]:
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))

with open("contract/address", "r") as file:
    address = file.read()

with open("contract/abi", "r") as file:
    abi = file.read()

contract_instance = w3.eth.contract(address=address, abi=abi)

In [7]:
def getStartRound():
    return contract_instance.functions.getStartRound().call()

def getCurrentRound():
    return contract_instance.functions.getCurrentRound().call()

def getIterations():
    return contract_instance.functions.getIterations().call();

def isCurrentRoundValid(current_round):
    if current_round >= getStartRound() and current_round < getCurrentRound() + getIterations():
        return True
    else:
        return False
    
def getLocalModelHash(address, round_num):
    return contract_instance.functions.getLocalModelHashAtRound(address, round_num).call()
    

In [8]:
ClientModel = namedtuple('ClientModel', ['address', 'round', 'model', 'encrypted_symmetric_key'])

In [None]:
class Sandbox:
    def __init__(self, client_address="tcp://*:9102", server_address="tcp://127.0.0.1:9103"):
        self.context = zmq.Context()

        # Client Communication
        self.client_socket = self.context.socket(zmq.REP)  # REPLY to client
        self.client_socket.bind(client_address)

        # Server Communication
        self.server_socket = self.context.socket(zmq.REQ)  # REQUEST to server
        self.server_socket.connect(server_address)

    def close(self):
        self.socket.close()
        self.context.term() 

    def forward_to_server(self, data):
        """Forward the model received from the client to the server."""
        print("Forwarding model to server...")
        self.server_socket.send(data)
        return self.server_socket.recv()  # Receive response from server

    def forward_to_client(self, data):
        """Forward the aggregated model from the server to the client."""
        print("Forwarding aggregated model to client...")
        self.client_socket.send(data)

    def run(self):
        print("TEE is running...")
        try:
            while True:
                # Receive encrypted model from client
                client_message = self.client_socket.recv()
                print("Received model from client.")

                client_data = pickle.loads(client_message)

                if isinstance(client_data, ClientModel):
                    client_address = client_data.address
                    round_num = client_data.round
                    model_encrypt = client_data.model
                    encrypted_symmetric_key = client_data.encrypted_symmetric_key

                    model_bytes = decrypt(model_encrypt, encrypted_symmetric_key)
                    model_hash = Web3.keccak(model_bytes).hex()

                    model_hash_on_contract = getLocalModelHash(client_address, round_num)

                    if model_hash == model_hash_on_contract:
                        # TODO: perform cosine similarity
                        # raw_model = pickle.loads(model_bytes)
                    
                        print("Forwarding to server.")
                        # Forward to server
                        server_response = self.forward_to_server(client_message)
                        
                        print("Forwarding global model to client.")
                        # Server response back to the client
                        self.forward_to_client(server_response)
                else:
                    print("Received invalid data from client.")
                
        except KeyboardInterrupt:  # handle termination signal (Ctrl+C)
            pass
        finally:
            self.close() 

if __name__ == "__main__":
    Sandbox_processor = Sandbox()
    Sandbox_processor.run()
