Import the necessary dependent libraries

In [1]:
import os
from module import MLPModule
from utils import aggregate_weights, save_global_measure, load_client_data, calculate_global_measures
from tqdm import tqdm
import torch
from client import Client

Set the necessary parameters and hyperparameters for the experiment and model

In [2]:
seed = 54
dataset_name = "unsw-nb15"
num_rounds = 25
input_dim = 0

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.backends.mps.is_available():
#     device = torch.device("mps:0")
else:
    device = torch.device("cpu")

hyperparameters = {
    'input_dim': input_dim,
    'lr': 0.001,
    'hidden_neurons_num': 512,
    'batch_size': 128,
    'seed': seed,
    'device': device
}

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"), device)

In [4]:
def process_client_data(data_list, keep_columns):
    processed_data_list = []

    for client_data in data_list:
        X_train, Y_train = client_data[0]
        processed_X_train = X_train[:, keep_columns]
        processed_train = (processed_X_train, Y_train)

        X_val, Y_val = client_data[1]
        processed_X_val = X_val[:, keep_columns]
        processed_val = (processed_X_val, Y_val)

        X_test, Y_test = client_data[2]
        processed_X_test = X_test[:, keep_columns]
        processed_test = (processed_X_test, Y_test)

        processed_client_data = (processed_train, processed_val, processed_test)
        processed_data_list.append(processed_client_data)

    return processed_data_list

In [5]:
columns_to_keep = [0, 2, 4, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 28, 29, 38, 39]
if dataset_name == "n-baiot":
    columns_to_keep = [0, 3, 6, 30, 37, 44, 31, 38, 45, 32, 39, 46, 60, 63, 66, 75, 82, 89, 76, 83, 90]
client_data_list = process_client_data(client_data_list, columns_to_keep)
hyperparameters['input_dim'] = len(columns_to_keep)

Generate clients

In [6]:
def create_clients(client_data, hyperparameter):
    client_list = []
    for i, (train_set, test_set, val_set) in enumerate(client_data):
        c = Client(
            id=i,
            hyperparameters=hyperparameter,
            train_set=train_set,
            test_set=test_set,
            val_set=val_set
        )
        client_list.append(c)
    return client_list
clients = create_clients(client_data_list, hyperparameters)

FedAvg training process

In [7]:
# Initialize model
global_model = MLPModule(hyperparameters)

# Pre-trained
model_weights_list = []
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
    client_n_samples_list[client.id] = client_n_samples
    model_weights_list.append(client_model_parameters)
    
global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))

In [8]:
# Record training results
record_global_train_measures = [None] * num_rounds
record_global_val_measures = [None] * num_rounds

pbar = tqdm(total=num_rounds)
# Perform num_rounds rounds of training
for idx in range(num_rounds):
    model_weights_list = []
    client_n_samples_list = {}
    all_client_train_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    all_client_val_measures = { 'accuracy': {}, 'fpr': {}, 'tpr': {}, 'ber': {}, 'loss': {} }
    
    for client in clients:
        # Client receives new global model
        client.receive_global_model(global_model)

        # The client trains the model locally
        client_model_parameters, client_measures, client_n_samples = client.local_fit_and_upload_parameters()
        client_n_samples_list[client.id] = client_n_samples
        model_weights_list.append(client_model_parameters)
        for key in all_client_train_measures:
            all_client_train_measures[key][client.id] = (client_measures['train'][key])
            all_client_val_measures[key][client.id] = (client_measures['val'][key])

    global_train_measures, global_val_measures = calculate_global_measures(clients,
                                                                           all_client_train_measures,
                                                                           all_client_val_measures,
                                                                           display_result=True)
    record_global_train_measures[idx] = global_train_measures
    record_global_val_measures[idx] = global_val_measures

    pbar.update(1)

    # Aggregate new global model
    global_model.set_parameters(aggregate_weights(model_weights_list, client_n_samples_list))
    
pbar.reset()

  4%|▍         | 1/25 [00:13<05:28, 13.70s/it]

Global Train Set Loss: 0.3996258521731434
Global Train Set Accuracy: 0.8272163144367685
Global Train Set FPR: 0.14281994768518083
Global Train Set TPR: 0.27600130164572706
Global Train Set BER: 0.4284093230197268
Global Validation Set Loss: 0.3157757683821993
Global Validation Set Accuracy: 0.8580285725913508
Global Validation Set FPR: 0.04747662970694575
Global Validation Set TPR: 0.30399923825843767
Global Validation Set BER: 0.2617386957242541


  8%|▊         | 2/25 [00:27<05:14, 13.67s/it]

Global Train Set Loss: 0.36509207318031534
Global Train Set Accuracy: 0.8330123014877523
Global Train Set FPR: 0.1444068370324117
Global Train Set TPR: 0.3031451460711775
Global Train Set BER: 0.41563084548061724
Global Validation Set Loss: 0.27932165986643365
Global Validation Set Accuracy: 0.8644856985398723
Global Validation Set FPR: 0.038813899221871624
Global Validation Set TPR: 0.339988874192472
Global Validation Set BER: 0.23941251251469986


 12%|█▏        | 3/25 [00:41<05:01, 13.72s/it]

Global Train Set Loss: 0.3584143579023467
Global Train Set Accuracy: 0.8540228607528628
Global Train Set FPR: 0.14527350993406332
Global Train Set TPR: 0.3430147642894148
Global Train Set BER: 0.3961293728223241
Global Validation Set Loss: 0.2735769113579475
Global Validation Set Accuracy: 0.8812678562202569
Global Validation Set FPR: 0.034591573418789545
Global Validation Set TPR: 0.3774758846045562
Global Validation Set BER: 0.2185578444071166


 16%|█▌        | 4/25 [00:54<04:49, 13.78s/it]

Global Train Set Loss: 0.3565718728965069
Global Train Set Accuracy: 0.8641044473079205
Global Train Set FPR: 0.14638718507440301
Global Train Set TPR: 0.3617194055281835
Global Train Set BER: 0.38733388977310973
Global Validation Set Loss: 0.2723357422169368
Global Validation Set Accuracy: 0.8955631141126845
Global Validation Set FPR: 0.03294604599531307
Global Validation Set TPR: 0.40396972471186887
Global Validation Set BER: 0.2044881606417221


 20%|██        | 5/25 [01:08<04:35, 13.79s/it]

Global Train Set Loss: 0.355455017293341
Global Train Set Accuracy: 0.8664338258243519
Global Train Set FPR: 0.14522848948570155
Global Train Set TPR: 0.3670972588903216
Global Train Set BER: 0.38406561529768984
Global Validation Set Loss: 0.27159518471326033
Global Validation Set Accuracy: 0.8985898391255941
Global Validation Set FPR: 0.03155008006640576
Global Validation Set TPR: 0.410647628360258
Global Validation Set BER: 0.2004512258530739


 24%|██▍       | 6/25 [01:22<04:22, 13.83s/it]

Global Train Set Loss: 0.35471136969827993
Global Train Set Accuracy: 0.866895166384828
Global Train Set FPR: 0.1444319856492493
Global Train Set TPR: 0.3729405964615782
Global Train Set BER: 0.38074569459383534
Global Validation Set Loss: 0.2708706025207707
Global Validation Set Accuracy: 0.898749047593112
Global Validation Set FPR: 0.03001577204223672
Global Validation Set TPR: 0.4164058711166959
Global Validation Set BER: 0.19680495046277044


 28%|██▊       | 7/25 [01:36<04:10, 13.90s/it]

Global Train Set Loss: 0.3528543787873918
Global Train Set Accuracy: 0.868184114868398
Global Train Set FPR: 0.14374137892010275
Global Train Set TPR: 0.37904672833598246
Global Train Set BER: 0.37734732529206
Global Validation Set Loss: 0.2689472590271431
Global Validation Set Accuracy: 0.8997275718497169
Global Validation Set FPR: 0.030422496177538876
Global Validation Set TPR: 0.4275418337175198
Global Validation Set BER: 0.1914403312300096


 32%|███▏      | 8/25 [01:50<03:55, 13.88s/it]

Global Train Set Loss: 0.35184663568241936
Global Train Set Accuracy: 0.8704295883499418
Global Train Set FPR: 0.14337029588831818
Global Train Set TPR: 0.3854535170631123
Global Train Set BER: 0.3739583894126028
Global Validation Set Loss: 0.2681979379274411
Global Validation Set Accuracy: 0.9019999776466199
Global Validation Set FPR: 0.029667527339643646
Global Validation Set TPR: 0.43442242370931267
Global Validation Set BER: 0.18762255181516554


 36%|███▌      | 9/25 [02:04<03:41, 13.82s/it]

Global Train Set Loss: 0.35026905077880405
Global Train Set Accuracy: 0.8712897726556126
Global Train Set FPR: 0.14225394277600728
Global Train Set TPR: 0.3883504974658378
Global Train Set BER: 0.37195172265508475
Global Validation Set Loss: 0.26676039757687275
Global Validation Set Accuracy: 0.9025244410877191
Global Validation Set FPR: 0.024073755920077114
Global Validation Set TPR: 0.43811836252826225
Global Validation Set BER: 0.1829776966959075


 40%|████      | 10/25 [02:18<03:27, 13.86s/it]

Global Train Set Loss: 0.3481804911333083
Global Train Set Accuracy: 0.8721036445848115
Global Train Set FPR: 0.1410847746326042
Global Train Set TPR: 0.39067817266436544
Global Train Set BER: 0.3702033009841192
Global Validation Set Loss: 0.26474568352293715
Global Validation Set Accuracy: 0.904501213237712
Global Validation Set FPR: 0.021342421949902504
Global Validation Set TPR: 0.4426863704189137
Global Validation Set BER: 0.17932802576549445


 44%|████▍     | 11/25 [02:32<03:14, 13.91s/it]

Global Train Set Loss: 0.3469085130678567
Global Train Set Accuracy: 0.8723260724168982
Global Train Set FPR: 0.1412896750743188
Global Train Set TPR: 0.39219377879130446
Global Train Set BER: 0.36954794814150693
Global Validation Set Loss: 0.26363401477471626
Global Validation Set Accuracy: 0.9066096095574298
Global Validation Set FPR: 0.019759044669404802
Global Validation Set TPR: 0.4465594422440041
Global Validation Set BER: 0.17659980121270027


 48%|████▊     | 12/25 [02:46<03:01, 14.00s/it]

Global Train Set Loss: 0.3451123958722555
Global Train Set Accuracy: 0.8732955319895539
Global Train Set FPR: 0.14129778999111475
Global Train Set TPR: 0.39639461818437466
Global Train Set BER: 0.3674515859033699
Global Validation Set Loss: 0.26207196706733676
Global Validation Set Accuracy: 0.9072126036320414
Global Validation Set FPR: 0.019591968562025993
Global Validation Set TPR: 0.4515922245588019
Global Validation Set BER: 0.17399987200161207


 52%|█████▏    | 13/25 [03:00<02:46, 13.91s/it]

Global Train Set Loss: 0.34519435937435006
Global Train Set Accuracy: 0.8732856518643599
Global Train Set FPR: 0.14135628453542384
Global Train Set TPR: 0.39804755554231763
Global Train Set BER: 0.366654364496553
Global Validation Set Loss: 0.2626184726376094
Global Validation Set Accuracy: 0.9075222970692868
Global Validation Set FPR: 0.018812103345744558
Global Validation Set TPR: 0.45364920985693025
Global Validation Set BER: 0.17258144674440723


 56%|█████▌    | 14/25 [03:13<02:32, 13.88s/it]

Global Train Set Loss: 0.3451436154485319
Global Train Set Accuracy: 0.874570616546074
Global Train Set FPR: 0.14276926279639307
Global Train Set TPR: 0.4016242001102843
Global Train Set BER: 0.3655725313430544
Global Validation Set Loss: 0.2632784629011662
Global Validation Set Accuracy: 0.9075356889508177
Global Validation Set FPR: 0.02085921080488018
Global Validation Set TPR: 0.4585114254953229
Global Validation Set BER: 0.17117389265477875


 60%|██████    | 15/25 [03:27<02:18, 13.86s/it]

Global Train Set Loss: 0.3445456147810323
Global Train Set Accuracy: 0.8749936289493497
Global Train Set FPR: 0.14476300029343528
Global Train Set TPR: 0.4020697350592347
Global Train Set BER: 0.3663466326171001
Global Validation Set Loss: 0.2630189762015497
Global Validation Set Accuracy: 0.9080541827938637
Global Validation Set FPR: 0.0205428537058491
Global Validation Set TPR: 0.4592461844486531
Global Validation Set BER: 0.170648334628598


 64%|██████▍   | 16/25 [03:41<02:04, 13.84s/it]

Global Train Set Loss: 0.34391475925916
Global Train Set Accuracy: 0.8760008846609401
Global Train Set FPR: 0.1454235236071825
Global Train Set TPR: 0.40376406115722113
Global Train Set BER: 0.3658297312249805
Global Validation Set Loss: 0.26252860271231676
Global Validation Set Accuracy: 0.9095445757690412
Global Validation Set FPR: 0.019968881413770068
Global Validation Set TPR: 0.46128583323625505
Global Validation Set BER: 0.1693415240887576


 68%|██████▊   | 17/25 [03:55<01:50, 13.85s/it]

Global Train Set Loss: 0.3433742316802145
Global Train Set Accuracy: 0.8765145003380368
Global Train Set FPR: 0.14548763748872043
Global Train Set TPR: 0.40490606535701595
Global Train Set BER: 0.36529078606585214
Global Validation Set Loss: 0.2622500265254577
Global Validation Set Accuracy: 0.9088905951479014
Global Validation Set FPR: 0.02103606207728073
Global Validation Set TPR: 0.46237662132511625
Global Validation Set BER: 0.1693297203760823


 72%|███████▏  | 18/25 [04:09<01:37, 13.90s/it]

Global Train Set Loss: 0.3427445206136596
Global Train Set Accuracy: 0.8768814863678053
Global Train Set FPR: 0.14557163651647181
Global Train Set TPR: 0.40538495489964044
Global Train Set BER: 0.3650933408084157
Global Validation Set Loss: 0.26193187900917053
Global Validation Set Accuracy: 0.910211645639263
Global Validation Set FPR: 0.020440648856764155
Global Validation Set TPR: 0.4623435841453851
Global Validation Set BER: 0.16904853235568956


 76%|███████▌  | 19/25 [04:23<01:23, 13.91s/it]

Global Train Set Loss: 0.34260110843833125
Global Train Set Accuracy: 0.877426988151861
Global Train Set FPR: 0.14694852022501032
Global Train Set TPR: 0.4063869432502837
Global Train Set BER: 0.36528078848736323
Global Validation Set Loss: 0.2619323796099431
Global Validation Set Accuracy: 0.9110725819514678
Global Validation Set FPR: 0.02129784642755061
Global Validation Set TPR: 0.46407532138143437
Global Validation Set BER: 0.16861126252305814


 80%|████████  | 20/25 [04:37<01:09, 13.90s/it]

Global Train Set Loss: 0.3422664137691374
Global Train Set Accuracy: 0.8800362367133836
Global Train Set FPR: 0.14923051103812512
Global Train Set TPR: 0.40944589537877946
Global Train Set BER: 0.36489230782967286
Global Validation Set Loss: 0.26200408476798853
Global Validation Set Accuracy: 0.9131129243985862
Global Validation Set FPR: 0.02102107346927878
Global Validation Set TPR: 0.4670672836769781
Global Validation Set BER: 0.16697689489615036


 84%|████████▍ | 21/25 [04:51<00:55, 13.93s/it]

Global Train Set Loss: 0.3415279732600976
Global Train Set Accuracy: 0.8836433064240521
Global Train Set FPR: 0.14995746047395742
Global Train Set TPR: 0.41444808666084737
Global Train Set BER: 0.36275468690655494
Global Validation Set Loss: 0.2615400540923727
Global Validation Set Accuracy: 0.9192074108917971
Global Validation Set FPR: 0.01878414108343836
Global Validation Set TPR: 0.4740830068722498
Global Validation Set BER: 0.16235056710559426


 88%|████████▊ | 22/25 [05:04<00:41, 13.82s/it]

Global Train Set Loss: 0.34104539678786394
Global Train Set Accuracy: 0.8865865913206855
Global Train Set FPR: 0.1497449036657989
Global Train Set TPR: 0.4196952658230924
Global Train Set BER: 0.3600248189213533
Global Validation Set Loss: 0.2611410736880343
Global Validation Set Accuracy: 0.9228960925201167
Global Validation Set FPR: 0.0193729469326493
Global Validation Set TPR: 0.4805938059070441
Global Validation Set BER: 0.15938957051280256


 92%|█████████▏| 23/25 [05:18<00:27, 13.84s/it]

Global Train Set Loss: 0.34012716344220684
Global Train Set Accuracy: 0.8871990456900265
Global Train Set FPR: 0.15012579568720216
Global Train Set TPR: 0.42040663345124946
Global Train Set BER: 0.3598595811179763
Global Validation Set Loss: 0.2605209383747952
Global Validation Set Accuracy: 0.9240548824570419
Global Validation Set FPR: 0.018608957473854933
Global Validation Set TPR: 0.4825370801087508
Global Validation Set BER: 0.1580359386825521


 96%|█████████▌| 24/25 [05:32<00:13, 13.82s/it]

Global Train Set Loss: 0.3402873408680388
Global Train Set Accuracy: 0.8874314246447468
Global Train Set FPR: 0.14987928359858077
Global Train Set TPR: 0.42094682030447444
Global Train Set BER: 0.35946623164705316
Global Validation Set Loss: 0.26102198483165673
Global Validation Set Accuracy: 0.924381495381253
Global Validation Set FPR: 0.01806519511000287
Global Validation Set TPR: 0.483421479676337
Global Validation Set BER: 0.15732185771683282


  0%|          | 0/25 [00:00<?, ?it/s]         

Global Train Set Loss: 0.33977771320310024
Global Train Set Accuracy: 0.8874850994785081
Global Train Set FPR: 0.14973713966641602
Global Train Set TPR: 0.4214193657947861
Global Train Set BER: 0.359158886935815
Global Validation Set Loss: 0.2609236231286386
Global Validation Set Accuracy: 0.9248847779664762
Global Validation Set FPR: 0.017204335625303328
Global Validation Set TPR: 0.4853973071959786
Global Validation Set BER: 0.15590351421466228


Save experiment results

In [9]:
save_global_measure(record_global_train_measures, "train_measures_"+dataset_name+".csv", "FedTrust")
save_global_measure(record_global_val_measures, "val_measures_"+dataset_name+".csv", "FedTrust")

Saved measures to Experimental_results/FedTrust/train_measures_unsw-nb15.csv
Saved measures to Experimental_results/FedTrust/val_measures_unsw-nb15.csv


Test the performance of the model on the unseen test set

In [10]:
all_client_measures = {
    'accuracy': [],
    'fpr': [],
    'tpr': [],
    'ber': [],
    'loss': []
}
client_n_samples_list = {}
for client in clients:
    client.receive_global_model(global_model)
    client_measures, client_n_samples = client.evaluate_test_set()
    client_n_samples_list[client.id] = client_n_samples

    for key in all_client_measures:
        all_client_measures[key].append(client_measures[key])

global_measures = {}
for key in all_client_measures:
    global_measures[key] = sum([all_client_measures[key][c.id] * client_n_samples_list[c.id] for c in clients]) / sum(client_n_samples_list.values())

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])

Global Test Set Loss: 0.2849678733798533
Global Test Set Accuracy: 0.8855191418167503
Global Test Set FPR: 0.00016290764224234115
Global Test Set TPR: 0.027515634977215864
Global Test Set BER: 0.40550678759715253
