In [None]:
%pip install torch torchvision zmq web3 cryptography mysql-connector-python toml requests

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import zmq
import pickle
import toml
import mysql.connector
import datetime
import requests
import os
import time

from web3 import Web3
from eth_account import Account

from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

from collections import namedtuple

In [3]:
# Get Database Config
config = toml.load("config.toml")

db_host = config["database"]["host"]
# db_port = config["database"]["port"]
db_username = config["database"]["username"]
db_password = config["database"]["password"]

In [4]:

mydb = mysql.connector.connect(
  host=db_host,
  user=db_username,
  password=db_password
)

def insertModels(account, model_url, round, is_local_model):
    mycursor = mydb.cursor()

    now = datetime.datetime.now()
    sql = "INSERT INTO `ai_link`.`models` (`client_address`,`model_url`,`round`,`creation_date`,`local_model`) VALUES (%s, %s, %s, %s, %s);"

    val = (account, model_url, round, now.strftime("%Y-%m-%d %H:%M:%S"), is_local_model)

    mycursor.execute(sql, val)

    mydb.commit()


In [5]:
def uploadToCESS(model, round):
    url = "https://uat-d.cess.network/file"
    file_path = f"models/client1/model_{round}.mod"

    with open(file_path, "wb") as f:
        pickle.dump(model, f)

    with open(file_path, "rb") as f:
        files = {"file": f}
        response = requests.post(url, files=files)

    if response.status_code == 200:
        print("Upload successful!")
        response_data = response.json()
        file_url = response_data['data']['url']
        return file_url
    else:
        print(f"Upload failed! Status code: {response.status_code}")
        print("Error:", response.text)

In [6]:
def generate_symmetric_key():
    return os.urandom(32)

In [7]:
def encrypt(message):
    # Generate a symmetric key
    symmetric_key = generate_symmetric_key()

    # Encrypt the file data with AES
    iv = os.urandom(16)  # Initialization vector for AES
    cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
    encryptor = cipher.encryptor()
    encrypted_file_data = iv + encryptor.update(message) + encryptor.finalize()

    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    # Encrypt the symmetric key with RSA public key
    encrypted_symmetric_key = public_key.encrypt(
        symmetric_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        )
    )

    return encrypted_file_data, encrypted_symmetric_key

In [8]:
def decrypt(encrypted_message, encrypted_symmetric_key):

    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    # Decrypt the symmetric key with RSA private key
    symmetric_key = private_key.decrypt(
        encrypted_symmetric_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        )
    )

    # The first 16 bytes of the encrypted file data is the IV
    iv = encrypted_message[:16]
    encrypted_data = encrypted_message[16:]

    # Decrypt the file data with AES
    cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_file_data = decryptor.update(encrypted_data) + decryptor.finalize()

    return decrypted_file_data

In [9]:
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))

with open("contract/address", "r") as file:
    address = file.read()

with open("contract/abi", "r") as file:
    abi = file.read()

contract_instance = w3.eth.contract(address=address, abi=abi)

In [10]:
with open("keys/client1", "r") as file:
    private_key = file.read()

account = Account.from_key(private_key)
sender_address = account.address

def submitLocalModelHash(modelHash: str, round: int):
    transaction = contract_instance.functions.submitLocalModelHash(modelHash, round).build_transaction({
        'from': sender_address,
        'nonce': w3.eth.get_transaction_count(sender_address)
    })

    signed_transaction = w3.eth.account.sign_transaction(transaction, private_key)

    tx_hash = w3.eth.send_raw_transaction(signed_transaction.raw_transaction)

    tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

    print(tx_receipt)

In [11]:
def getStartRound():
    return contract_instance.functions.getStartRound().call()

def getCurrentRound():
    return contract_instance.functions.getCurrentRound().call()

def getIterations():
    return contract_instance.functions.getIterations().call();

def isCurrentRoundValid(current_round):
    if current_round >= getStartRound() and current_round < (getStartRound() + getIterations()):
        return True
    else:
        return False

def doesLocalModelHashExist(client, round):
    return contract_instance.functions.doesLocalModelHashExist(client, round).call();

In [12]:
ClientModel = namedtuple('ClientModel', ['address', 'round', 'model', 'encrypted_symmetric_key'])

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def train(model, data_loader, optimizer, criterion, epochs=1):
    """Train the model locally on client data."""
    model.train()
    for epoch in range(epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def send_model_weights(socket, model, round_num):
    """Send model weights to the TEE."""
    model_weights = model.state_dict()
    model_bytes = pickle.dumps(model_weights)
    
    model_hash = Web3.keccak(model_bytes).hex()
    model_encrypt, encrypted_symmetric_key = encrypt(model_bytes)
    model_url = uploadToCESS(model_encrypt, round_num)
    insertModels(sender_address, model_url, round_num, 1)
    submitLocalModelHash(model_hash , round_num)
    
    client1 = ClientModel(sender_address, round_num, model_encrypt, encrypted_symmetric_key)
    
    socket.send(pickle.dumps(client1))
    # socket.send(client1)

def receive_aggregated_model(socket, model):
    """Receive aggregated model from the TEE and load it."""
    message = socket.recv()
    try:
        new_weights = pickle.loads(message)
        model.load_state_dict(new_weights)
        print("Received aggregated model from server.")
    except pickle.UnpicklingError:
        if message == b"ACK":
            print("Acknowledgment received from server.")
        else:
            print("Unexpected response from server.")

if __name__ == "__main__":
    context = zmq.Context()
    socket = context.socket(zmq.REQ) 
    socket.connect("tcp://127.0.0.1:9102")  # Connect to TEE

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = SimpleNN()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    while True:
        current_round = getCurrentRound()
        if isCurrentRoundValid(current_round):
            if not doesLocalModelHashExist(sender_address, current_round):
                print(f"Training round {current_round}")
                train(model, train_loader, optimizer, criterion, epochs=1)

                print("Sending model to TEE...")
                send_model_weights(socket, model, current_round)

                print("Waiting for aggregated model from TEE...")
                receive_aggregated_model(socket, model)
                print("Received aggregated model from server.")
        else:
            time.sleep(3)
