In [3]:
import warnings
warnings.filterwarnings("ignore")
import argparse
import concurrent.futures
import contextlib
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle as pkl
import seaborn as sns
import shutil
import spur
import sys
import tensorflow as tf
import time
sns.set()

2022-05-26 10:44:51.267163: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/slurm-20.11.0/lib64:
2022-05-26 10:44:51.267211: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [24]:
##create the keras model (LR in this case)
def create_keras_model():
    initializer = tf.keras.initializers.GlorotNormal(seed=0)
    ##build LR model
    number_of_classes = 1
    number_of_features = 27
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(number_of_classes,activation = 'sigmoid',input_dim = number_of_features))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['AUC'])
    return model

In [22]:
def clear_clients(client, clientServer, usr_client, pwd_client):
    ##clear models from clients
    shell = spur.SshShell(
        hostname=clientServer, username=usr_client, password=pwd_client)
    command = f'rm -rf {path}{client}/model/ *'
    command.split(' ')
    shell.run(command)
    return

In [20]:
def run_clients(client, clientServer, centralServer, usr_client, pwd_client, usr_server, pwd_server, epochs):
    shell = spur.LocalShell()
    ##send model to clients
    command = f'sshpass -p {pwd_client} scp -r {dir}/model/server_models/current_model {usr_client}@{clientServer}:{path}{client}/model'
    command = command.split(' ')
    shell.run(command)
    ##Run script
    shell = spur.SshShell(
        hostname=clientServer, username=usr_client, password=pwd_client)
    command = f'python {path}{client}/clientServer.ipynb {client} {usr_server} {pwd_server} {centralServer} {epochs}'
    command.split(' ')
    server_response = shell.run(command).output.decode('utf-8').split(' ')
    return

In [28]:
def fedAvg(server_1_model, server_2_model, relative_weights):
    weights = [server_1_model.get_weights(), server_2_model.get_weights()]
    new_weights = []
    for weights_list_tuple in zip(*weights):
        new_weights.append(np.array([np.average(np.array(
            weights_), axis=0, weights=relative_weights) for weights_ in zip(*weights_list_tuple)]))
    return new_weights

In [29]:
def validateResults(server_1_response, server_2_response):
    # Check if the current validation is better than the previous best one
    relative_weights = [server_1_response[0] / (server_1_response[0] + server_2_response[0]),
                        server_2_response[0] / (server_1_response[0] + server_2_response[0])]

    current_validation = server_1_response[3] * relative_weights[0] + server_2_response[3] * relative_weights[1]
    
    return relative_weights, current_validation

In [None]:
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-u0','--userNameServer', default = 'aelhussein')
    parser.add_argument('-p0','--passwordServer')
    parser.add_argument('-u1','--userNameClient1', default = 'aelhussein')
    parser.add_argument('-p1','--passwordClient1')
    parser.add_argument('-u2','--userNameClient2', default = 'aelhussein')
    parser.add_argument('-p2','--passwordClient2')
    
    usr0, pwd0 = args.userNameServer, args.passwordServer
    usr1, pwd1 = args.userNameClient1, args.passwordClient1
    usr2, pwd2 = args.userNameClient2, args.passwordClient2
    client_1 = 'client_1'
    client_2 = 'client_2'
    centralServer = 'pe2cc3-002'
    clientServer_1 = 'pe2cc3'
    clientServer_2 = 'pe2cc3'
    
    # Hyperparameters
    patience = 10
    epochs = 100
    
    ##load directory
    __file__ = 'centralServer.ipynb'
    dir = os.path.abspath(os.path.dirname(__file__))
    path = '/gpfs/commons/groups/gursoy_lab/aelhussein/DCI_FL/'
    
    # Delete past client data
    with concurrent.futures.ThreadPoolExecutor() as executor:
        command_1 = executor.submit(
            clear_clients, client_1, clientServer_1, usr1, pwd1)
        command_2 = executor.submit(
            clear_clients, client_2, clientServer_2, usr2, pwd2)

    # Model architecture
    federated_model = create_keras_model()
    federated_model.save(f'{path}server/model/server_models/current_model')
    federated_model.save(f'{path}server/model/server_models/best_model')
    
    # Set runtime parameters
    patience_counter = 0
    lowest_validation = float('inf')
    early_stopping = False

    while(early_stopping == False):
        # Run model
        with concurrent.futures.ThreadPoolExecutor() as executor:
            command_1 = executor.submit(run_clients, client_1, clientServer_1, centralServer, usr1, pwd1, usr0, pwd0,
                                         epochs)
            command_2 = executor.submit(run_clients, client_2, clientServer_2, centralServer, usr2, pwd2, usr0, pwd0,
                                         epochs)
            
            # Retrieve server responses
            server_1_response = [float(j) for j in command_1.result()]
            server_2_response = [float(j) for j in command_2.result()]

        # Wait until the two servers return their weights files
        while (os.path.isfile(path + 'server/model/client_models/client_1/model/saved_model.pb') == False or
                   os.path.isfile(path + 'server/model/client_models/client_2/model/saved_model.pb' == False)):
            time.sleep(5)

        relative_weights, current_validation = validateResults(server_1_response, server_2_response)

        if current_validation < lowest_validation:
            patience_counter = 0
            federated_model.save(f'{path}server/model/best_model')
            lowest_validation = current_validation
            print(f'Validation Loss: {current_validation}')
        else:
            patience_counter += 1

        # Conduct federated averaging to update the federated_model if we have not exceeded patience
        if patience_counter > patience:
            early_stopping = True
        else:
            server_1_model = tf.keras.models.load_model(path + 'server/model/client_models/client_1/model')
            server_2_model = tf.keras.models.load_model(path + 'server/model/client_models/client_2/model')
            new_weights = fedAvg(server_1_model, server_2_model, relative_weights)
            federated_model.set_weights(new_weights)
            federated_model.save(f'{path}model/server_models/current_model')

if __name__ == '__main__':
    main()