Import the necessary dependent libraries

In [1]:
import os
from module import MLPModule
from utils import aggregate_weights, save_global_measure, load_client_data, calculate_global_measures
from tqdm import tqdm
import torch
from client import Client

Set the necessary parameters and hyperparameters for the experiment and model

In [2]:
seed = 54
dataset_name = "n-baiot"
num_rounds = 25
input_dim = 0

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps:0")
else:
    device = torch.device("cpu")

hyperparameters = {
    'input_dim': input_dim,
    'lr': 0.001,
    'hidden_neurons_num': 512,
    'batch_size': 128,
    'seed': seed,
    'device': device
}

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"), device)

In [4]:
def process_client_data(data_list, keep_columns):
    processed_data_list = []

    for client_data in data_list:
        X_train, Y_train = client_data[0]
        processed_X_train = X_train[:, keep_columns]
        processed_train = (processed_X_train, Y_train)

        X_val, Y_val = client_data[1]
        processed_X_val = X_val[:, keep_columns]
        processed_val = (processed_X_val, Y_val)

        X_test, Y_test = client_data[2]
        processed_X_test = X_test[:, keep_columns]
        processed_test = (processed_X_test, Y_test)

        processed_client_data = (processed_train, processed_val, processed_test)
        processed_data_list.append(processed_client_data)

    return processed_data_list

In [5]:
columns_to_keep = [0, 2, 4, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 28, 29, 38, 39]
if dataset_name == "n-baiot":
    columns_to_keep = [0, 3, 6, 30, 37, 44, 31, 38, 45, 32, 39, 46, 60, 63, 66, 75, 82, 89, 76, 83, 90]
client_data_list = process_client_data(client_data_list, columns_to_keep)
hyperparameters['input_dim'] = len(columns_to_keep)

Generate clients

In [6]:
def create_clients(client_data, hyperparameter):
    client_list = []
    for i, (train_set, test_set, val_set) in enumerate(client_data):
        c = Client(
            id=i,
            hyperparameters=hyperparameter,
            train_set=train_set,
            test_set=test_set,
            val_set=val_set
        )
        client_list.append(c)
    return client_list
clients = create_clients(client_data_list, hyperparameters)

FedAvg training process

In [7]:
# Initialize model
global_model = MLPModule(hyperparameters)

# Pre-trained
model_weights_list = []
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
    client_n_samples_list[client.id] = client_n_samples
    model_weights_list.append(client_model_parameters)
    
global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))

In [8]:
# Record training results
record_global_train_measures = [None] * num_rounds
record_global_val_measures = [None] * num_rounds

pbar = tqdm(total=num_rounds)
# Perform num_rounds rounds of training
for idx in range(num_rounds):
    model_weights_list = []
    client_n_samples_list = {}
    all_client_train_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    all_client_val_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    
    for client in clients:
        # Client receives new global model
        client.receive_global_model(global_model)

        # The client trains the model locally
        client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
        client_n_samples_list[client.id] = client_n_samples
        model_weights_list.append(client_model_parameters)
        for key in all_client_train_measures:
            all_client_train_measures[key][client.id] = (client_measures['train'][key])
            all_client_val_measures[key][client.id] = (client_measures['val'][key])

    global_train_measures, global_val_measures = calculate_global_measures(clients,
                                                                           all_client_train_measures,
                                                                           all_client_val_measures,
                                                                           display_result=True)
    record_global_train_measures[idx] = global_train_measures
    record_global_val_measures[idx] = global_val_measures

    pbar.update(1)

    # Aggregate new global model
    global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))
    
pbar.reset()

  4%|▍         | 1/25 [00:40<16:07, 40.32s/it]

Global Train Set Loss: 0.3245374543008779
Global Train Set Accuracy: 0.8512858724198814
Global Train Set FPR: 0.580530196545936
Global Train Set TPR: 0.8561349479387258
Global Train Set BER: 0.3571976243036049
Global Validation Set Loss: 0.20218206264819666
Global Validation Set Accuracy: 0.8896984106989878
Global Validation Set FPR: 0.34437682093677296
Global Validation Set TPR: 0.8743494384217513
Global Validation Set BER: 0.1900136912575108


  8%|▊         | 2/25 [01:20<15:24, 40.21s/it]

Global Train Set Loss: 0.2847062147513999
Global Train Set Accuracy: 0.9021369464907845
Global Train Set FPR: 0.48741504999435165
Global Train Set TPR: 0.8372423002624305
Global Train Set BER: 0.3200863748659607
Global Validation Set Loss: 0.1515230894777387
Global Validation Set Accuracy: 0.9432002210104682
Global Validation Set FPR: 0.20128610347914144
Global Validation Set TPR: 0.8955144610440631
Global Validation Set BER: 0.10788582121753913


 12%|█▏        | 3/25 [02:01<14:53, 40.63s/it]

Global Train Set Loss: 0.28079344930778055
Global Train Set Accuracy: 0.9052725853806952
Global Train Set FPR: 0.48100684482365247
Global Train Set TPR: 0.8339968061173814
Global Train Set BER: 0.31850501935313547
Global Validation Set Loss: 0.14623026470576306
Global Validation Set Accuracy: 0.9485111479099027
Global Validation Set FPR: 0.1774785444783252
Global Validation Set TPR: 0.8993179128818237
Global Validation Set BER: 0.09408031579825087


 16%|█▌        | 4/25 [02:43<14:23, 41.11s/it]

Global Train Set Loss: 0.28021027125266845
Global Train Set Accuracy: 0.9060273244937427
Global Train Set FPR: 0.47993540326679424
Global Train Set TPR: 0.8343631535960552
Global Train Set BER: 0.31778612483536944
Global Validation Set Loss: 0.1462634313142812
Global Validation Set Accuracy: 0.9488533325706715
Global Validation Set FPR: 0.17218968891179337
Global Validation Set TPR: 0.9025106577532805
Global Validation Set BER: 0.08983951557925651


 20%|██        | 5/25 [03:24<13:43, 41.18s/it]

Global Train Set Loss: 0.28008824244315317
Global Train Set Accuracy: 0.9058858040696292
Global Train Set FPR: 0.4792111039487969
Global Train Set TPR: 0.8337583461751169
Global Train Set BER: 0.31772637888684013
Global Validation Set Loss: 0.1468104397718333
Global Validation Set Accuracy: 0.9480771923646951
Global Validation Set FPR: 0.17200023619415997
Global Validation Set TPR: 0.9028403073241126
Global Validation Set BER: 0.08957996443502383


 24%|██▍       | 6/25 [04:05<13:00, 41.07s/it]

Global Train Set Loss: 0.2786917693018257
Global Train Set Accuracy: 0.9072706796273279
Global Train Set FPR: 0.47704928015625614
Global Train Set TPR: 0.833899542425976
Global Train Set BER: 0.31657486886514
Global Validation Set Loss: 0.14584627760285115
Global Validation Set Accuracy: 0.9490895112787041
Global Validation Set FPR: 0.16790918947959352
Global Validation Set TPR: 0.9025828088443114
Global Validation Set BER: 0.08766319031764107


 28%|██▊       | 7/25 [04:46<12:18, 41.01s/it]

Global Train Set Loss: 0.276761580456839
Global Train Set Accuracy: 0.9092522705886082
Global Train Set FPR: 0.4739607785094276
Global Train Set TPR: 0.8339780202621755
Global Train Set BER: 0.31499137912362596
Global Validation Set Loss: 0.14373213970383217
Global Validation Set Accuracy: 0.9517546750170481
Global Validation Set FPR: 0.16428498863433727
Global Validation Set TPR: 0.9026205012934537
Global Validation Set BER: 0.08583224367044183


 32%|███▏      | 8/25 [05:28<11:41, 41.27s/it]

Global Train Set Loss: 0.27622551359957775
Global Train Set Accuracy: 0.9102109724537364
Global Train Set FPR: 0.4724376589454429
Global Train Set TPR: 0.8334633148436048
Global Train Set BER: 0.3144871720509188
Global Validation Set Loss: 0.14320954188693982
Global Validation Set Accuracy: 0.952319241138106
Global Validation Set FPR: 0.17136177271162584
Global Validation Set TPR: 0.9065850946294229
Global Validation Set BER: 0.08738833904110145


 36%|███▌      | 9/25 [06:10<11:03, 41.46s/it]

Global Train Set Loss: 0.2755909646628037
Global Train Set Accuracy: 0.9107962548790829
Global Train Set FPR: 0.47128068070789125
Global Train Set TPR: 0.8332510842553791
Global Train Set BER: 0.314014798226256
Global Validation Set Loss: 0.14247056579354356
Global Validation Set Accuracy: 0.9547552338420229
Global Validation Set FPR: 0.17891329500431322
Global Validation Set TPR: 0.9066913467295394
Global Validation Set BER: 0.09111097413738686


 40%|████      | 10/25 [06:51<10:21, 41.42s/it]

Global Train Set Loss: 0.2744897994274657
Global Train Set Accuracy: 0.9118232042155351
Global Train Set FPR: 0.4693568493241044
Global Train Set TPR: 0.832497704556795
Global Train Set BER: 0.3134295723836548
Global Validation Set Loss: 0.1411437806610648
Global Validation Set Accuracy: 0.9557974189226749
Global Validation Set FPR: 0.17682274959374567
Global Validation Set TPR: 0.9067501383933638
Global Validation Set BER: 0.0900363056001908


 44%|████▍     | 11/25 [07:32<09:37, 41.28s/it]

Global Train Set Loss: 0.27428092133812076
Global Train Set Accuracy: 0.9122984740723491
Global Train Set FPR: 0.46822056109018617
Global Train Set TPR: 0.8325820040445205
Global Train Set BER: 0.3128192785228328
Global Validation Set Loss: 0.14087051006632872
Global Validation Set Accuracy: 0.957030439688182
Global Validation Set FPR: 0.1756124380781046
Global Validation Set TPR: 0.9067616162955631
Global Validation Set BER: 0.08942541089127082


 48%|████▊     | 12/25 [08:13<08:56, 41.24s/it]

Global Train Set Loss: 0.27309092303745136
Global Train Set Accuracy: 0.9130815330893906
Global Train Set FPR: 0.46685660850343624
Global Train Set TPR: 0.8324141823943343
Global Train Set BER: 0.3122212130545511
Global Validation Set Loss: 0.1395104721546826
Global Validation Set Accuracy: 0.9583447865179625
Global Validation Set FPR: 0.17302736022910278
Global Validation Set TPR: 0.9070462427151047
Global Validation Set BER: 0.08799055875699917


 52%|█████▏    | 13/25 [08:54<08:14, 41.24s/it]

Global Train Set Loss: 0.27252826467815794
Global Train Set Accuracy: 0.9138857532164195
Global Train Set FPR: 0.465467225982769
Global Train Set TPR: 0.8314873623572242
Global Train Set BER: 0.3119899318127722
Global Validation Set Loss: 0.13875212204330756
Global Validation Set Accuracy: 0.9635998132310841
Global Validation Set FPR: 0.1673060043380995
Global Validation Set TPR: 0.9070825789223754
Global Validation Set BER: 0.08511171270786218


 56%|█████▌    | 14/25 [09:34<07:29, 40.87s/it]

Global Train Set Loss: 0.2725031420438702
Global Train Set Accuracy: 0.9142302765002128
Global Train Set FPR: 0.46507680795575296
Global Train Set TPR: 0.8315466103228323
Global Train Set BER: 0.3117650988164604
Global Validation Set Loss: 0.13862688561644979
Global Validation Set Accuracy: 0.9636196177198957
Global Validation Set FPR: 0.16698559302115867
Global Validation Set TPR: 0.9070988670760814
Global Validation Set BER: 0.08494336297253864


 60%|██████    | 15/25 [10:16<06:50, 41.05s/it]

Global Train Set Loss: 0.27171095306350024
Global Train Set Accuracy: 0.9147297487301514
Global Train Set FPR: 0.4639465098187946
Global Train Set TPR: 0.8315532636877543
Global Train Set BER: 0.31119662306552026
Global Validation Set Loss: 0.13782325369053394
Global Validation Set Accuracy: 0.9642167594114245
Global Validation Set FPR: 0.16645760235085663
Global Validation Set TPR: 0.9071043604261934
Global Validation Set BER: 0.08467662096233176


 64%|██████▍   | 16/25 [10:58<06:11, 41.24s/it]

Global Train Set Loss: 0.27161503312250734
Global Train Set Accuracy: 0.9149642316001639
Global Train Set FPR: 0.46363680318121003
Global Train Set TPR: 0.8309216567094907
Global Train Set BER: 0.31135757323585983
Global Validation Set Loss: 0.1375988665591816
Global Validation Set Accuracy: 0.9644411920385817
Global Validation Set FPR: 0.1650526440408674
Global Validation Set TPR: 0.9071426790638664
Global Validation Set BER: 0.08395498248850077


 68%|██████▊   | 17/25 [11:39<05:29, 41.16s/it]

Global Train Set Loss: 0.2710841416668369
Global Train Set Accuracy: 0.9153530231460708
Global Train Set FPR: 0.46318946730558
Global Train Set TPR: 0.8309475397669249
Global Train Set BER: 0.31112096376932746
Global Validation Set Loss: 0.13715862047043978
Global Validation Set Accuracy: 0.9644212239674943
Global Validation Set FPR: 0.15517221154753233
Global Validation Set TPR: 0.9071569964475159
Global Validation Set BER: 0.07900760755000821


 72%|███████▏  | 18/25 [12:19<04:46, 40.90s/it]

Global Train Set Loss: 0.27108042137650523
Global Train Set Accuracy: 0.9170191068359179
Global Train Set FPR: 0.46143721779282026
Global Train Set TPR: 0.8310872891977801
Global Train Set BER: 0.31017496429752023
Global Validation Set Loss: 0.13720044860975347
Global Validation Set Accuracy: 0.9645141550969822
Global Validation Set FPR: 0.154577763829086
Global Validation Set TPR: 0.9071566872162066
Global Validation Set BER: 0.07871053830643981


 76%|███████▌  | 19/25 [13:01<04:06, 41.16s/it]

Global Train Set Loss: 0.27042530902357953
Global Train Set Accuracy: 0.9171379500781482
Global Train Set FPR: 0.461315964534918
Global Train Set TPR: 0.831162294897282
Global Train Set BER: 0.3100768348188179
Global Validation Set Loss: 0.1366984167810743
Global Validation Set Accuracy: 0.9653252104346376
Global Validation Set FPR: 0.15325511430179842
Global Validation Set TPR: 0.9071677988396948
Global Validation Set BER: 0.07804365773105199


 80%|████████  | 20/25 [13:42<03:26, 41.26s/it]

Global Train Set Loss: 0.27014431284379403
Global Train Set Accuracy: 0.9172165570458121
Global Train Set FPR: 0.46091409386165444
Global Train Set TPR: 0.8313398730255284
Global Train Set BER: 0.30978711041806284
Global Validation Set Loss: 0.13663736199747128
Global Validation Set Accuracy: 0.9653212940638949
Global Validation Set FPR: 0.15382410654792808
Global Validation Set TPR: 0.9071843110206456
Global Validation Set BER: 0.07831989776364154


 84%|████████▍ | 21/25 [14:23<02:44, 41.10s/it]

Global Train Set Loss: 0.2700989795548234
Global Train Set Accuracy: 0.9173808090423475
Global Train Set FPR: 0.46068594437111493
Global Train Set TPR: 0.8313452486948905
Global Train Set BER: 0.309670347838112
Global Validation Set Loss: 0.13632270593250906
Global Validation Set Accuracy: 0.9654801326965037
Global Validation Set FPR: 0.15353005451201607
Global Validation Set TPR: 0.9071901175344818
Global Validation Set BER: 0.0781699684887673


 88%|████████▊ | 22/25 [15:03<02:02, 40.70s/it]

Global Train Set Loss: 0.2701341386898728
Global Train Set Accuracy: 0.9173349370633455
Global Train Set FPR: 0.46071836754795414
Global Train Set TPR: 0.8313612317550373
Global Train Set BER: 0.30967856789645803
Global Validation Set Loss: 0.13653878659620605
Global Validation Set Accuracy: 0.9654515782451919
Global Validation Set FPR: 0.1537904716153782
Global Validation Set TPR: 0.9071898493739066
Global Validation Set BER: 0.07830031112073593


 92%|█████████▏| 23/25 [15:43<01:21, 40.52s/it]

Global Train Set Loss: 0.2698731307750343
Global Train Set Accuracy: 0.9174765802523787
Global Train Set FPR: 0.46057171784117
Global Train Set TPR: 0.8313701948377787
Global Train Set BER: 0.30960076150169547
Global Validation Set Loss: 0.1360894244675942
Global Validation Set Accuracy: 0.9655199079116893
Global Validation Set FPR: 0.15351055997263147
Global Validation Set TPR: 0.9072288487420355
Global Validation Set BER: 0.07814085561529821


 96%|█████████▌| 24/25 [16:23<00:40, 40.37s/it]

Global Train Set Loss: 0.2695671130087738
Global Train Set Accuracy: 0.9171365041632137
Global Train Set FPR: 0.4605950282537839
Global Train Set TPR: 0.8279835615213393
Global Train Set BER: 0.31130573336622236
Global Validation Set Loss: 0.13573431949817677
Global Validation Set Accuracy: 0.9656930586367203
Global Validation Set FPR: 0.15246560498089287
Global Validation Set TPR: 0.9072335421656093
Global Validation Set BER: 0.07761603140764191


  0%|          | 0/25 [00:00<?, ?it/s]         

Global Train Set Loss: 0.2691223174125145
Global Train Set Accuracy: 0.9171483345249776
Global Train Set FPR: 0.460502339819925
Global Train Set TPR: 0.8281471421072323
Global Train Set BER: 0.31117759885634655
Global Validation Set Loss: 0.1352994129072618
Global Validation Set Accuracy: 0.9662335754345753
Global Validation Set FPR: 0.1515168935718104
Global Validation Set TPR: 0.9072362405745185
Global Validation Set BER: 0.07714032649864602


Save experiment results

In [9]:
save_global_measure(record_global_train_measures, "train_measures_"+dataset_name+".csv", "FedTrust")
save_global_measure(record_global_val_measures, "val_measures_"+dataset_name+".csv", "FedTrust")

Saved measures to Experimental_results/FedTrust/train_measures_n-baiot.csv
Saved measures to Experimental_results/FedTrust/val_measures_n-baiot.csv


Test the performance of the model on the unseen test set

In [10]:
all_client_measures = {
    'accuracy': [],
    'fpr': [],
    'tpr': [],
    'ber': [],
    'loss': []
}
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_measures, client_n_samples = client.evaluate_test_set()
    client_n_samples_list[client.id] = client_n_samples

    for key in all_client_measures:
        all_client_measures[key].append(client_measures[key])

global_measures = {}
for key in all_client_measures:
    global_measures[key] = sum([all_client_measures[key][c.id] * client_n_samples_list[c.id] for c in clients]) / sum(client_n_samples_list.values())

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])

Global Test Set Loss: 0.0992762065479118
Global Test Set Accuracy: 0.9822230937664373
Global Test Set FPR: 0.18820480251049926
Global Test Set TPR: 0.9887116955918488
Global Test Set BER: 0.09441157017008425
