In [1]:
import socket
import tensorflow as tf
import numpy as np
import pandas as pd
import pickle

# secure aggregation
def decrypt_weights(encrypted_weights, encryption_key):
    # This is a placeholder for a real decryption function
    return [w / encryption_key for w in encrypted_weights]

def build_model():
    tf.keras.backend.clear_session()  # Clear any previous session.
    # NN model creation with input parameters
    model = tf.keras.Sequential([
        # First hidden layer with ReLU activation
        # weather 5 + roadcond 3 + dayofweek 7 + month 12 = e.g. 26 features
        tf.keras.layers.Dense(64, activation='relu', input_shape=(13,)),
        # Second hidden layer
        tf.keras.layers.Dense(32, activation='relu'),
        # Output layer with softmax activation
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    # Compile the model with optimizer, loss function, and metrics to evaluate
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    # print("model weights")
    # print (model.get_weights())
    weights = model.get_weights()
    # Calculate the total size of the weights
    total_elements = sum([np.prod(w.shape) for w in weights])
    print("Total number of elements in the GM weights:", total_elements)

    print("GM has been created and compiled successfully.")
    return model

def send_model(client_connection, model):
    # Serialize and send the model structure (architecture) to the client
    # print("Serializing the model structure to JSON format.")
    model_json = model.to_json()
    client_connection.sendall(model_json.encode('utf-8'))
    #print("Model structure sent to client. Awaiting acknowledgment...")
    # Wait for acknowledgment from the client before sending weights
    ack = client_connection.recv(1024).decode()
    if ack == "ACK_MODEL":
        # print("Acknowledgment received. Now sending initial model weights.")
        # Serialize and send the initial model weights
        initial_weights_data = pickle.dumps(model.get_weights())#initial_weights = pickle.dumps(model.get_weights())
        #print("***************************************************************")
        #print("initial_weights")
        #print(initial_weights)
        client_connection.sendall(len(initial_weights_data).to_bytes(8, byteorder='big'))
        client_connection.sendall(initial_weights_data)#client_connection.send(initial_weights)
        #print("Initial model weights have been sent to the client.")

def receive_and_update_model(client_connection, model):
    #print("Receiving updated model weights from the client.")
    data_size = int.from_bytes(client_connection.recv(8), byteorder='big')
    #Now receive the actual pickled data based on the size
    updated_weights_data = client_connection.recv(data_size)
    while len(updated_weights_data) < data_size:
        packet = client_connection.recv(data_size - len(updated_weights_data))
        if not packet: break
        updated_weights_data += packet

    updated_weights = pickle.loads(updated_weights_data)
    model.set_weights(updated_weights)
    print("Server model has been updated with the client's weights.")
    return updated_weights

#secure aggregation
def aggregate_weights(encrypted_weight_list):
    # Assuming encrypted_weight_list is a list of encrypted weights from each client
    # and that direct addition is meaningful for the encryption scheme
    aggregated_encrypted_weights = [sum(weights) for weights in zip(*encrypted_weight_list)]
    print("aggregated_encrypted_weights")
    print(aggregated_encrypted_weights)
    return aggregated_encrypted_weights

def encrypt_weights(weights, encryption_key):
    # This is a placeholder for a real encryption function
    return [w * encryption_key for w in weights]

def val_weights(weights_list):
    for i, weights in enumerate(weights_list):
        for layer_weights in weights:
            if np.isnan(layer_weights).any():
                print(f"NaN found in weights from client {i}")
                return False  
    return True  # no nan values
    
# Server setup
host = '127.0.0.1' # change this back later '10.0.0.189'
port = 5300
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((host, port))
server_socket.listen(2)
print(f"Server is listening on port {port}.")


client_counter = 0

# Build a model and send it to the client
print('Building GM')
model = build_model()
print("***************************************************************")
print("Clients please connect to server")

client_weights_list = []  # List to store weights from each client
client_conn = []  # List to store client connections

while client_counter < 3 :#change to devices 3
    conn, address = server_socket.accept()
    client_conn.append((conn, address))
    print("*************************************************")
    print(f"Established connection with client at {address}.")

    # send initial model to client
    send_model(conn, model)
    
    print("*******************")
    print(f"Initial model sent to client at {address}.")
    
    # Increment the client counter 
    client_counter += 1
    print(f"Ready for a new connection if any.")
    print("*******")


for conn, address in client_conn:    
    # Receive updated model weights from the client and update the server model
    print(f'Receiving client {address} local trained model - weights and updating the server model')
    updated_weights = receive_and_update_model(conn, model)
    print(updated_weights)
    client_weights_list.append(updated_weights)

all_weights_valid = val_weights(client_weights_list)
    
if client_counter == 3: #change to devices 3   
    print("""Secure Aggregation from multiple models.""")
    # After collecting weights from all clients, aggregate them
    if all_weights_valid:
        aggregated_encrypted_weights = aggregate_weights(client_weights_list)
        decryption_key = 5 # this has to be match with encryptionkey
        aggregated_weights = decrypt_weights(aggregated_encrypted_weights, decryption_key)

        # Update the global model with the aggregated weights
        model.set_weights(aggregated_weights)
        print("Updated the global model with the aggregated weights")
        print(model)

        # send the global model back to each client
        for conn, address in client_conn:
            print("***************")
            print(f"global model weights sent to client {address}.")
            send_model(conn, model)
            conn.close()
        model.save('Globalmodel_Secure.h5')

server_socket.close()
print("Server program ending. All clients have been served.")



Server is listening on port 5300.
Building GM



  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Total number of elements in the GM weights: 3009
GM has been created and compiled successfully.
***************************************************************
Clients please connect to server
*************************************************
Established connection with client at ('127.0.0.1', 60929).
*******************
Initial model sent to client at ('127.0.0.1', 60929).
Ready for a new connection if any.
*******
*************************************************
Established connection with client at ('127.0.0.1', 60931).
*******************
Initial model sent to client at ('127.0.0.1', 60931).
Ready for a new connection if any.
*******
*************************************************
Established connection with client at ('127.0.0.1', 60932).
*******************
Initial model sent to client at ('127.0.0.1', 60932).
Ready for a new connection if any.
*******
Receiving client ('127.0.0.1', 60929) local trained model - weights and updating the server model
Server model has been update



Server model has been updated with the client's weights.
[array([[-1.8111117e+00, -1.2437429e+00,  1.1376231e+00,  1.4037071e+00,
         1.0025606e+00,  4.6795887e-01, -3.9175093e+00,  1.0688372e+00,
         3.9555006e+00,  5.8305717e+00,  1.8330342e+00,  2.4407227e+00,
         8.8134855e-02, -3.4069622e-01,  2.5695477e+00,  6.5080743e+00,
         6.8872517e-01, -7.8906655e+00,  6.0597095e+00,  2.1229432e+00,
         3.8777828e-01,  1.4769058e+00, -9.5103449e-01, -3.1020341e+00,
         3.9648727e-01,  7.0222390e-01,  3.5338216e+00, -1.7550540e+00,
         2.3916264e+00, -4.4166923e+00,  1.7102113e+00, -3.6432102e+00,
        -4.0883770e+00,  6.0152483e+00,  1.2864788e+00,  1.3492774e+00,
         3.6707499e+00,  5.0827932e-01,  4.9050730e-01,  2.4401755e+00,
        -7.4623758e-01, -4.3728108e+00, -1.2335676e+00,  5.1704731e+00,
        -4.9189550e-01, -1.0171847e-02,  5.0576359e-01,  1.2759831e+00,
         2.2152746e+00,  1.0542576e-01, -3.6646981e+00,  5.6955175e+00,
      