In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="5"
import shfl
import torch
import copy
import cv2
import numpy as np

from sklearn.preprocessing import LabelBinarizer

from shfl.private import UnprotectedAccess
from CIT.model import CITModel
from utils import get_federated_data_csv, get_data_csv
from ClassifierModel import ClassifierModel

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
args = {"data_path":"../data/COVIDGR1.0-Segmentadas", 
        "csv_path": "../partitions/partition_reducido_fed.csv",
        "output_path": "../weights",
        "input_path": "",
        "model_name":"transferlearning.model", 
        "label_bin": "lb.pickle", 
        "batch_size": 1,
        "federated_rounds": 1,
        "epochs_per_FL_round": 1,
        "num_nodes": 3,
        "size_averaging": 1,
        "random_rotation": 0,
        "random_shift": 0, 
        "random_zoom": 0,
        "horizontal_flip": False,        
        "finetune": True,
        "train_network": True}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
a = ['N', 'P']
b = ['NTN', 'NTP', 'PTP', 'PTN']
lb1 = LabelBinarizer()
lb2 = LabelBinarizer()
lb1.fit(a)
lb2.fit(b)

print("[INFO] Fetching federated data...")
federated_data, train_data, train_label, test_data, test_label, train_files, test_files, args["num_nodes"] = get_federated_data_csv(args["data_path"], args["csv_path"], lb1)
federated_data.configure_data_access(UnprotectedAccess())
print(len(train_data))
print(len(test_data))
print(args["num_nodes"])

for i in range(federated_data.num_nodes()):
    data = federated_data[i].query()._data
    print(len(data))

print("[INFO] done")

[INFO] Fetching federated data...
16
4
1
16
[INFO] done


In [4]:
def cit_builder():    
    return CITModel(['N', 'P'], classifier_name = "resnet18", folds=1, lambda_values = [0.05], batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], device=device)


In [5]:
aggregator = shfl.federated_aggregator.FedAvgAggregator()
cit_federated_government = shfl.federated_government.FederatedGovernment(cit_builder, federated_data, aggregator)
cit_federated_government.run_rounds(args["federated_rounds"], test_data, test_label)

Accuracy round 0
Training node 0
[INFO] weights = [1. 1.]


[Validating]: Acc_D: 1.0000: 100%|██████████| 2/2 [00:01<00:00,  1.47it/s]


[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 1.0000: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s]
[1/1] Loss_D: 0.4824 Acc_D: 0.3571 Loss_G_class1: 0.1903 Loss_G_class2: 0.1667: 100%|██████████| 14/14 [00:03<00:00,  4.21it/s]
[Validating]: Acc_D: 0.5000: 100%|██████████| 2/2 [00:00<00:00,  5.85it/s]
  0%|          | 0/2 [00:00<?, ?it/s]


Valid Acc = 0.5
Valid Loss = 1.01156884431839


[Validating]: Acc_D: 1.0000: 100%|██████████| 2/2 [00:00<00:00,  7.87it/s]

[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 1.0
Valid Loss = 0.6437866687774658





Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f2ed47b1320>: [0.6898781061172485, 0.75, {'0': {'precision': 0.5, 'recall': 1.0, 'f1-score': 0.6666666666666666, 'support': 1}, '1': {'precision': 1.0, 'recall': 0.6666666666666666, 'f1-score': 0.8, 'support': 3}, 'accuracy': 0.75, 'macro avg': {'precision': 0.75, 'recall': 0.8333333333333333, 'f1-score': 0.7333333333333334, 'support': 4}, 'weighted avg': {'precision': 0.875, 'recall': 0.75, 'f1-score': 0.7666666666666667, 'support': 4}}]
Global model test performance : [0.6898781061172485, 0.75, {'0': {'precision': 0.5, 'recall': 1.0, 'f1-score': 0.6666666666666666, 'support': 1}, '1': {'precision': 1.0, 'recall': 0.6666666666666666, 'f1-score': 0.8, 'support': 3}, 'accuracy': 0.75, 'macro avg': {'precision': 0.75, 'recall': 0.8333333333333333, 'f1-score': 0.7333333333333334, 'support': 4}, 'weighted avg': {'precision': 0.875, 'recall': 0.75, 'f1-score': 0.7666666666666667, 'support': 4}}]





In [6]:
"""
cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()

cit2._G_dict['P'].to('cpu')

#for p in cit2._G_dict['P'].parameters():
#    print(p)

state_dict = cit2._G_dict['P'].state_dict()
new_state_dict = copy.deepcopy(state_dict)

for k in state_dict:
    print(state_dict[k])
"""

"\ncit1 = cit_builder()\ncit2 = cit_builder()\ncit3 = cit_builder()\n\ncit2._G_dict['P'].to('cpu')\n\n#for p in cit2._G_dict['P'].parameters():\n#    print(p)\n\nstate_dict = cit2._G_dict['P'].state_dict()\nnew_state_dict = copy.deepcopy(state_dict)\n\nfor k in state_dict:\n    print(state_dict[k])\n"

In [7]:
"""
with torch.no_grad():
    
    for k in state_dict:
        #print(k)
        new_state_dict[k] = 2*state_dict[k]
        #print(new_state_dict[k])
    cit2._G_dict['P'].load_state_dict(new_state_dict)
    

"""
    

"\nwith torch.no_grad():\n    \n    for k in state_dict:\n        #print(k)\n        new_state_dict[k] = 2*state_dict[k]\n        #print(new_state_dict[k])\n    cit2._G_dict['P'].load_state_dict(new_state_dict)\n    \n\n"

In [8]:
"""
state_dict = cit2._G_dict['P'].state_dict()
for k in state_dict:
    print(state_dict[k])
"""

"\nstate_dict = cit2._G_dict['P'].state_dict()\nfor k in state_dict:\n    print(state_dict[k])\n"

In [9]:
"""
cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()

for cit in [cit1, cit2, cit3]:
    state_dict = cit._G_dict['P'].state_dict()
    k = "block1.0.weight"
    print(state_dict[k][0][0][0])

aggregator = shfl.federated_aggregator.FedAvgAggregator()
agg = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])
cit1.set_model_params(agg)

print("bbb")
state_dict = cit1._G_dict['P'].state_dict()
k = "block1.0.weight"
print(state_dict[k][0][0][0])
"""


'\ncit1 = cit_builder()\ncit2 = cit_builder()\ncit3 = cit_builder()\n\nfor cit in [cit1, cit2, cit3]:\n    state_dict = cit._G_dict[\'P\'].state_dict()\n    k = "block1.0.weight"\n    print(state_dict[k][0][0][0])\n\naggregator = shfl.federated_aggregator.FedAvgAggregator()\nagg = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])\ncit1.set_model_params(agg)\n\nprint("bbb")\nstate_dict = cit1._G_dict[\'P\'].state_dict()\nk = "block1.0.weight"\nprint(state_dict[k][0][0][0])\n'

In [10]:
"""
cit1 = cit_builder()
cit2 = cit_builder()
cit3 = cit_builder()

state_dict = cit1._classifier.state_dict()
state_dict[k][0][0][0][0] = 3
cit1._classifier.load_state_dict(state_dict)

k = "conv1.weight"
for cit in [cit1, cit2, cit3]:
    state_dict = cit._classifier.state_dict()
    print(state_dict[k][0][0][0])

state_dict[k][0][0][0][0] = 3
cit1._classifier.load_state_dict(state_dict)
    
aggregator = shfl.federated_aggregator.FedAvgAggregator()
agg = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])
cit1.set_model_params(agg)

print("bbb")
state_dict = cit1._classifier.state_dict()
k = "conv1.weight"
print(state_dict[k][0][0][0])
"""

'\ncit1 = cit_builder()\ncit2 = cit_builder()\ncit3 = cit_builder()\n\nstate_dict = cit1._classifier.state_dict()\nstate_dict[k][0][0][0][0] = 3\ncit1._classifier.load_state_dict(state_dict)\n\nk = "conv1.weight"\nfor cit in [cit1, cit2, cit3]:\n    state_dict = cit._classifier.state_dict()\n    print(state_dict[k][0][0][0])\n\nstate_dict[k][0][0][0][0] = 3\ncit1._classifier.load_state_dict(state_dict)\n    \naggregator = shfl.federated_aggregator.FedAvgAggregator()\nagg = aggregator.aggregate_weights([cit1.get_model_params(), cit2.get_model_params(), cit3.get_model_params()])\ncit1.set_model_params(agg)\n\nprint("bbb")\nstate_dict = cit1._classifier.state_dict()\nk = "conv1.weight"\nprint(state_dict[k][0][0][0])\n'