Import the necessary dependent libraries

In [1]:
import os
from module import MLPModule
from utils import aggregate_weights, save_global_measure, load_client_data, calculate_global_measures
from tqdm import tqdm
import torch
from client import Client
import time

Set the necessary parameters and hyperparameters for the experiment and model

In [2]:
seed = 54
dataset_name = "n-baiot"
num_rounds = 25
input_dim = 0

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps:0")
else:
    device = torch.device("cpu")

hyperparameters = {
    'input_dim': input_dim,
    'lr': 0.001,
    'hidden_neurons_num': 512,
    'batch_size': 128,
    'seed': seed,
    'device': device
}

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"), device)

In [4]:
def process_client_data(data_list, keep_columns):
    processed_data_list = []

    for client_data in data_list:
        X_train, Y_train = client_data[0]
        processed_X_train = X_train[:, keep_columns]
        processed_train = (processed_X_train, Y_train)

        X_val, Y_val = client_data[1]
        processed_X_val = X_val[:, keep_columns]
        processed_val = (processed_X_val, Y_val)

        X_test, Y_test = client_data[2]
        processed_X_test = X_test[:, keep_columns]
        processed_test = (processed_X_test, Y_test)

        processed_client_data = (processed_train, processed_val, processed_test)
        processed_data_list.append(processed_client_data)

    return processed_data_list

In [5]:
columns_to_keep = [0, 2, 4, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 28, 29, 38, 39]
if dataset_name == "n-baiot":
    columns_to_keep = [0, 3, 6, 30, 37, 44, 31, 38, 45, 32, 39, 46, 60, 63, 66, 75, 82, 89, 76, 83, 90]
client_data_list = process_client_data(client_data_list, columns_to_keep)
hyperparameters['input_dim'] = len(columns_to_keep)

Generate clients

In [6]:
def create_clients(client_data, hyperparameter):
    client_list = []
    for i, (train_set, test_set, val_set) in enumerate(client_data):
        c = Client(
            id=i,
            hyperparameters=hyperparameter,
            train_set=train_set,
            test_set=test_set,
            val_set=val_set
        )
        client_list.append(c)
    return client_list
clients = create_clients(client_data_list, hyperparameters)

FedAvg training process

In [7]:
# Initialize model
global_model = MLPModule(hyperparameters)

# Pre-trained
model_weights_list = []
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
    client_n_samples_list[client.id] = client_n_samples
    model_weights_list.append(client_model_parameters)
    
global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))

In [8]:
# Record training results
record_global_train_measures = [None] * num_rounds
record_global_val_measures = [None] * num_rounds

train_time_record = [None] * num_rounds
start_time = time.time()

pbar = tqdm(total=num_rounds)
# Perform num_rounds rounds of training
for idx in range(num_rounds):
    model_weights_list = []
    client_n_samples_list = {}
    all_client_train_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    all_client_val_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    
    for client in clients:
        # Client receives new global model
        client.receive_global_model(global_model)

        # The client trains the model locally
        client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
        client_n_samples_list[client.id] = client_n_samples
        model_weights_list.append(client_model_parameters)
        for key in all_client_train_measures:
            all_client_train_measures[key][client.id] = (client_measures['train'][key])
            all_client_val_measures[key][client.id] = (client_measures['val'][key])

    global_train_measures, global_val_measures = calculate_global_measures(clients,
                                                                           all_client_train_measures,
                                                                           all_client_val_measures,
                                                                           display_result=True)
    record_global_train_measures[idx] = global_train_measures
    record_global_val_measures[idx] = global_val_measures

    # Aggregate new global model
    global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))

    train_time_record[idx] = time.time() - start_time
    pbar.update(1)
    
pbar.reset()

  4%|▍         | 1/25 [00:55<22:10, 55.46s/it]

Global Train Set Loss: 0.35237836485052554
Global Train Set Accuracy: 0.8855949277949189
Global Train Set FPR: 0.5663305512451771
Global Train Set TPR: 0.8884891691654901
Global Train Set BER: 0.33892069103984346
Global Validation Set Loss: 0.14491426150545966
Global Validation Set Accuracy: 0.9835914303582998
Global Validation Set FPR: 0.16968653007012763
Global Validation Set TPR: 0.9743422630046432
Global Validation Set BER: 0.09767213353274229


  8%|▊         | 2/25 [01:44<19:42, 51.42s/it]

Global Train Set Loss: 0.3424606733703211
Global Train Set Accuracy: 0.8921381058838201
Global Train Set FPR: 0.5506329713276363
Global Train Set TPR: 0.886084858639
Global Train Set BER: 0.3322740563443183
Global Validation Set Loss: 0.12495363773821426
Global Validation Set Accuracy: 0.9963657295287512
Global Validation Set FPR: 0.09348543133328198
Global Validation Set TPR: 0.9781195267649561
Global Validation Set BER: 0.05768295228416307


KeyboardInterrupt: 

Save experiment results

In [None]:
save_global_measure(record_global_train_measures, "train_measures_"+dataset_name+".csv", "FedTrust")
save_global_measure(record_global_val_measures, "val_measures_"+dataset_name+".csv", "FedTrust")

Test the performance of the model on the unseen test set

In [None]:
all_client_measures = {
    'accuracy': [],
    'fpr': [],
    'tpr': [],
    'ber': [],
    'loss': []
}
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_measures, client_n_samples = client.evaluate_test_set()
    client_n_samples_list[client.id] = client_n_samples

    for key in all_client_measures:
        all_client_measures[key].append(client_measures[key])

global_measures = {}
for key in all_client_measures:
    global_measures[key] = sum([all_client_measures[key][c.id] * client_n_samples_list[c.id] for c in clients]) / sum(client_n_samples_list.values())

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])