In [1]:

import socket
import tensorflow as tf
import numpy as np
import pandas as pd
import pickle


def build_model():
    tf.keras.backend.clear_session()  # Clear any previous session.
    # NN model creation with input parameters
    model = tf.keras.Sequential([
        # First hidden layer with ReLU activation
        # weather 5 + roadcond 3 + dayofweek 7 + month 12 = e.g. 26 features
        tf.keras.layers.Dense(64, activation='relu', input_shape=(13,)),
        # Second hidden layer
        tf.keras.layers.Dense(32, activation='relu'),
        # Output layer with softmax activation
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    # Compile the model with optimizer, loss function, and metrics to evaluate
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    # print("model weights")
    # print (model.get_weights())
    weights = model.get_weights()
    # Calculate the total size of the weights
    total_elements = sum([np.prod(w.shape) for w in weights])
    print("Total number of elements in the GM weights:", total_elements)

    print("GM has been created and compiled successfully.")
    return model

def send_model(client_connection, model):
    # Serialize and send the model structure (architecture) to the client
    # print("Serializing the model structure to JSON format.")
    model_json = model.to_json()
    client_connection.sendall(model_json.encode('utf-8'))
    #print("Model structure sent to client. Awaiting acknowledgment...")
    # Wait for acknowledgment from the client before sending weights
    ack = client_connection.recv(1024).decode()
    if ack == "ACK_MODEL":
        # print("Acknowledgment received. Now sending initial model weights.")
        # Serialize and send the initial model weights
        initial_weights_data = pickle.dumps(model.get_weights())
        #initial_weights = pickle.dumps(model.get_weights())
        #print("***************************************************************")
        #print("initial_weights")
        #print(initial_weights)
        client_connection.sendall(len(initial_weights_data).to_bytes(8, byteorder='big'))
        client_connection.sendall(initial_weights_data)#client_connection.send(initial_weights)
        #print("Initial model weights have been sent to the client.")

def receive_and_update_model(client_connection, model):
    #print("Receiving updated model weights from the client.")
    data_size = int.from_bytes(client_connection.recv(8), byteorder='big')
    #Now receive the actual pickled data based on the size
    updated_weights_data = client_connection.recv(data_size)
    while len(updated_weights_data) < data_size:
        packet = client_connection.recv(data_size - len(updated_weights_data))
        if not packet: break
        updated_weights_data += packet

    updated_weights = pickle.loads(updated_weights_data)
    model.set_weights(updated_weights)
    print("Server model has been updated with the client's weights.")
    return updated_weights

#The function aims to aggregate (combine) the model weights received from multiple clients
# into a single set of weights. This aggregation is typically done by averaging the weights. 
# The process ensures that the global model learns from all the local datasets without actually having access to them.
def aggregate_weights(weight_list):
    # Check if weight_list is empty or contains None elements
    if not weight_list or None in weight_list:
        raise ValueError("Weight list is empty or contains None elements.")
        
    # Calculate the average weights
    avg_weight = [np.mean(np.array([client_weights[layer] for client_weights in weight_list]), axis=0) 
                  for layer in range(len(weight_list[0]))]
    print("avg_weight")
    print(avg_weight)
    return avg_weight

def val_weights(weights_list):
    for i, weights in enumerate(weights_list):
        for layer_weights in weights:
            if np.isnan(layer_weights).any():
                print(f"NaN found in weights from client {i}")
                return False  
    return True  # no nan values
            
# Server setup
host ='127.0.0.1' 
port = 5300
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((host, port))
server_socket.listen(2)
print(f"Server is listening on port {port}.")

client_counter = 0

# Build a model and send it to the client
print('Building GM')
model = build_model()
print("***************************************************************")
print("Clients please connect to server")

client_weights_list = []  # List to store weights from each client
client_conn = []  # List to store client connections

while client_counter < 3 :#change to devices 3
    conn, address = server_socket.accept()
    client_conn.append((conn, address))
    print("*************************************************")
    print(f"Established connection with client at {address}.")

    # send initial model to client
    send_model(conn, model)
    
    print("*******************")
    print(f"Initial model sent to client at {address}.")
    
    # Increment the client counter 
    client_counter += 1
    print(f"Ready for a new connection if any.")
    print("*******")


for conn, address in client_conn:    
    # Receive updated model weights from the client and update the server model
    print(f'Receiving client {address} local trained model - weights and updating the server model')
    updated_weights = receive_and_update_model(conn, model)
    print(updated_weights)
    client_weights_list.append(updated_weights)

all_weights_valid = val_weights(client_weights_list)
    
if client_counter == 3: #change to devices 3   
    print("""Aggregate weights from multiple models.""")
    # After collecting weights from all clients, aggregate them
    if all_weights_valid:
        aggregated_weights = aggregate_weights(client_weights_list)

        # Update the global model with the aggregated weights
        model.set_weights(aggregated_weights)
        print("Updated the global model with the aggregated weights")
        print(model)

        # send the global model back to each client
        for conn, address in client_conn:
            print("***************")
            print(f"global model weights sent to client {address}.")
            send_model(conn, model)
            conn.close()
        model.save('Globalmodel_FederatedAveraging.h5')

server_socket.close()
print("Server program ending. All clients have been served.")



Server is listening on port 5300.
Building GM

Total number of elements in the GM weights: 3009
GM has been created and compiled successfully.
***************************************************************
Clients please connect to server


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


*************************************************
Established connection with client at ('127.0.0.1', 53762).
*******************
Initial model sent to client at ('127.0.0.1', 53762).
Ready for a new connection if any.
*******
*************************************************
Established connection with client at ('127.0.0.1', 53764).
*******************
Initial model sent to client at ('127.0.0.1', 53764).
Ready for a new connection if any.
*******
*************************************************
Established connection with client at ('127.0.0.1', 53765).
*******************
Initial model sent to client at ('127.0.0.1', 53765).
Ready for a new connection if any.
*******
Receiving client ('127.0.0.1', 53762) local trained model - weights and updating the server model
Server model has been updated with the client's weights.
[array([[-2.59900272e-01,  3.75413835e-01, -9.94405895e-02,
        -1.39454708e-01, -5.84986925e-01,  3.85507569e-02,
         6.80262923e-01, -7.63291195e-02, -1.



Server model has been updated with the client's weights.
[array([[-7.05169439e-01, -1.31258920e-01, -9.94405895e-02,
        -6.27082407e-01, -3.50991338e-01,  1.71255115e-02,
         5.83024323e-01,  3.19090605e-01,  8.65967691e-01,
         1.19633777e-02,  6.07071221e-01,  4.61080045e-01,
         9.02332485e-01, -3.18920642e-01,  3.32198411e-01,
         1.31995523e+00, -8.58501852e-01,  6.89025149e-02,
         6.23924792e-01, -1.24453738e-01, -3.08475614e-01,
         8.78603995e-01,  5.16844206e-02, -1.22600794e-01,
        -2.67618209e-01,  1.03057587e+00, -6.47677243e-01,
         7.12868214e-01,  3.22134376e-01,  4.72961694e-01,
        -7.89819241e-01, -3.16446126e-02, -8.71115923e-01,
        -2.48121455e-01,  3.04158349e-02, -2.56643355e-01,
         3.48857671e-01, -1.10874558e-02, -6.73042536e-01,
         6.95871785e-02, -1.20655894e-01,  1.51917353e-01,
         1.20873079e-02,  3.89152467e-01, -4.97280806e-02,
         5.85169196e-01, -4.91707832e-01, -5.11557162e-01