In [1]:
import os
from utils import load_client_data
import networkx as nx
from sadmm_solver import NetworkLassoRunner
from utils import save_global_measure

os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

(CVXPY) Dec 08 04:22:52 PM: Encountered unexpected exception importing solver OSQP:
ImportError("dlopen(/Users/hueybai/miniconda3/envs/DL_Pytroch/lib/python3.10/site-packages/osqp/_osqp.cpython-310-darwin.so, 0x0002): symbol not found in flat namespace '_csc_matrix'")


In [2]:
seed = 54
dataset_name = "n-baiot"
num_rounds = 25
input_dim = 0
if dataset_name == "unsw-nb15":
    input_dim = 47
elif dataset_name == "n-baiot":
    input_dim = 115

hyperparameters = {
    'sadmm_lambda': 0.1,
    'rho': 1.0,
    'c': 0.75,
    'n_rounds': num_rounds,
    'num_features': input_dim,
    'seed': seed
}

In [3]:
client_data_list = load_client_data(os.path.join(dataset_name, "split"))

In [4]:
train_set = [None] * len(client_data_list)
test_set = [None] * len(client_data_list)
val_set = [None] * len(client_data_list)
for idx, client_data in enumerate(client_data_list):
    train_data, test_data, val_data = client_data

    train_data[1][train_data[1] == 0] = -1
    test_data[1][test_data[1] == 0] = -1
    val_data[1][val_data[1] == 0] = -1

    train_set[idx] = train_data
    test_set[idx] = test_data
    val_set[idx] = val_data

In [5]:
def build_graph(num_nodes):
    print("Building a simulated MEC topology for " + str(num_nodes) + " nodes.")
    neighbours = [[73, 38, 88, 1, 42], [75, 64, 88, 12, 67], [77, 60, 56, 17, 13], [30, 75, 74, 93, 64],
                  [90, 27, 51, 47, 82], [96, 16, 55, 95, 28], [12, 7, 88, 42, 64], [6, 12, 42, 88, 64],
                  [48, 18, 80, 47, 61], [11, 33, 85, 44, 87], [34, 15, 58, 50, 43], [9, 33, 85, 84, 87],
                  [88, 6, 75, 64, 1], [81, 60, 2, 17, 56], [30, 3, 1, 74, 39], [34, 50, 10, 32, 91],
                  [28, 5, 87, 95, 96], [77, 60, 2, 13, 1], [48, 8, 24, 47, 80], [66, 79, 7, 6, 63],
                  [89, 81, 13, 50, 92], [25, 28, 95, 87, 49], [52, 54, 26, 53, 71], [43, 97, 62, 58, 10],
                  [44, 25, 18, 8, 28], [44, 24, 49, 21, 28], [76, 71, 54, 53, 86], [90, 4, 51, 82, 47],
                  [16, 87, 5, 95, 44], [59, 99, 82, 47, 80], [3, 74, 14, 75, 88], [68, 23, 97, 62, 43],
                  [91, 45, 35, 15, 60], [11, 9, 85, 44, 18], [15, 10, 50, 58, 32], [32, 60, 91, 93, 45],
                  [86, 96, 87, 5, 16], [5, 16, 96, 36, 28], [42, 1, 67, 0, 12], [14, 45, 46, 0, 91],
                  [42, 67, 38, 1, 12], [50, 81, 20, 62, 89], [40, 12, 6, 38, 88], [58, 62, 10, 23, 41],
                  [24, 25, 49, 28, 16], [91, 32, 46, 39, 14], [45, 1, 14, 78, 39], [80, 51, 59, 48, 82],
                  [47, 80, 61, 8, 59], [29, 59, 44, 25, 80], [81, 20, 15, 41, 13], [82, 47, 80, 4, 79],
                  [54, 26, 86, 57, 36], [71, 26, 76, 86, 54], [52, 86, 36, 26, 71], [5, 96, 95, 86, 76],
                  [2, 60, 77, 93, 13], [54, 52, 36, 98, 86], [43, 10, 34, 62, 50], [29, 99, 47, 80, 51],
                  [13, 2, 17, 77, 56], [80, 48, 99, 47, 79], [43, 97, 58, 41, 10], [65, 72, 38, 6, 42],
                  [75, 12, 1, 88, 3], [63, 79, 61, 48, 72], [19, 88, 6, 7, 12], [1, 40, 38, 64, 75],
                  [23, 97, 43, 62, 58], [94, 70, 22, 52, 57], [94, 69, 22, 52, 57], [53, 26, 86, 76, 54],
                  [63, 42, 38, 40, 65], [0, 38, 42, 88, 12], [30, 3, 93, 14, 75], [64, 88, 12, 1, 3],
                  [26, 96, 5, 55, 86], [2, 17, 60, 64, 3], [46, 45, 39, 91, 32], [51, 99, 80, 47, 4],
                  [47, 99, 51, 48, 59], [13, 50, 20, 60, 2], [51, 47, 80, 29, 90], [85, 84, 99, 61, 49],
                  [83, 85, 99, 61, 59], [83, 84, 11, 33, 9], [36, 96, 95, 98, 5], [16, 28, 95, 36, 96],
                  [12, 75, 64, 6, 1], [92, 20, 81, 50, 13], [27, 4, 51, 82, 47], [32, 45, 60, 35, 14],
                  [89, 20, 81, 13, 41], [3, 74, 30, 56, 60], [69, 70, 22, 52, 57], [96, 5, 86, 16, 87],
                  [5, 86, 95, 36, 55], [62, 23, 43, 58, 41], [86, 87, 36, 71, 95], [80, 59, 79, 47, 51]]
    G = nx.Graph()
    for node_id in range(num_nodes):
        G.add_node(node_id)
        for neighbour_id in neighbours[node_id]:
            G.add_edge(node_id, neighbour_id, weight=1)
    return G

In [6]:
G = build_graph(100)

Building a simulated MEC topology for 100 nodes.


In [7]:
runner = NetworkLassoRunner(G)

In [8]:
train_measures, val_measures, test_measure = runner.run(hyperparameters, train_set, test_set,val_set)

Running Stochastic Network Lasso...
Time Iteration: 0, Time: 57.97767424583435
train_loss: 0.6577476952285549, train_acc: 0.5537948619134183, train_FPR: 0.6912989030342532, train_TPR: 0.7050156738778632, train_BER: 0.4881416145781951
val_loss: 0.692087553862325, val_acc: 0.5553691243455817, val_FPR: 0.6259373819616941, val_TPR: 0.6137070936664537, val_BER: 0.4611151441476203
Time Iteration: 1, Time: 146.34398818016052
train_loss: 0.6073727821506688, train_acc: 0.6800426511376897, train_FPR: 0.7484763773083923, train_TPR: 0.8457846994415746, train_BER: 0.4463458389334088
val_loss: 0.6421083326073636, val_acc: 0.7023352322341845, val_FPR: 0.6191898394438065, val_TPR: 0.7820436522412268, val_BER: 0.3735730936012898
Time Iteration: 2, Time: 234.55174016952515
train_loss: 0.5873960883492978, train_acc: 0.6633228885411833, train_FPR: 0.784928727436968, train_TPR: 0.8543263585396735, train_BER: 0.4603011844486474
val_loss: 0.6057517776494826, val_acc: 0.6807568769109362, val_FPR: 0.6532986183

In [9]:
name = "S-ADMM" if hyperparameters['rho'] != 0 else "Local_SVM" 
save_global_measure(train_measures, "train_measures_"+dataset_name+".csv", name)
save_global_measure(val_measures, "val_measures_"+dataset_name+".csv", name)

Saved measures to Experimental_results/S-ADMM/train_measures_n-baiot.csv
Saved measures to Experimental_results/S-ADMM/val_measures_n-baiot.csv
