In [None]:
%pip install zmq web3

In [2]:
import zmq
import pickle
import os

from web3 import Web3
from eth_account import Account

from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

from collections import namedtuple

In [3]:
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))

with open("contract/address", "r") as file:
    address = file.read()

with open("contract/abi", "r") as file:
    abi = file.read()

contract_instance = w3.eth.contract(address=address, abi=abi)

In [4]:
with open("keys/server", "r") as file:
    private_key = file.read()

account = Account.from_key(private_key)
sender_address = account.address

def submitAggregateModelHash(modelHash: str, round: int):
    transaction = contract_instance.functions.submitGlobalModelHash(modelHash, round).build_transaction({
        'from': sender_address,
        'nonce': w3.eth.get_transaction_count(sender_address)
    })

    signed_transaction = w3.eth.account.sign_transaction(transaction, private_key)

    tx_hash = w3.eth.send_raw_transaction(signed_transaction.raw_transaction)

    tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

    print(tx_receipt)

In [5]:
def getStartRound():
    return contract_instance.functions.getStartRound().call()

def getCurrentRound():
    return contract_instance.functions.getCurrentRound().call()

def getIterations():
    return contract_instance.functions.getIterations().call();

def isCurrentRoundValid(current_round):
    if current_round >= getStartRound() and current_round < getCurrentRound() + getIterations():
        return True
    else:
        return False

In [None]:
def generate_symmetric_key():
    return os.urandom(32)

In [6]:
def encrypt(message):
    # Generate a symmetric key
    symmetric_key = generate_symmetric_key()

    # Encrypt the file data with AES
    iv = os.urandom(16)  # Initialization vector for AES
    cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
    encryptor = cipher.encryptor()
    encrypted_file_data = iv + encryptor.update(message) + encryptor.finalize()

    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    # Encrypt the symmetric key with RSA public key
    encrypted_symmetric_key = public_key.encrypt(
        symmetric_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        )
    )

    return encrypted_file_data, encrypted_symmetric_key

In [7]:
def decrypt(encrypted_message, encrypted_symmetric_key):

    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    # Decrypt the symmetric key with RSA private key
    symmetric_key = private_key.decrypt(
        encrypted_symmetric_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        )
    )

    # The first 16 bytes of the encrypted file data is the IV
    iv = encrypted_message[:16]
    encrypted_data = encrypted_message[16:]

    # Decrypt the file data with AES
    cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_file_data = decryptor.update(encrypted_data) + decryptor.finalize()

    return decrypted_file_data

In [8]:
ClientModel = namedtuple('ClientModel', ['address', 'round', 'model', 'encrypted_symmetric_key'])

In [None]:
class FederatedServer:
    def __init__(self, address="tcp://*:9103"):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REP)
        self.socket.bind(address)

    def close(self):
        self.socket.close()
        self.context.term() 

    def aggregate_models(self, model_list):
        """Aggregate model"""
        avg_model = model_list[0]
        for key in avg_model.keys():
            for model in model_list[1:]:
                avg_model[key] += model[key]
            avg_model[key] /= len(model_list)
        return avg_model

    def run(self):
        print("Server is running...")
        model_updates = []
        try:
            while True:
                current_round = getCurrentRound()
                if isCurrentRoundValid(current_round):
                    message = self.socket.recv()
                    print("Received model from sandbox.")
                    client_update = pickle.loads(message)

                    if isinstance(client_update, ClientModel):
                        
                        client_address = client_update.address
                        # round_num = client_update.round
                        model_encrypt = client_update.model
                        encrypted_symmetric_key = client_update.encrypted_symmetric_key
                        model_bytes = decrypt(model_encrypt, encrypted_symmetric_key)
                        raw_model = pickle.loads(model_bytes)
                        model_updates.append(raw_model)
                        if len(model_updates) >= 2:
                            print("Aggregating models...")
                            aggregated_model = self.aggregate_models(model_updates)

                            aggregated_model_bytes = pickle.dumps(aggregated_model)
                            aggregated_model_hash = Web3.keccak(aggregated_model_bytes).hex()

                            submitAggregateModelHash(aggregated_model_hash , current_round)
                            self.socket.send(aggregated_model_bytes)
                            print("Aggregated model sent to TEE.")

                            model_updates = []
                        else:
                            self.socket.send(b"ACK")
                    else :
                        self.socket.send(b"ACK")
                else:
                    self.socket.send(b"ACK")
        except KeyboardInterrupt:
            pass
        finally:
            self.close() 

if __name__ == "__main__":
    server = FederatedServer()
    server.run()
