Import the necessary dependent libraries

In [1]:
import os
from module import MLPModule
from utils import aggregate_weights, save_global_measure, load_client_data, calculate_global_measures
from tqdm import tqdm
import torch
from client import EMClient
from client_selector import ClientSelector

Set the necessary parameters and hyperparameters for the experiment and model

In [2]:
seed = 54
dataset_name = "n-baiot"
selector_type = 'dqn'
num_rounds = 25
input_dim = 0
n_mix_distribute = 2
select_num = 30

if dataset_name == "unsw-nb15":
    input_dim = 47
elif dataset_name == "n-baiot":
    input_dim = 115

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps:0")
else:
    device = torch.device("cpu")

hyperparameters = {
    'input_dim': input_dim,
    'lr': 0.005,
    'models_num': n_mix_distribute,
    'hidden_neurons_num': 512,
    'batch_size': 128,
    'seed': seed,
    'device': device
} 

select_hyperparameters = {
    'enable': True,
    'is_train': False,
    'type_name': selector_type,
    'lr': 0.001,
    'global_models_num': n_mix_distribute,
    'pca_n_components': 15,
    'buffer_batch_size': 16,
    'buffer_size': 1000,
    'reward_lambda_value': 64,
    'target_accuracy': 0.99,
    'epsilon_start': 0.8,
    'epsilon_end': 0.2,
    'epsilon_decay': 100,
    'gamma': 0.99,
    'model_path': "./dqn_models/17-10-2023_16-04-52_unsw-nb15",
    'seed': seed,
    'device': device
}

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"), device)

Generate clients

In [4]:
def create_clients(client_data, hyperparameter):
    client_list = []
    for i, (train_set, test_set, val_set) in enumerate(client_data):
        c = EMClient(
            id=i,
            hyperparameters=hyperparameter,
            train_set=train_set,
            test_set=test_set,
            val_set=val_set,
        )
        client_list.append(c)
    return client_list
clients = create_clients(client_data_list, hyperparameters)

Generate client selector

In [5]:
selector = ClientSelector(clients, select_num, select_hyperparameters)

DQN selection of clients
Successfully loaded model


FedEM training process

In [6]:
# Initialize multiple global models for each mixture distribution
global_models = [MLPModule(hyperparameters) for _ in range(n_mix_distribute)]

# Pre-trained
model_weights_lists = [[] for _ in range(n_mix_distribute)]
client_n_samples_list = {}

for client in clients:
    client.receive_global_model(global_models)
    client_models_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
    client_n_samples_list[client.id] = client_n_samples

    for i in range(n_mix_distribute):
        model_weights_lists[i].append(client_models_parameters[i])

for i in range(n_mix_distribute):
    global_models[i].set_parameters(aggregate_weights(model_weights_lists[i], client_n_samples_list))
    
selector.fit_pca(global_models, model_weights_lists)

In [7]:
# Record training results
record_global_train_measures = [None] * num_rounds
record_global_val_measures = [None] * num_rounds

pbar = tqdm(total=num_rounds)
pbar.reset()
# Perform num_rounds rounds of training
for idx in range(num_rounds):
    client_n_samples_list = {}

    all_client_train_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    all_client_val_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }

    selected_client = selector.sample_clients(global_models, model_weights_lists)
    model_weights_lists = [[] for _ in range(n_mix_distribute)]

    for client in selected_client:
        # Client receives new global models
        client.receive_global_model(global_models)

        # The client trains the models locally
        client_models_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
        client_n_samples_list[client.id] = client_n_samples

        for i in range(n_mix_distribute):
            model_weights_lists[i].append(client_models_parameters[i])

        for key in all_client_train_measures:
            all_client_train_measures[key][client.id] = (client_measures['train'][key])
            all_client_val_measures[key][client.id] = (client_measures['val'][key])

    global_train_measures, global_val_measures = calculate_global_measures(selected_client,
                                                                           all_client_train_measures,
                                                                           all_client_val_measures,
                                                                           display_result=True)
    record_global_train_measures[idx] = global_train_measures
    record_global_val_measures[idx] = global_val_measures

    # Aggregate new global models
    for i in range(n_mix_distribute):
        global_models[i].set_parameters(aggregate_weights(model_weights_lists[i], client_n_samples_list))

    selector.update_dqn(global_models, model_weights_lists, global_val_measures['accuracy'])
    pbar.update(1)

  4%|▍         | 1/25 [00:21<08:45, 21.90s/it]

Global Train Set Loss: 0.3564854662808674
Global Train Set Accuracy: 0.8074460676269654
Global Train Set FPR: 0.5821795848608596
Global Train Set TPR: 0.8908942604195896
Global Train Set BER: 0.3456426622206351
Global Validation Set Loss: 0.25452951711527405
Global Validation Set Accuracy: 0.8398555383362363
Global Validation Set FPR: 0.3709891613759337
Global Validation Set TPR: 0.9312680529362287
Global Validation Set BER: 0.18652722088651916


  8%|▊         | 2/25 [00:42<08:10, 21.31s/it]

Global Train Set Loss: 0.22129732094064142
Global Train Set Accuracy: 0.9428741362823874
Global Train Set FPR: 0.34944328064585195
Global Train Set TPR: 0.8691362780749348
Global Train Set BER: 0.24015350128545848
Global Validation Set Loss: 0.09967576938431169
Global Validation Set Accuracy: 0.9952607477672912
Global Validation Set FPR: 0.05776326187657201
Global Validation Set TPR: 0.9643562250510924
Global Validation Set BER: 0.030036851746073158


 12%|█▏        | 3/25 [01:03<07:46, 21.23s/it]

Global Train Set Loss: 0.21526188774339824
Global Train Set Accuracy: 0.9453413645112726
Global Train Set FPR: 0.3435514733614559
Global Train Set TPR: 0.8694139087791487
Global Train Set BER: 0.23706878229115352
Global Validation Set Loss: 0.08625092949801491
Global Validation Set Accuracy: 0.9950362082489543
Global Validation Set FPR: 0.03706381221428087
Global Validation Set TPR: 0.9645711706154017
Global Validation Set BER: 0.019579654132772958


 16%|█▌        | 4/25 [01:24<07:21, 21.05s/it]

Global Train Set Loss: 0.21419352962889812
Global Train Set Accuracy: 0.9463027063085808
Global Train Set FPR: 0.34201624263722347
Global Train Set TPR: 0.8695523892510562
Global Train Set BER: 0.23623192669308357
Global Validation Set Loss: 0.08510445876467422
Global Validation Set Accuracy: 0.9953384374577837
Global Validation Set FPR: 0.022551467907389246
Global Validation Set TPR: 0.964681256980356
Global Validation Set BER: 0.012268438796849917


 20%|██        | 5/25 [01:46<07:03, 21.16s/it]

Global Train Set Loss: 0.21338968541188008
Global Train Set Accuracy: 0.9452770512702543
Global Train Set FPR: 0.343894463206482
Global Train Set TPR: 0.869626107550008
Global Train Set BER: 0.23713417782823684
Global Validation Set Loss: 0.08607530484643702
Global Validation Set Accuracy: 0.9931124351554086
Global Validation Set FPR: 0.0268382217937738
Global Validation Set TPR: 0.9625821371859342
Global Validation Set BER: 0.015461375637253106


 24%|██▍       | 6/25 [02:07<06:41, 21.12s/it]

Global Train Set Loss: 0.211812629881529
Global Train Set Accuracy: 0.9454562331488591
Global Train Set FPR: 0.3774119127775973
Global Train Set TPR: 0.8803133881590128
Global Train Set BER: 0.2485492623092923
Global Validation Set Loss: 0.08765403590602307
Global Validation Set Accuracy: 0.992054491606967
Global Validation Set FPR: 0.02314706568511696
Global Validation Set TPR: 0.9615759570394214
Global Validation Set BER: 0.014118887656181123


 28%|██▊       | 7/25 [02:27<06:16, 20.93s/it]

Global Train Set Loss: 0.21220826326170586
Global Train Set Accuracy: 0.9454555391103746
Global Train Set FPR: 0.34468691658259465
Global Train Set TPR: 0.8697516294176723
Global Train Set BER: 0.2374676435824611
Global Validation Set Loss: 0.09129583252449727
Global Validation Set Accuracy: 0.9916447980716279
Global Validation Set FPR: 0.027462958083538477
Global Validation Set TPR: 0.961502921193128
Global Validation Set BER: 0.016313351778538612


 32%|███▏      | 8/25 [02:48<05:55, 20.92s/it]

Global Train Set Loss: 0.21213437914909405
Global Train Set Accuracy: 0.9455511797863927
Global Train Set FPR: 0.3444551431453817
Global Train Set TPR: 0.8698024963017458
Global Train Set BER: 0.23732632342181795
Global Validation Set Loss: 0.09746663376773784
Global Validation Set Accuracy: 0.9905874118567098
Global Validation Set FPR: 0.023489329858328224
Global Validation Set TPR: 0.9604007574520939
Global Validation Set BER: 0.01487761953645048


 36%|███▌      | 9/25 [03:09<05:34, 20.90s/it]

Global Train Set Loss: 0.2115440629212113
Global Train Set Accuracy: 0.9465040821919454
Global Train Set FPR: 0.342535058819511
Global Train Set TPR: 0.8698082066002261
Global Train Set BER: 0.2363634261096424
Global Validation Set Loss: 0.10909144649256113
Global Validation Set Accuracy: 0.9906096212036061
Global Validation Set FPR: 0.024521912473203795
Global Validation Set TPR: 0.9604116492239092
Global Validation Set BER: 0.01538846495798059


 40%|████      | 10/25 [03:30<05:14, 20.99s/it]

Global Train Set Loss: 0.21136740006568897
Global Train Set Accuracy: 0.9466621899238191
Global Train Set FPR: 0.3423334857442228
Global Train Set TPR: 0.8698613635386262
Global Train Set BER: 0.23623606110279816
Global Validation Set Loss: 0.112258648178918
Global Validation Set Accuracy: 0.9915900110302014
Global Validation Set FPR: 0.021804388048280523
Global Validation Set TPR: 0.9604049316865697
Global Validation Set BER: 0.014033061514188748


 44%|████▍     | 11/25 [03:52<04:55, 21.14s/it]

Global Train Set Loss: 0.21062813081705092
Global Train Set Accuracy: 0.9468783147932351
Global Train Set FPR: 0.3418186332048563
Global Train Set TPR: 0.8698556050312564
Global Train Set BER: 0.23598151408679985
Global Validation Set Loss: 0.11906940933893963
Global Validation Set Accuracy: 0.990681668933867
Global Validation Set FPR: 0.017259881565647204
Global Validation Set TPR: 0.9604705216495744
Global Validation Set BER: 0.011728013291369662


 48%|████▊     | 12/25 [04:13<04:34, 21.15s/it]

Global Train Set Loss: 0.20969129186088048
Global Train Set Accuracy: 0.9467123621416182
Global Train Set FPR: 0.3417104054759205
Global Train Set TPR: 0.8695854513391426
Global Train Set BER: 0.23606247706838884
Global Validation Set Loss: 0.12265733682574903
Global Validation Set Accuracy: 0.9916650654622845
Global Validation Set FPR: 0.02015017243861104
Global Validation Set TPR: 0.9604817188625656
Global Validation Set BER: 0.013167560121355993


 52%|█████▏    | 13/25 [04:34<04:13, 21.08s/it]

Global Train Set Loss: 0.20990297418191742
Global Train Set Accuracy: 0.9467131365054942
Global Train Set FPR: 0.34155705146657145
Global Train Set TPR: 0.869704824640775
Global Train Set BER: 0.23592611341289813
Global Validation Set Loss: 0.1311061877286283
Global Validation Set Accuracy: 0.9909458760535708
Global Validation Set FPR: 0.017170755901008805
Global Validation Set TPR: 0.9604999678135484
Global Validation Set BER: 0.01166872737706351


 56%|█████▌    | 14/25 [04:54<03:50, 20.97s/it]

Global Train Set Loss: 0.21039174673067768
Global Train Set Accuracy: 0.9472541499990944
Global Train Set FPR: 0.3378880886921797
Global Train Set TPR: 0.8695126021130756
Global Train Set BER: 0.23418774328955194
Global Validation Set Loss: 0.1442340853250174
Global Validation Set Accuracy: 0.9883982523273128
Global Validation Set FPR: 0.020836249527244274
Global Validation Set TPR: 0.9602831586694582
Global Validation Set BER: 0.013609878762226324


 60%|██████    | 15/25 [05:13<03:23, 20.35s/it]

Global Train Set Loss: 0.21087426690854066
Global Train Set Accuracy: 0.9471610230180981
Global Train Set FPR: 0.33792747418165203
Global Train Set TPR: 0.8696216412198093
Global Train Set BER: 0.23415291648092126
Global Validation Set Loss: 0.15065405593258355
Global Validation Set Accuracy: 0.9870306746651789
Global Validation Set FPR: 0.020478209204196386
Global Validation Set TPR: 0.9591672839882854
Global Validation Set BER: 0.013988795941288783


 64%|██████▍   | 16/25 [05:34<03:04, 20.49s/it]

Global Train Set Loss: 0.21167469197361377
Global Train Set Accuracy: 0.9473111409302221
Global Train Set FPR: 0.33723050181526754
Global Train Set TPR: 0.8694147876712133
Global Train Set BER: 0.23390785707202708
Global Validation Set Loss: 0.1534388425353501
Global Validation Set Accuracy: 0.9871902411502717
Global Validation Set FPR: 0.040962161968008844
Global Validation Set TPR: 0.959177385870326
Global Validation Set BER: 0.02422572138217476


 68%|██████▊   | 17/25 [05:55<02:43, 20.47s/it]

Global Train Set Loss: 0.21289327572589398
Global Train Set Accuracy: 0.9470965659462834
Global Train Set FPR: 0.3374370188345567
Global Train Set TPR: 0.8694388528045527
Global Train Set BER: 0.233999083015002
Global Validation Set Loss: 0.1561425908800935
Global Validation Set Accuracy: 0.9859203563370365
Global Validation Set FPR: 0.025972708220721293
Global Validation Set TPR: 0.9591759970177728
Global Validation Set BER: 0.016731688934807677


 72%|███████▏  | 18/25 [06:16<02:24, 20.70s/it]

Global Train Set Loss: 0.21304215690812853
Global Train Set Accuracy: 0.9470265662389136
Global Train Set FPR: 0.33747552566874905
Global Train Set TPR: 0.8694694804158452
Global Train Set BER: 0.2340030226264519
Global Validation Set Loss: 0.1565021090404249
Global Validation Set Accuracy: 0.985838719835437
Global Validation Set FPR: 0.02523036561198499
Global Validation Set TPR: 0.9591411417598565
Global Validation Set BER: 0.016377945259397435


 76%|███████▌  | 19/25 [06:37<02:04, 20.80s/it]

Global Train Set Loss: 0.213470105129567
Global Train Set Accuracy: 0.9467748530757412
Global Train Set FPR: 0.33798619934529045
Global Train Set TPR: 0.8695619964238853
Global Train Set BER: 0.23421210146070254
Global Validation Set Loss: 0.15745141655336856
Global Validation Set Accuracy: 0.9855345609524611
Global Validation Set FPR: 0.026457480456519536
Global Validation Set TPR: 0.9591734746164996
Global Validation Set BER: 0.01697533625334322


 80%|████████  | 20/25 [06:57<01:42, 20.57s/it]

Global Train Set Loss: 0.21479440406936906
Global Train Set Accuracy: 0.9464244626975804
Global Train Set FPR: 0.33854548654336103
Global Train Set TPR: 0.8698517998474558
Global Train Set BER: 0.2343468433479527
Global Validation Set Loss: 0.16020947899568092
Global Validation Set Accuracy: 0.9832384329162319
Global Validation Set FPR: 0.03001486276785884
Global Validation Set TPR: 0.9591549118568091
Global Validation Set BER: 0.018763308788858162


 84%|████████▍ | 21/25 [07:18<01:23, 20.84s/it]

Global Train Set Loss: 0.21428868550842906
Global Train Set Accuracy: 0.9464509678320611
Global Train Set FPR: 0.3385968653285344
Global Train Set TPR: 0.8699416079680494
Global Train Set BER: 0.23432762868024254
Global Validation Set Loss: 0.15946946443083201
Global Validation Set Accuracy: 0.9836389416309438
Global Validation Set FPR: 0.04999525034209147
Global Validation Set TPR: 0.9591733493873732
Global Validation Set BER: 0.028744283810692453


 88%|████████▊ | 22/25 [07:40<01:02, 20.96s/it]

Global Train Set Loss: 0.21588448749315198
Global Train Set Accuracy: 0.9461198705462833
Global Train Set FPR: 0.33930963196491354
Global Train Set TPR: 0.8703451156370685
Global Train Set BER: 0.23448225816392243
Global Validation Set Loss: 0.16027811082807983
Global Validation Set Accuracy: 0.982906178785262
Global Validation Set FPR: 0.03142908091908945
Global Validation Set TPR: 0.9592288337781255
Global Validation Set BER: 0.01943345690381527


 92%|█████████▏| 23/25 [08:00<00:41, 20.81s/it]

Global Train Set Loss: 0.21560443451910913
Global Train Set Accuracy: 0.9464728508122583
Global Train Set FPR: 0.3388939024790989
Global Train Set TPR: 0.8704639327196433
Global Train Set BER: 0.23421498487972767
Global Validation Set Loss: 0.15882652595539376
Global Validation Set Accuracy: 0.9833311383997937
Global Validation Set FPR: 0.03107081687610602
Global Validation Set TPR: 0.9592018157448625
Global Validation Set BER: 0.019267833898955075


 96%|█████████▌| 24/25 [08:20<00:20, 20.71s/it]

Global Train Set Loss: 0.21494612249685738
Global Train Set Accuracy: 0.9464854550972207
Global Train Set FPR: 0.33879513100336217
Global Train Set TPR: 0.870159209107618
Global Train Set BER: 0.23431796094787202
Global Validation Set Loss: 0.1602703723943003
Global Validation Set Accuracy: 0.9858141197690318
Global Validation Set FPR: 0.02641452271092049
Global Validation Set TPR: 0.9591596476185734
Global Validation Set BER: 0.01696077087950692


100%|██████████| 25/25 [08:42<00:00, 20.87s/it]

Global Train Set Loss: 0.21650740987837333
Global Train Set Accuracy: 0.9451772991465828
Global Train Set FPR: 0.3414562242732902
Global Train Set TPR: 0.8705391109380488
Global Train Set BER: 0.23545855666762067
Global Validation Set Loss: 0.15998834145939908
Global Validation Set Accuracy: 0.9829222690265278
Global Validation Set FPR: 0.04704352234344312
Global Validation Set TPR: 0.9590679939481014
Global Validation Set BER: 0.0273210975310042


Save experiment results

In [8]:
suffix_name = ""
if select_num != 100 and selector_type != "random":
    suffix_name = "_dqn_"+str(select_num)
elif select_num != 100: 
    suffix_name = "_random_"+str(select_num)
    
save_global_measure(record_global_train_measures, "train_measures_"+dataset_name+".csv", "FedEM"+suffix_name)
save_global_measure(record_global_val_measures, "val_measures_"+dataset_name+".csv", "FedEM"+suffix_name)

Saved measures to Experimental_results/FedEM_dqn_30/train_measures_n-baiot.csv
Saved measures to Experimental_results/FedEM_dqn_30/val_measures_n-baiot.csv


Save DQN model

In [9]:
selector.save_model(dataset_name)

Test the performance of the model on the unseen test set

In [10]:
all_client_measures = {
    'accuracy': [],
    'fpr': [],
    'tpr': [],
    'ber': [],
    'loss': []
}
client_n_samples_list = {}

# Iterate over clients and evaluate using each global model
for client in clients:
    client_n_samples = 0
    client_measures_sum = {
        'accuracy': 0,
        'fpr': 0,
        'tpr': 0,
        'ber': 0,
        'loss': 0
    }

    # Client receives new global models
    client.receive_global_model(global_models)

    # Evaluate using each global model and aggregate results
    for global_model in global_models:
        client_measures, current_n_samples = client.evaluate_test_set()
        client_n_samples += current_n_samples
        for key in client_measures:
            client_measures_sum[key] += client_measures[key]

    # Average the measures over all models
    for key in client_measures_sum:
        client_measures_sum[key] /= len(global_models)

    client_n_samples_list[client.id] = client_n_samples

    for key in all_client_measures:
        all_client_measures[key].append(client_measures_sum[key])

global_measures = {}
for key in all_client_measures:
    global_measures[key] = sum([all_client_measures[key][c.id] * client_n_samples_list[c.id] for c in clients]) / sum(client_n_samples_list.values())

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])

Global Test Set Loss: 0.07198309101843382
Global Test Set Accuracy: 0.9941475176682385
Global Test Set FPR: 0.026303188812686725
Global Test Set TPR: 0.9864316147533704
Global Test Set BER: 0.014600803740417173
