In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="5"
import shfl
import torch
import copy
import numpy as np
import tensorflow as tf

from sklearn.preprocessing import LabelBinarizer

from shfl.private import UnprotectedAccess
from CIT.model import CITModel
from utils import get_federated_data_csv
from ClassifierModel import ClassifierModel

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
#"../data/COVIDGR1.0/centralized/cropped"
#partition_iid_3nodes_1.csv
args = {"data_path":"../data/COVIDGR1.0-Segmentadas", 
        "csv_path": "../partitions/partition_iid_3nodes_1.csv",
        "output_path": "../weights",
        "input_path": "",
        "model_name":"transferlearning.model", 
        "label_bin": "lb.pickle", 
        "batch_size": 8,
        "federated_rounds": 1,
        "epochs_per_FL_round": 50,
        "num_nodes": 3,
        "size_averaging": 1,
        "random_rotation": 0,
        "random_shift": 0, 
        "random_zoom": 0,
        "horizontal_flip": False,        
        "finetune": True,
        "train_network": True}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

In [3]:
a = ['N', 'P']
b = ['NTN', 'NTP', 'PTP', 'PTN']
lb1 = LabelBinarizer()
lb2 = LabelBinarizer()
lb1.fit(a)
lb2.fit(b)

dict_labels = { 'PTP' : np.argmax(lb2.transform(['PTP'])[0]) , 'PTN' : np.argmax(lb2.transform(['PTN'])[0]) , 
                'NTP' : np.argmax(lb2.transform(['NTP'])[0]) , 'NTN' : np.argmax(lb2.transform(['NTN'])[0]), 
                'P' : lb1.transform(['P'])[0][0], 'N' : lb1.transform(['N'])[0][0]
              } 



In [4]:
def cit_builder():    
    return CITModel(['N', 'P'], classifier_name = "resnet18", folds=1, lambda_values = [0.05], batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], device=device)

def classifier_builder( G_dict ):
    return ClassifierModel(G_dict, dict_labels, batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], finetune = args["finetune"])

def get_transformed_data(federated_data, cit_federated_government, lb1, lb2):
    t_federated_data = copy.deepcopy(federated_data)

    for i in range(federated_data.num_nodes()):
        data_node = federated_data[i]
        t_data_node = t_federated_data[i]
        data = data_node.query()._data
        labels = data_node.query()._label
        t_data, t_labels = cit_federated_government.global_model.transform_data(data, labels, lb1, lb2)
        t_data_node.query()._data = t_data
        t_data_node.query()._label = t_labels


    return t_federated_data

In [5]:
print("[INFO] Fetching federated data...")
federated_data, train_data, train_label, test_data, test_label, train_files, test_files, args["num_nodes"] = get_federated_data_csv(args["data_path"], args["csv_path"], lb1)
federated_data.configure_data_access(UnprotectedAccess())
print(len(train_data))
print(len(test_data))
print("[INFO] done")

[INFO] Fetching federated data...
681
171
[INFO] done


In [6]:
aggregator = shfl.federated_aggregator.FedAvgAggregator()
cit_federated_government = shfl.federated_government.FederatedGovernment(cit_builder, federated_data, aggregator)
cit_federated_government.run_rounds(args["federated_rounds"], test_data, test_label)

Accuracy round 0


  0%|          | 0/23 [00:00<?, ?it/s]

Training node 0
[INFO] weights = [1.         0.93965517]


[Validating]: Acc_D: 0.2609: 100%|██████████| 23/23 [00:00<00:00, 28.40it/s]


[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.2609: 100%|██████████| 23/23 [00:00<00:00, 28.64it/s]
[1/50] Loss_D: 0.4433 Acc_D: 0.5470 Loss_G_class1: 0.1281 Loss_G_class2: 0.1363: 100%|██████████| 26/26 [00:36<00:00,  1.41s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:00<00:00, 34.22it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.7412493876788927


[2/50] Loss_D: 0.2980 Acc_D: 0.7178 Loss_G_class1: 0.0262 Loss_G_class2: 0.0228: 100%|██████████| 26/26 [00:37<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:02<00:00,  8.05it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.9836041266503541
EarlyStopping counter: 1 out of 10


[3/50] Loss_D: 0.2813 Acc_D: 0.7178 Loss_G_class1: 0.0203 Loss_G_class2: 0.0169: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:02<00:00,  8.59it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.6923403221627941


[4/50] Loss_D: 0.2857 Acc_D: 0.7129 Loss_G_class1: 0.0198 Loss_G_class2: 0.0170: 100%|██████████| 26/26 [00:38<00:00,  1.47s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:02<00:00,  8.77it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.9339130488426789
EarlyStopping counter: 1 out of 10


[5/50] Loss_D: 0.2675 Acc_D: 0.7450 Loss_G_class1: 0.0196 Loss_G_class2: 0.0141: 100%|██████████| 26/26 [00:37<00:00,  1.43s/it]
[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 32.48it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.5217391304347826
Valid Loss = 1.5391942844928608
EarlyStopping counter: 2 out of 10


[6/50] Loss_D: 0.2537 Acc_D: 0.7252 Loss_G_class1: 0.0102 Loss_G_class2: 0.0207: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7826: 100%|██████████| 23/23 [00:02<00:00,  7.92it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.782608695652174
Valid Loss = 0.6894416679506716


[7/50] Loss_D: 0.2062 Acc_D: 0.8144 Loss_G_class1: 0.0126 Loss_G_class2: 0.0127: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:01<00:00, 14.92it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.8978462504303973
EarlyStopping counter: 1 out of 10


[8/50] Loss_D: 0.1773 Acc_D: 0.8614 Loss_G_class1: 0.0129 Loss_G_class2: 0.0107: 100%|██████████| 26/26 [00:37<00:00,  1.43s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:03<00:00,  6.49it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 1.0729294420908326
EarlyStopping counter: 2 out of 10


[9/50] Loss_D: 0.1806 Acc_D: 0.8366 Loss_G_class1: 0.0125 Loss_G_class2: 0.0114: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.4348: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.43478260869565216
Valid Loss = 1.1331503903088362
EarlyStopping counter: 3 out of 10


[10/50] Loss_D: 0.1699 Acc_D: 0.8416 Loss_G_class1: 0.0136 Loss_G_class2: 0.0107: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:02<00:00, 10.32it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.873130450754062
EarlyStopping counter: 4 out of 10


[11/50] Loss_D: 0.1365 Acc_D: 0.8861 Loss_G_class1: 0.0099 Loss_G_class2: 0.0089: 100%|██████████| 26/26 [00:36<00:00,  1.41s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:01<00:00, 17.80it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 1.302487801598466
EarlyStopping counter: 5 out of 10


[12/50] Loss_D: 0.1460 Acc_D: 0.8688 Loss_G_class1: 0.0103 Loss_G_class2: 0.0098: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 31.24it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.9600926832012509
EarlyStopping counter: 6 out of 10


[13/50] Loss_D: 0.1296 Acc_D: 0.8985 Loss_G_class1: 0.0095 Loss_G_class2: 0.0093: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.4783: 100%|██████████| 23/23 [00:00<00:00, 32.21it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.4782608695652174
Valid Loss = 1.2366683487982855
EarlyStopping counter: 7 out of 10


[14/50] Loss_D: 0.1212 Acc_D: 0.9035 Loss_G_class1: 0.0088 Loss_G_class2: 0.0094: 100%|██████████| 26/26 [00:36<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:02<00:00, 10.65it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.998969194681748
EarlyStopping counter: 8 out of 10


[15/50] Loss_D: 0.0936 Acc_D: 0.9381 Loss_G_class1: 0.0073 Loss_G_class2: 0.0073: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:01<00:00, 11.91it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 1.4431310088738152
EarlyStopping counter: 9 out of 10


[16/50] Loss_D: 0.1244 Acc_D: 0.9010 Loss_G_class1: 0.0104 Loss_G_class2: 0.0080: 100%|██████████| 26/26 [00:36<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 34.40it/s]
  0%|          | 0/23 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 1.6001119107493889
EarlyStopping counter: 10 out of 10
Early stopping, epoch 16


[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 34.74it/s]
  0%|          | 0/23 [00:00<?, ?it/s]

[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.6086956521739131
Valid Loss = 0.6995087175265603
Training node 1
[INFO] weights = [0.83471074 1.        ]


[Validating]: Acc_D: 0.3913: 100%|██████████| 23/23 [00:00<00:00, 35.48it/s]


[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.3913: 100%|██████████| 23/23 [00:01<00:00, 17.01it/s]
[1/50] Loss_D: 0.4482 Acc_D: 0.5779 Loss_G_class1: 0.1172 Loss_G_class2: 0.1459: 100%|██████████| 25/25 [00:35<00:00,  1.44s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 33.60it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.8111272778200067


[2/50] Loss_D: 0.2831 Acc_D: 0.7161 Loss_G_class1: 0.0223 Loss_G_class2: 0.0269: 100%|██████████| 25/25 [00:35<00:00,  1.44s/it]
[Validating]: Acc_D: 0.8261: 100%|██████████| 23/23 [00:01<00:00, 13.86it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.8260869565217391
Valid Loss = 0.45015015258737234


[3/50] Loss_D: 0.2495 Acc_D: 0.7613 Loss_G_class1: 0.0179 Loss_G_class2: 0.0177: 100%|██████████| 25/25 [00:39<00:00,  1.60s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:01<00:00, 14.97it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.6683615439611933
EarlyStopping counter: 1 out of 10


[4/50] Loss_D: 0.2196 Acc_D: 0.7990 Loss_G_class1: 0.0151 Loss_G_class2: 0.0150: 100%|██████████| 25/25 [00:36<00:00,  1.45s/it]
[Validating]: Acc_D: 0.8261: 100%|██████████| 23/23 [00:00<00:00, 32.72it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.8260869565217391
Valid Loss = 0.541494655220405
EarlyStopping counter: 2 out of 10


[5/50] Loss_D: 0.1995 Acc_D: 0.7940 Loss_G_class1: 0.0128 Loss_G_class2: 0.0141: 100%|██████████| 25/25 [00:35<00:00,  1.43s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 33.50it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.8837818412677102
EarlyStopping counter: 3 out of 10


[6/50] Loss_D: 0.1644 Acc_D: 0.8618 Loss_G_class1: 0.0144 Loss_G_class2: 0.0120: 100%|██████████| 25/25 [00:35<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7826: 100%|██████████| 23/23 [00:01<00:00, 14.36it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.782608695652174
Valid Loss = 0.6207274286197904
EarlyStopping counter: 4 out of 10


[7/50] Loss_D: 0.1184 Acc_D: 0.8920 Loss_G_class1: 0.0089 Loss_G_class2: 0.0107: 100%|██████████| 25/25 [00:37<00:00,  1.49s/it]
[Validating]: Acc_D: 0.7826: 100%|██████████| 23/23 [00:00<00:00, 33.42it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.782608695652174
Valid Loss = 0.43684717729363753


[8/50] Loss_D: 0.1218 Acc_D: 0.8995 Loss_G_class1: 0.0091 Loss_G_class2: 0.0098: 100%|██████████| 25/25 [00:37<00:00,  1.50s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 31.77it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 1.2815020822960397
EarlyStopping counter: 1 out of 10


[9/50] Loss_D: 0.1141 Acc_D: 0.8995 Loss_G_class1: 0.0101 Loss_G_class2: 0.0087: 100%|██████████| 25/25 [00:36<00:00,  1.47s/it]
[Validating]: Acc_D: 0.8261: 100%|██████████| 23/23 [00:00<00:00, 33.69it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.8260869565217391
Valid Loss = 0.6781988402261682
EarlyStopping counter: 2 out of 10


[10/50] Loss_D: 0.0912 Acc_D: 0.9121 Loss_G_class1: 0.0076 Loss_G_class2: 0.0084: 100%|██████████| 25/25 [00:35<00:00,  1.44s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:03<00:00,  7.57it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.7863435639959314
EarlyStopping counter: 3 out of 10


[11/50] Loss_D: 0.0835 Acc_D: 0.9372 Loss_G_class1: 0.0088 Loss_G_class2: 0.0064: 100%|██████████| 25/25 [00:37<00:00,  1.51s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 32.29it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 1.2101526058724392
EarlyStopping counter: 4 out of 10


[12/50] Loss_D: 0.0655 Acc_D: 0.9523 Loss_G_class1: 0.0072 Loss_G_class2: 0.0058: 100%|██████████| 25/25 [00:36<00:00,  1.45s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 33.44it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.7641653081847598
EarlyStopping counter: 5 out of 10


[13/50] Loss_D: 0.0831 Acc_D: 0.9347 Loss_G_class1: 0.0075 Loss_G_class2: 0.0073: 100%|██████████| 25/25 [00:36<00:00,  1.46s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:03<00:00,  7.13it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.876377747013264
EarlyStopping counter: 6 out of 10


[14/50] Loss_D: 0.0655 Acc_D: 0.9523 Loss_G_class1: 0.0065 Loss_G_class2: 0.0054: 100%|██████████| 25/25 [00:36<00:00,  1.47s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:02<00:00, 10.43it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 1.0602057983500266
EarlyStopping counter: 7 out of 10


[15/50] Loss_D: 0.0908 Acc_D: 0.9171 Loss_G_class1: 0.0073 Loss_G_class2: 0.0079: 100%|██████████| 25/25 [00:35<00:00,  1.43s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:01<00:00, 15.45it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.6898005203060482
EarlyStopping counter: 8 out of 10


[16/50] Loss_D: 0.0732 Acc_D: 0.9497 Loss_G_class1: 0.0061 Loss_G_class2: 0.0072: 100%|██████████| 25/25 [00:35<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:02<00:00,  9.82it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.9097748606297461
EarlyStopping counter: 9 out of 10


[17/50] Loss_D: 0.0694 Acc_D: 0.9548 Loss_G_class1: 0.0058 Loss_G_class2: 0.0053: 100%|██████████| 25/25 [00:35<00:00,  1.43s/it]
[Validating]: Acc_D: 0.8261: 100%|██████████| 23/23 [00:00<00:00, 33.19it/s]
  0%|          | 0/23 [00:00<?, ?it/s]


Valid Acc = 0.8260869565217391
Valid Loss = 0.805582619920049
EarlyStopping counter: 10 out of 10
Early stopping, epoch 17


[Validating]: Acc_D: 0.7826: 100%|██████████| 23/23 [00:00<00:00, 31.94it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.782608695652174
Valid Loss = 0.8547412001568339
Training node 2
[INFO] weights = [1.       0.828125]


[Validating]: Acc_D: 0.4583: 100%|██████████| 24/24 [00:00<00:00, 34.92it/s]


[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.4583: 100%|██████████| 24/24 [00:00<00:00, 27.22it/s]
[1/50] Loss_D: 0.4319 Acc_D: 0.6119 Loss_G_class1: 0.1240 Loss_G_class2: 0.1389: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:02<00:00, 10.10it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.5936628505587578


[2/50] Loss_D: 0.2960 Acc_D: 0.6881 Loss_G_class1: 0.0262 Loss_G_class2: 0.0237: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:05<00:00,  4.27it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.5385719755043586


[3/50] Loss_D: 0.2473 Acc_D: 0.7667 Loss_G_class1: 0.0176 Loss_G_class2: 0.0169: 100%|██████████| 27/27 [00:40<00:00,  1.51s/it]
[Validating]: Acc_D: 0.6250: 100%|██████████| 24/24 [00:02<00:00, 11.80it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.625
Valid Loss = 0.6689012253967425
EarlyStopping counter: 1 out of 10


[4/50] Loss_D: 0.2504 Acc_D: 0.7571 Loss_G_class1: 0.0184 Loss_G_class2: 0.0152: 100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 35.70it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.5888022153327862
EarlyStopping counter: 2 out of 10


[5/50] Loss_D: 0.2204 Acc_D: 0.8000 Loss_G_class1: 0.0167 Loss_G_class2: 0.0139: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6667: 100%|██████████| 24/24 [00:00<00:00, 34.94it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.6666666666666666
Valid Loss = 0.8542049409200748
EarlyStopping counter: 3 out of 10


[6/50] Loss_D: 0.1890 Acc_D: 0.8429 Loss_G_class1: 0.0127 Loss_G_class2: 0.0134: 100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:01<00:00, 16.15it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.4777667298913002


[7/50] Loss_D: 0.1737 Acc_D: 0.8405 Loss_G_class1: 0.0123 Loss_G_class2: 0.0118: 100%|██████████| 27/27 [00:40<00:00,  1.50s/it]
[Validating]: Acc_D: 0.8333: 100%|██████████| 24/24 [00:01<00:00, 14.52it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.8333333333333334
Valid Loss = 0.6290932726114988
EarlyStopping counter: 1 out of 10


[8/50] Loss_D: 0.1525 Acc_D: 0.8595 Loss_G_class1: 0.0108 Loss_G_class2: 0.0107: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:01<00:00, 15.41it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.736982688618203
EarlyStopping counter: 2 out of 10


[9/50] Loss_D: 0.1573 Acc_D: 0.8595 Loss_G_class1: 0.0111 Loss_G_class2: 0.0119: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:00<00:00, 34.50it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.4323978318522374


[10/50] Loss_D: 0.1555 Acc_D: 0.8738 Loss_G_class1: 0.0118 Loss_G_class2: 0.0101: 100%|██████████| 27/27 [00:39<00:00,  1.47s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:00<00:00, 31.54it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.8123262313505014
EarlyStopping counter: 1 out of 10


[11/50] Loss_D: 0.1368 Acc_D: 0.8881 Loss_G_class1: 0.0113 Loss_G_class2: 0.0093: 100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
[Validating]: Acc_D: 0.8750: 100%|██████████| 24/24 [00:01<00:00, 14.00it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.875
Valid Loss = 0.4418488061055541
EarlyStopping counter: 2 out of 10


[12/50] Loss_D: 0.1444 Acc_D: 0.8810 Loss_G_class1: 0.0103 Loss_G_class2: 0.0106: 100%|██████████| 27/27 [00:38<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 31.99it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.4766076131102939
EarlyStopping counter: 3 out of 10


[13/50] Loss_D: 0.1229 Acc_D: 0.9048 Loss_G_class1: 0.0083 Loss_G_class2: 0.0091: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 32.92it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.5894755534827709
EarlyStopping counter: 4 out of 10


[14/50] Loss_D: 0.1307 Acc_D: 0.9048 Loss_G_class1: 0.0097 Loss_G_class2: 0.0087: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 34.70it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.6953975247840086
EarlyStopping counter: 5 out of 10


[15/50] Loss_D: 0.1481 Acc_D: 0.8810 Loss_G_class1: 0.0092 Loss_G_class2: 0.0106: 100%|██████████| 27/27 [00:38<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:01<00:00, 14.66it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.5615521954993407
EarlyStopping counter: 6 out of 10


[16/50] Loss_D: 0.1316 Acc_D: 0.8929 Loss_G_class1: 0.0087 Loss_G_class2: 0.0089: 100%|██████████| 27/27 [00:42<00:00,  1.56s/it]
[Validating]: Acc_D: 0.8333: 100%|██████████| 24/24 [00:01<00:00, 12.58it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.8333333333333334
Valid Loss = 0.5624654876689116
EarlyStopping counter: 7 out of 10


[17/50] Loss_D: 0.1344 Acc_D: 0.8833 Loss_G_class1: 0.0090 Loss_G_class2: 0.0109: 100%|██████████| 27/27 [00:38<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:00<00:00, 33.88it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.7692466483761867
EarlyStopping counter: 8 out of 10


[18/50] Loss_D: 0.1320 Acc_D: 0.9167 Loss_G_class1: 0.0093 Loss_G_class2: 0.0094: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6667: 100%|██████████| 24/24 [00:00<00:00, 32.95it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.6666666666666666
Valid Loss = 0.8964312796015292
EarlyStopping counter: 9 out of 10


[19/50] Loss_D: 0.1372 Acc_D: 0.8857 Loss_G_class1: 0.0097 Loss_G_class2: 0.0100: 100%|██████████| 27/27 [00:38<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 33.52it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.36004532935718697


[20/50] Loss_D: 0.1264 Acc_D: 0.9119 Loss_G_class1: 0.0082 Loss_G_class2: 0.0095: 100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:02<00:00,  9.64it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.4631953189770381
EarlyStopping counter: 1 out of 10


[21/50] Loss_D: 0.1310 Acc_D: 0.8952 Loss_G_class1: 0.0076 Loss_G_class2: 0.0101: 100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
[Validating]: Acc_D: 0.8333: 100%|██████████| 24/24 [00:01<00:00, 18.19it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.8333333333333334
Valid Loss = 0.5737002919583271
EarlyStopping counter: 2 out of 10


[22/50] Loss_D: 0.1281 Acc_D: 0.9167 Loss_G_class1: 0.0085 Loss_G_class2: 0.0085: 100%|██████████| 27/27 [00:39<00:00,  1.46s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:00<00:00, 35.13it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.5875272379877666
EarlyStopping counter: 3 out of 10


[23/50] Loss_D: 0.1091 Acc_D: 0.9238 Loss_G_class1: 0.0070 Loss_G_class2: 0.0077: 100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6667: 100%|██████████| 24/24 [00:00<00:00, 32.57it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.6666666666666666
Valid Loss = 0.6441383448739847
EarlyStopping counter: 4 out of 10


[24/50] Loss_D: 0.1309 Acc_D: 0.8786 Loss_G_class1: 0.0080 Loss_G_class2: 0.0095: 100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:01<00:00, 16.28it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.6129590994678438
EarlyStopping counter: 5 out of 10


[25/50] Loss_D: 0.1237 Acc_D: 0.8857 Loss_G_class1: 0.0087 Loss_G_class2: 0.0085: 100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:01<00:00, 12.58it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.674946780782193
EarlyStopping counter: 6 out of 10


[26/50] Loss_D: 0.1390 Acc_D: 0.8857 Loss_G_class1: 0.0100 Loss_G_class2: 0.0092: 100%|██████████| 27/27 [00:38<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7917: 100%|██████████| 24/24 [00:02<00:00, 11.16it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.7916666666666666
Valid Loss = 0.5358523917384446
EarlyStopping counter: 7 out of 10


[27/50] Loss_D: 0.1406 Acc_D: 0.8786 Loss_G_class1: 0.0117 Loss_G_class2: 0.0090: 100%|██████████| 27/27 [00:38<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7500: 100%|██████████| 24/24 [00:00<00:00, 34.91it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.75
Valid Loss = 0.5089286339934915
EarlyStopping counter: 8 out of 10


[28/50] Loss_D: 0.1283 Acc_D: 0.8976 Loss_G_class1: 0.0095 Loss_G_class2: 0.0092: 100%|██████████| 27/27 [00:38<00:00,  1.43s/it]
[Validating]: Acc_D: 0.8750: 100%|██████████| 24/24 [00:00<00:00, 32.74it/s]
  0%|          | 0/27 [00:00<?, ?it/s]


Valid Acc = 0.875
Valid Loss = 0.4670220084177951
EarlyStopping counter: 9 out of 10


[29/50] Loss_D: 0.1323 Acc_D: 0.8905 Loss_G_class1: 0.0086 Loss_G_class2: 0.0083: 100%|██████████| 27/27 [00:40<00:00,  1.51s/it]
[Validating]: Acc_D: 0.8333: 100%|██████████| 24/24 [00:03<00:00,  7.72it/s]
  0%|          | 0/24 [00:00<?, ?it/s]


Valid Acc = 0.8333333333333334
Valid Loss = 0.6151562499192854
EarlyStopping counter: 10 out of 10
Early stopping, epoch 29


[Validating]: Acc_D: 0.8750: 100%|██████████| 24/24 [00:00<00:00, 31.69it/s]


[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.875
Valid Loss = 0.505636791077753
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f03425a99e8>: [0.7426223094351807, 0.6608187134502924, {'0': {'precision': 0.9444444444444444, 'recall': 0.6159420289855072, 'f1-score': 0.7456140350877193, 'support': 138}, '1': {'precision': 0.345679012345679, 'recall': 0.8484848484848485, 'f1-score': 0.4912280701754386, 'support': 33}, 'accuracy': 0.6608187134502924, 'macro avg': {'precision': 0.6450617283950617, 'recall': 0.7322134387351779, 'f1-score': 0.618421052631579, 'support': 171}, 'weighted avg': {'precision': 0.8288932207060863, 'recall': 0.6608187134502924, 'f1-score': 0.6965220067713143, 'support': 171}}]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f03425a9710>: [1.1105456775280962, 0.7543859649122807, {'0': {'precision': 0.9666666666666667, 'recall': 0.6904761904761905, 'f1-sc

In [11]:
metrics = cit_federated_government.global_model.evaluate(test_data, test_label)
print("CIT Classifier Results:")
print("Loss: {}".format(metrics[0]))
print("Acc: {}".format(metrics[1]))
print(metrics[2])

CIT Classifier Results:
Loss: 0.7426223080136884
Acc: 0.6608187134502924
{'0': {'precision': 0.9444444444444444, 'recall': 0.6159420289855072, 'f1-score': 0.7456140350877193, 'support': 138}, '1': {'precision': 0.345679012345679, 'recall': 0.8484848484848485, 'f1-score': 0.4912280701754386, 'support': 33}, 'accuracy': 0.6608187134502924, 'macro avg': {'precision': 0.6450617283950617, 'recall': 0.7322134387351779, 'f1-score': 0.618421052631579, 'support': 171}, 'weighted avg': {'precision': 0.8288932207060863, 'recall': 0.6608187134502924, 'f1-score': 0.6965220067713143, 'support': 171}}


In [12]:
t_federated_data = get_transformed_data(federated_data, cit_federated_government, lb1, lb2)

In [13]:
aggregator = shfl.federated_aggregator.FedAvgAggregator()
G_dict = cit_federated_government.global_model._G_dict
x = lambda : classifier_builder(G_dict)
classifier_federated_government = shfl.federated_government.FederatedGovernment(x, t_federated_data, aggregator)
#classifier_federated_government = FederatedGovernment_2(classifier_builder(G_dict), t_federated_data, aggregator)
classifier_federated_government.run_rounds(args["federated_rounds"], test_data, test_label)

#print("[INFO] saving model ...")
#federated_government.global_model._model.save( os.path.join(args["output_path"], args["model_name"]) )
print("[INFO] done")

Accuracy round 0
Training node 0
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50


Epoch 00049: early stopping
Training node 1
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 00026: early stopping
Training node 2
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50


Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 00045: early stopping
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f03f844a588>: [0.6491228070175439, 0.6374269005847953, 50, {'0': {'precision': 0.7777777777777778, 'recall': 0.4666666666666667, 'f1-score': 0.5833333333333334, 'support': 90}, '1': {'precision': 0.5897435897435898, 'recall': 0.8518518518518519, 'f1-score': 0.6969696969696971, 'support': 81}, 'accuracy': 0.6491228070175439, 'macro avg': {'precision': 0.6837606837606838, 'recall': 0.6592592592592592, 'f1-score': 0.6401515151515152, 'support': 171}, 'weighted avg': {'precision': 0.6887089518668467, 'recall': 0.6491228070175439, 'f1-score': 0.6371610845295057, 'support': 171}}]
Test performance client <

In [10]:
metrics = classifier_federated_government.global_model.evaluate(test_data, test_label)
print("SDNET Classifier Results:")
print("Acc: {}".format(metrics[0]))
print("Acc_4: {}".format(metrics[1]))
print("No concuerda: {}".format(metrics[2]))
print(metrics[3])

SDNET Classifier Results:
Acc: 0.7543859649122807
Acc_4: 0.7660818713450293
No concuerda: 20
{'0': {'precision': 0.7264150943396226, 'recall': 0.8555555555555555, 'f1-score': 0.7857142857142856, 'support': 90}, '1': {'precision': 0.8, 'recall': 0.6419753086419753, 'f1-score': 0.7123287671232877, 'support': 81}, 'accuracy': 0.7543859649122807, 'macro avg': {'precision': 0.7632075471698113, 'recall': 0.7487654320987653, 'f1-score': 0.7490215264187867, 'support': 171}, 'weighted avg': {'precision': 0.761271102284012, 'recall': 0.7543859649122807, 'f1-score': 0.7509527242764447, 'support': 171}}
