In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import shfl
import torch
import copy
import cv2
import numpy as np

from sklearn.preprocessing import LabelBinarizer

from shfl.private import UnprotectedAccess
from CIT.model import CITModel
from utils import get_federated_data_csv, get_data_csv
from ClassifierModel import ClassifierModel

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
args = {"data_path":"../data/COVIDGR1.0-Segmentadas", 
        "csv_path": "../partitions/partition_iid_3nodes_1.csv",
        "output_path": "../weights",
        "input_path": "",
        "model_name":"transferlearning.model", 
        "label_bin": "lb.pickle", 
        "batch_size": 8,
        "federated_rounds": 3,
        "epochs_per_FL_round": 20,
        "num_nodes": 3,
        "size_averaging": 1,
        "random_rotation": 0,
        "random_shift": 0, 
        "random_zoom": 0,
        "horizontal_flip": False,        
        "finetune": True,
        "train_network": True}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
a = ['N', 'P']
b = ['NTN', 'NTP', 'PTP', 'PTN']
lb1 = LabelBinarizer()
lb2 = LabelBinarizer()
lb1.fit(a)
lb2.fit(b)

print("[INFO] Fetching federated data...")
federated_data, train_data, train_label, test_data, test_label, train_files, test_files, args["num_nodes"] = get_federated_data_csv(args["data_path"], args["csv_path"], lb1)
federated_data.configure_data_access(UnprotectedAccess())
print(len(train_data))
print(len(test_data))
print(args["num_nodes"])

for i in range(federated_data.num_nodes()):
    data = federated_data[i].query()._data
    print(len(data))

print("[INFO] done")

[INFO] Fetching federated data...
681
171
3
225
222
234
[INFO] done


In [4]:
def cit_builder():    
    return CITModel(['N', 'P'], classifier_name = "resnet18", folds=1, lambda_values = [0.05], batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], device=device)


In [None]:
aggregator = shfl.federated_aggregator.FedAvgAggregator()
cit_federated_government = shfl.federated_government.FederatedGovernment(cit_builder, federated_data, aggregator)
cit_federated_government.run_rounds(args["federated_rounds"], test_data, test_label)


Accuracy round 0


  0%|          | 0/23 [00:00<?, ?it/s]

Training node 0
[INFO] weights = [1.         0.93965517]


[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 26.71it/s]


[INFO] Initial Valid Scores: 
Valid Acc = 0.5217391304347826
Valid Loss = 0.8046197476594344
[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 30.67it/s]
[1/20] Loss_D: 0.4642 Acc_D: 0.5668 Loss_G_class1: 0.1416 Loss_G_class2: 0.1560: 100%|██████████| 26/26 [00:36<00:00,  1.40s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.45it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.7492736487284951


[2/20] Loss_D: 0.3028 Acc_D: 0.7005 Loss_G_class1: 0.0284 Loss_G_class2: 0.0293: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.44it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.8185062654640364
EarlyStopping counter: 1 out of 10


[3/20] Loss_D: 0.2760 Acc_D: 0.7228 Loss_G_class1: 0.0186 Loss_G_class2: 0.0197: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:00<00:00, 30.72it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 1.5514604280824247
EarlyStopping counter: 2 out of 10


[4/20] Loss_D: 0.2672 Acc_D: 0.7450 Loss_G_class1: 0.0202 Loss_G_class2: 0.0185: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.47it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 1.0504046704458154
EarlyStopping counter: 3 out of 10


[5/20] Loss_D: 0.2718 Acc_D: 0.7673 Loss_G_class1: 0.0171 Loss_G_class2: 0.0177: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.24it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.9506891628970271
EarlyStopping counter: 4 out of 10


[6/20] Loss_D: 0.2195 Acc_D: 0.8119 Loss_G_class1: 0.0127 Loss_G_class2: 0.0157: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 30.44it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.6445851491197295


[7/20] Loss_D: 0.1853 Acc_D: 0.8441 Loss_G_class1: 0.0116 Loss_G_class2: 0.0119: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:00<00:00, 32.41it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.6841571235462375
EarlyStopping counter: 1 out of 10


[8/20] Loss_D: 0.1772 Acc_D: 0.8416 Loss_G_class1: 0.0112 Loss_G_class2: 0.0106: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 29.55it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.7249868287016517
EarlyStopping counter: 2 out of 10


[9/20] Loss_D: 0.1802 Acc_D: 0.8639 Loss_G_class1: 0.0115 Loss_G_class2: 0.0106: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 33.01it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.8092112099022969
EarlyStopping counter: 3 out of 10


[10/20] Loss_D: 0.1770 Acc_D: 0.8416 Loss_G_class1: 0.0112 Loss_G_class2: 0.0116: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:00<00:00, 31.06it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 0.7678570819451757
EarlyStopping counter: 4 out of 10


[11/20] Loss_D: 0.1558 Acc_D: 0.8738 Loss_G_class1: 0.0100 Loss_G_class2: 0.0091: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 31.48it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.8993360697413268
EarlyStopping counter: 5 out of 10


[12/20] Loss_D: 0.1410 Acc_D: 0.9010 Loss_G_class1: 0.0095 Loss_G_class2: 0.0079: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 31.20it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.7223509147601284
EarlyStopping counter: 6 out of 10


[13/20] Loss_D: 0.1459 Acc_D: 0.8738 Loss_G_class1: 0.0089 Loss_G_class2: 0.0089: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 32.96it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.9562223142255908
EarlyStopping counter: 7 out of 10


[14/20] Loss_D: 0.1374 Acc_D: 0.8960 Loss_G_class1: 0.0087 Loss_G_class2: 0.0087: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.975662217353997
EarlyStopping counter: 8 out of 10


[15/20] Loss_D: 0.1538 Acc_D: 0.8713 Loss_G_class1: 0.0095 Loss_G_class2: 0.0095: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 30.66it/s]
  0%|          | 0/26 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.9678013473344238
EarlyStopping counter: 9 out of 10


[16/20] Loss_D: 0.1478 Acc_D: 0.8812 Loss_G_class1: 0.0093 Loss_G_class2: 0.0083: 100%|██████████| 26/26 [00:36<00:00,  1.39s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 23/23 [00:00<00:00, 32.19it/s]
  0%|          | 0/23 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.7497970727474793
EarlyStopping counter: 10 out of 10
Early stopping, epoch 16


[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 34.95it/s]
  0%|          | 0/23 [00:00<?, ?it/s]

[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.6521739130434783
Valid Loss = 0.793949225836474
Training node 1
[INFO] weights = [0.83471074 1.        ]


[Validating]: Acc_D: 0.4783: 100%|██████████| 23/23 [00:00<00:00, 33.19it/s]


[INFO] Initial Valid Scores: 
Valid Acc = 0.4782608695652174
Valid Loss = 0.8390094132527061
[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.4783: 100%|██████████| 23/23 [00:00<00:00, 28.80it/s]
[1/20] Loss_D: 0.4340 Acc_D: 0.5829 Loss_G_class1: 0.1428 Loss_G_class2: 0.1533: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:00<00:00, 34.61it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 0.8746322587780331


[2/20] Loss_D: 0.2663 Acc_D: 0.7412 Loss_G_class1: 0.0283 Loss_G_class2: 0.0292: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 35.11it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.7831738377394883


[3/20] Loss_D: 0.2686 Acc_D: 0.7462 Loss_G_class1: 0.0192 Loss_G_class2: 0.0204: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 30.99it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5217391304347826
Valid Loss = 0.718564417699109


[4/20] Loss_D: 0.2227 Acc_D: 0.8040 Loss_G_class1: 0.0157 Loss_G_class2: 0.0170: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 33.99it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5217391304347826
Valid Loss = 1.2189899682998657
EarlyStopping counter: 1 out of 10


[5/20] Loss_D: 0.2108 Acc_D: 0.8191 Loss_G_class1: 0.0163 Loss_G_class2: 0.0145: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:00<00:00, 34.10it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 0.8167724483039068
EarlyStopping counter: 2 out of 10


[6/20] Loss_D: 0.1595 Acc_D: 0.8869 Loss_G_class1: 0.0110 Loss_G_class2: 0.0115: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.67it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.6410856519056403


[7/20] Loss_D: 0.1482 Acc_D: 0.8945 Loss_G_class1: 0.0101 Loss_G_class2: 0.0115: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 33.07it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.7337059988235326
EarlyStopping counter: 1 out of 10


[8/20] Loss_D: 0.1382 Acc_D: 0.9095 Loss_G_class1: 0.0094 Loss_G_class2: 0.0108: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 35.06it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.7937514182666073
EarlyStopping counter: 2 out of 10


[9/20] Loss_D: 0.1102 Acc_D: 0.9020 Loss_G_class1: 0.0090 Loss_G_class2: 0.0075: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 34.52it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.7839117639738581
EarlyStopping counter: 3 out of 10


[10/20] Loss_D: 0.1228 Acc_D: 0.9045 Loss_G_class1: 0.0083 Loss_G_class2: 0.0095: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5217: 100%|██████████| 23/23 [00:00<00:00, 34.00it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5217391304347826
Valid Loss = 0.7506458888478253
EarlyStopping counter: 4 out of 10


[11/20] Loss_D: 0.1352 Acc_D: 0.8970 Loss_G_class1: 0.0090 Loss_G_class2: 0.0105: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 32.71it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.7463941279312839
EarlyStopping counter: 5 out of 10


[12/20] Loss_D: 0.1092 Acc_D: 0.9221 Loss_G_class1: 0.0076 Loss_G_class2: 0.0089: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 32.00it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.8487373694129612
EarlyStopping counter: 6 out of 10


[13/20] Loss_D: 0.1080 Acc_D: 0.9196 Loss_G_class1: 0.0071 Loss_G_class2: 0.0088: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.5652: 100%|██████████| 23/23 [00:00<00:00, 33.62it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.5652173913043478
Valid Loss = 0.7875488257723982
EarlyStopping counter: 7 out of 10


[14/20] Loss_D: 0.1007 Acc_D: 0.9221 Loss_G_class1: 0.0072 Loss_G_class2: 0.0076: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 34.22it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6521739130434783
Valid Loss = 0.673385707580525
EarlyStopping counter: 8 out of 10


[15/20] Loss_D: 0.1143 Acc_D: 0.9146 Loss_G_class1: 0.0075 Loss_G_class2: 0.0087: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 23/23 [00:00<00:00, 32.31it/s]
  0%|          | 0/25 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.8598640859127045
EarlyStopping counter: 9 out of 10


[16/20] Loss_D: 0.1121 Acc_D: 0.9121 Loss_G_class1: 0.0084 Loss_G_class2: 0.0087: 100%|██████████| 25/25 [00:35<00:00,  1.42s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
  0%|          | 0/23 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 1.0685359456616899
EarlyStopping counter: 10 out of 10
Early stopping, epoch 16


[Validating]: Acc_D: 0.6522: 100%|██████████| 23/23 [00:00<00:00, 31.34it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.6521739130434783
Valid Loss = 0.9070882100368971
Training node 2
[INFO] weights = [1.       0.828125]


[Validating]: Acc_D: 0.4583: 100%|██████████| 24/24 [00:00<00:00, 35.53it/s]


[INFO] Initial Valid Scores: 
Valid Acc = 0.4583333333333333
Valid Loss = 0.8612320696314176
[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.4583: 100%|██████████| 24/24 [00:00<00:00, 29.17it/s]
[1/20] Loss_D: 0.5055 Acc_D: 0.5139 Loss_G_class1: 0.1699 Loss_G_class2: 0.1753:  67%|██████▋   | 18/27 [00:25<00:12,  1.41s/it]

In [None]:
"""
cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()

cit2._G_dict['P'].to('cpu')

#for p in cit2._G_dict['P'].parameters():
#    print(p)

state_dict = cit2._G_dict['P'].state_dict()
new_state_dict = copy.deepcopy(state_dict)

for k in state_dict:
    print(state_dict[k])
"""

In [None]:
"""
with torch.no_grad():
    
    for k in state_dict:
        #print(k)
        new_state_dict[k] = 2*state_dict[k]
        #print(new_state_dict[k])
    cit2._G_dict['P'].load_state_dict(new_state_dict)
    

"""
    

In [None]:
"""
state_dict = cit2._G_dict['P'].state_dict()
for k in state_dict:
    print(state_dict[k])
"""

In [None]:
cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()

for cit in [cit1, cit2, cit3]:
    state_dict = cit._G_dict['P'].state_dict()
    k = "block1.0.weight"
    print(state_dict[k][0][0][0])

aggregator = shfl.federated_aggregator.FedAvgAggregator()
aggregated_weights = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])

cit1.set_model_params(aggregated_weights)

print("bbb")
state_dict = cit1._G_dict['P'].state_dict()
k = "block1.0.weight"
print(state_dict[k][0][0][0])


In [None]:

cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()
k = "conv1.weight"
state_dict = cit1._classifier.state_dict()
state_dict[k][0][0][0][0] = 3
cit1._classifier.load_state_dict(state_dict)

k = "conv1.weight"
for cit in [cit1, cit2, cit3]:
    state_dict = cit._classifier.state_dict()
    print(state_dict[k][0][0][0])

state_dict[k][0][0][0][0] = 3
cit1._classifier.load_state_dict(state_dict)
    
aggregator = shfl.federated_aggregator.FedAvgAggregator()
agg = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])
cit1.set_model_params(agg)

print("bbb")
state_dict = cit1._classifier.state_dict()
k = "conv1.weight"
print(state_dict[k][0][0][0])
