In [1]:
import os
from utils import load_client_data
from local_svm import LocalSVM
from utils import save_global_measure

In [2]:
seed = 54
dataset_name = "n-baiot"

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"))

In [4]:
train_set = [None] * len(client_data_list)
test_set = [None] * len(client_data_list)
val_set = [None] * len(client_data_list)
for idx, client_data in enumerate(client_data_list):
    train_data, test_data, val_data = client_data

    train_data[1][train_data[1] == 0] = -1
    test_data[1][test_data[1] == 0] = -1
    val_data[1][val_data[1] == 0] = -1

    train_set[idx] = train_data
    test_set[idx] = test_data
    val_set[idx] = val_data

In [5]:
local_svm = LocalSVM(train_set, test_set, val_set, seed=seed)

In [6]:
local_svm.fit_all_client()

100%|██████████| 100/100 [00:56<00:00,  1.77it/s]


In [7]:
record_train_measures, record_val_measures = local_svm.calculate_measures()

Global Train Set Loss: 0.49913570854539885
Global Train Set Accuracy: 0.7576974518497456
Global Train Set FPR: 0.5559462605789572
Global Train Set TPR: 0.7663919301987303
Global Train Set BER: 0.39028736927174296
Global Validation Set Loss: 0.5406318606019971
Global Validation Set Accuracy: 0.7852189231633434
Global Validation Set FPR: 0.31900036667283516
Global Validation Set TPR: 0.7924115480913798
Global Validation Set BER: 0.22426379704583008


In [8]:
save_global_measure([record_train_measures], "train_measures_"+dataset_name+".csv", "Local_SVM")
save_global_measure([record_val_measures], "val_measures_"+dataset_name+".csv", "Local_SVM")

Saved measures to Experimental_results/Local_SVM/train_measures_n-baiot.csv
Saved measures to Experimental_results/Local_SVM/val_measures_n-baiot.csv


In [9]:
global_measures = local_svm.calculate_measures_test_set()

print("Global Test Set Loss:", global_measures['loss'])
print("Global Test Set Accuracy:", global_measures['accuracy'])
print("Global Test Set FPR:", global_measures['fpr'])
print("Global Test Set TPR:", global_measures['tpr'])
print("Global Test Set BER:", global_measures['ber'])

Global Test Set Loss: 0.5406318606019971
Global Test Set Accuracy: 0.7852189231633434
Global Test Set FPR: 0.31900036667283516
Global Test Set TPR: 0.7924115480913798
Global Test Set BER: 0.22426379704583008
