Import the necessary dependent libraries

In [None]:
import os
from module import MLPModule
from utils import aggregate_weights, save_global_measure, load_client_data, calculate_global_measures
from tqdm import tqdm
import torch
from client import EMClient
from client_selector import ClientSelector
import time

Set the necessary parameters and hyperparameters for the experiment and model

In [None]:
seed = 54
dataset_name = "n-baiot"
selector_type = 'random'
num_rounds = 25
input_dim = 0
n_mix_distribute = 2
select_num = 50

if dataset_name == "unsw-nb15":
    input_dim = 47
elif dataset_name == "n-baiot":
    input_dim = 115

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps:0")
else:
    device = torch.device("cpu")

hyperparameters = {
    'input_dim': input_dim,
    'lr': 0.00001,
    'models_num': n_mix_distribute,
    'hidden_neurons_num': 512,
    'batch_size': 128,
    'seed': seed,
    'device': device
} 

select_hyperparameters = {
    'enable': True,
    'is_train': False,
    'type_name': selector_type,
    'lr': 0.001,
    'global_models_num': n_mix_distribute,
    'pca_n_components': 20,
    'buffer_batch_size': 16,
    'buffer_size': 1000,
    'reward_lambda_value': 64,
    'target_accuracy': 0.99,
    'epsilon_start': 0.8,
    'epsilon_end': 0.2,
    'epsilon_decay': 100,
    'gamma': 0.99,
    'model_path': "./dqn_models/05-12-2023_14-35-37_n-baiot",
    'seed': seed,
    'device': device
}

In [None]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"), device)

Generate clients

In [None]:
def create_clients(client_data, hyperparameter):
    client_list = []
    for i, (train_set, test_set, val_set) in enumerate(client_data):
        c = EMClient(
            id=i,
            hyperparameters=hyperparameter,
            train_set=train_set,
            test_set=test_set,
            val_set=val_set,
        )
        client_list.append(c)
    return client_list
clients = create_clients(client_data_list, hyperparameters)

Generate client selector

In [None]:
selector = ClientSelector(clients, select_num, select_hyperparameters)

FedEM training process

In [None]:
# Initialize multiple global models for each mixture distribution
global_models = [MLPModule(hyperparameters) for _ in range(n_mix_distribute)]

# Pre-trained
model_weights_lists = [[] for _ in range(n_mix_distribute)]
client_n_samples_list = {}

for client in clients:
    client.receive_global_model(global_models)
    client_models_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
    client_n_samples_list[client.id] = client_n_samples

    for i in range(n_mix_distribute):
        model_weights_lists[i].append(client_models_parameters[i])

for i in range(n_mix_distribute):
    global_models[i].set_parameters(aggregate_weights(model_weights_lists[i], client_n_samples_list))
    
selector.fit_pca(global_models, model_weights_lists)

In [None]:
# Record training results
record_global_train_measures = [None] * num_rounds
record_global_val_measures = [None] * num_rounds

pbar = tqdm(total=num_rounds)

train_time_record = [None] * num_rounds
start_time = time.time()

# Perform num_rounds rounds of training
for idx in range(num_rounds):
    # print("idx: ", idx)
    client_n_samples_list = {}

    all_client_train_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    all_client_val_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }

    selected_client = selector.sample_clients(global_models, model_weights_lists)
    model_weights_lists = [[] for _ in range(n_mix_distribute)]

    for client in selected_client:
        # Client receives new global models
        client.receive_global_model(global_models)

        # The client trains the models locally
        client_models_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
        client_n_samples_list[client.id] = client_n_samples

        for i in range(n_mix_distribute):
            model_weights_lists[i].append(client_models_parameters[i])

        for key in all_client_train_measures:
            all_client_train_measures[key][client.id] = (client_measures['train'][key])
            all_client_val_measures[key][client.id] = (client_measures['val'][key])

    global_train_measures, global_val_measures = calculate_global_measures(selected_client,
                                                                           all_client_train_measures,
                                                                           all_client_val_measures,
                                                                           display_result=True)
    
    record_global_train_measures[idx] = global_train_measures
    record_global_val_measures[idx] = global_val_measures

    # Aggregate new global models
    for i in range(n_mix_distribute):
        global_models[i].set_parameters(aggregate_weights(model_weights_lists[i], client_n_samples_list))

    selector.update_dqn(global_models, model_weights_lists, global_val_measures['accuracy'])

    train_time_record[idx] = time.time() - start_time
    pbar.update(1)
pbar.close()

Save experiment results

In [None]:
name = ""
if name == 100:
    name = "FEDQ"
if selector_type != "random":
    name = "FEDQ"+str(select_num)
elif name != 100:
    name = "FedEM_random_"+str(select_num)
    
save_global_measure(record_global_train_measures, "train_measures_"+dataset_name+".csv", name)
save_global_measure(record_global_val_measures, "val_measures_"+dataset_name+".csv", name)

Save DQN model

In [None]:
selector.save_model(dataset_name)

Test the performance of the model on the unseen test set

In [None]:
all_client_measures = {
    'accuracy': [],
    'fpr': [],
    'tpr': [],
    'ber': [],
    'loss': []
}
client_n_samples_list = {}

# Iterate over clients and evaluate using each global model
for client in clients:
    client_n_samples = 0
    client_measures_sum = {
        'accuracy': 0,
        'fpr': 0,
        'tpr': 0,
        'ber': 0,
        'loss': 0
    }

    # Client receives new global models
    client.receive_global_model(global_models)

    # Evaluate using each global model and aggregate results
    for global_model in global_models:
        client_measures, current_n_samples = client.evaluate_test_set()
        client_n_samples += current_n_samples
        for key in client_measures:
            client_measures_sum[key] += client_measures[key]

    # Average the measures over all models
    for key in client_measures_sum:
        client_measures_sum[key] /= len(global_models)

    client_n_samples_list[client.id] = client_n_samples

    for key in all_client_measures:
        all_client_measures[key].append(client_measures_sum[key])

global_measures = {}
for key in all_client_measures:
    global_measures[key] = sum([all_client_measures[key][c.id] * client_n_samples_list[c.id] for c in clients]) / sum(client_n_samples_list.values())

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])