In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"
import shfl
import torch
import copy
import cv2
import numpy as np

from sklearn.preprocessing import LabelBinarizer

from shfl.private import UnprotectedAccess
from CIT.model import CITModel
from utils import get_federated_data_csv, get_data_csv
from ClassifierModel import ClassifierModel

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
#"../data/COVIDGR1.0/centralized/cropped"
#partition_iid_1nodes_1.csv
args = {"data_path":"../data/COVIDGR1.0-Segmentadas", 
        "csv_path": "../partitions/partition_iid_1nodes_1.csv",
        "output_path": "../weights",
        "input_path": "",
        "model_name":"transferlearning.model", 
        "label_bin": "lb.pickle", 
        "batch_size": 8,
        "federated_rounds": 1,
        "epochs_per_FL_round": 20,
        "num_nodes": 3,
        "size_averaging": 1,
        "random_rotation": 0,
        "random_shift": 0, 
        "random_zoom": 0,
        "horizontal_flip": False,        
        "finetune": True,
        "train_network": True}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
a = ['N', 'P']
b = ['NTN', 'NTP', 'PTP', 'PTN']
lb1 = LabelBinarizer()
lb2 = LabelBinarizer()
lb1.fit(a)
lb2.fit(b)

dict_labels = { 'PTP' : np.argmax(lb2.transform(['PTP'])[0]) , 'PTN' : np.argmax(lb2.transform(['PTN'])[0]) , 
                'NTP' : np.argmax(lb2.transform(['NTP'])[0]) , 'NTN' : np.argmax(lb2.transform(['NTN'])[0]), 
                'P' : lb1.transform(['P'])[0][0], 'N' : lb1.transform(['N'])[0][0]
              } 

data, label, train_data, train_label, test_data, test_label, train_files, test_files = get_data_csv(args["data_path"], args["csv_path"], lb1)

print(len(train_data))
print(len(test_data))

681
171


In [4]:
def cit_builder():    
    return CITModel(['N', 'P'], classifier_name = "resnet18", folds=1, lambda_values = [0.05], batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], device=device)

def classifier_builder( G_dict ):
    return ClassifierModel(G_dict, dict_labels, batch_size=args["batch_size"], epochs=args["epochs_per_FL_round"], finetune = args["finetune"])

def get_transformed_data(federated_data, cit_federated_government, lb1, lb2):
    t_federated_data = copy.deepcopy(federated_data)

    for i in range(federated_data.num_nodes()):
        data_node = federated_data[i]
        t_data_node = t_federated_data[i]
        data = data_node.query()._data
        labels = data_node.query()._label
        t_data, t_labels = cit_federated_government.global_model.transform_data(data, labels, lb1, lb2)
        t_data_node.query()._data = t_data
        t_data_node.query()._label = t_labels

    t_test_data, t_test_label = cit_federated_government.global_model.transform_data(test_data, test_label, lb1, lb2)

    return t_federated_data, t_test_data, t_test_label

In [5]:
cit_model = cit_builder()
cit_model.train(train_data, train_label)

[INFO] weights = [1.         0.97391304]
[INFO] LAMBDA: 0.05


[Validating]: Acc_D: 0.4783: 100%|██████████| 69/69 [00:02<00:00, 33.90it/s]
[1/20] Loss_D: 0.3630 Acc_D: 0.6291 Loss_G_class1: 0.0606 Loss_G_class2: 0.0498: 100%|██████████| 77/77 [01:50<00:00,  1.44s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 69/69 [00:01<00:00, 40.84it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.5250581140103547


[2/20] Loss_D: 0.3020 Acc_D: 0.6993 Loss_G_class1: 0.0184 Loss_G_class2: 0.0184: 100%|██████████| 77/77 [01:50<00:00,  1.44s/it]
[Validating]: Acc_D: 0.6957: 100%|██████████| 69/69 [00:01<00:00, 38.45it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.6956521739130435
Valid Loss = 0.6158042478280655
EarlyStopping counter: 1 out of 10


[3/20] Loss_D: 0.2839 Acc_D: 0.7075 Loss_G_class1: 0.0163 Loss_G_class2: 0.0171: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 69/69 [00:01<00:00, 37.92it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.6361405757134375
EarlyStopping counter: 2 out of 10


[4/20] Loss_D: 0.2769 Acc_D: 0.7206 Loss_G_class1: 0.0157 Loss_G_class2: 0.0166: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7246: 100%|██████████| 69/69 [00:01<00:00, 38.25it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7246376811594203
Valid Loss = 0.5989833228101117
EarlyStopping counter: 3 out of 10


[5/20] Loss_D: 0.2759 Acc_D: 0.7263 Loss_G_class1: 0.0161 Loss_G_class2: 0.0160: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.6087: 100%|██████████| 69/69 [00:01<00:00, 36.17it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.6086956521739131
Valid Loss = 0.616572426578057
EarlyStopping counter: 4 out of 10


[6/20] Loss_D: 0.2465 Acc_D: 0.7851 Loss_G_class1: 0.0156 Loss_G_class2: 0.0128: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7246: 100%|██████████| 69/69 [00:01<00:00, 37.91it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7246376811594203
Valid Loss = 0.6402413869897524
EarlyStopping counter: 5 out of 10


[7/20] Loss_D: 0.2228 Acc_D: 0.8031 Loss_G_class1: 0.0122 Loss_G_class2: 0.0124: 100%|██████████| 77/77 [01:50<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7101: 100%|██████████| 69/69 [00:01<00:00, 38.71it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7101449275362319
Valid Loss = 0.6446031386545603
EarlyStopping counter: 6 out of 10


[8/20] Loss_D: 0.2068 Acc_D: 0.8064 Loss_G_class1: 0.0113 Loss_G_class2: 0.0116: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7826: 100%|██████████| 69/69 [00:01<00:00, 38.06it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.782608695652174
Valid Loss = 0.5841134913481664
EarlyStopping counter: 7 out of 10


[9/20] Loss_D: 0.2180 Acc_D: 0.7900 Loss_G_class1: 0.0116 Loss_G_class2: 0.0124: 100%|██████████| 77/77 [01:50<00:00,  1.44s/it]
[Validating]: Acc_D: 0.7101: 100%|██████████| 69/69 [00:01<00:00, 36.72it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7101449275362319
Valid Loss = 0.6241113209854001
EarlyStopping counter: 8 out of 10


[10/20] Loss_D: 0.1983 Acc_D: 0.8235 Loss_G_class1: 0.0108 Loss_G_class2: 0.0113: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7246: 100%|██████████| 69/69 [00:01<00:00, 37.40it/s]
  0%|          | 0/77 [00:00<?, ?it/s]


Valid Acc = 0.7246376811594203
Valid Loss = 0.5339531193004813
EarlyStopping counter: 9 out of 10


[11/20] Loss_D: 0.1919 Acc_D: 0.8284 Loss_G_class1: 0.0096 Loss_G_class2: 0.0121: 100%|██████████| 77/77 [01:50<00:00,  1.43s/it]
[Validating]: Acc_D: 0.7391: 100%|██████████| 69/69 [00:01<00:00, 37.45it/s]
  0%|          | 0/69 [00:00<?, ?it/s]


Valid Acc = 0.7391304347826086
Valid Loss = 0.6657811182591579
EarlyStopping counter: 10 out of 10
Early stopping, epoch 11


[Validating]: Acc_D: 0.7101: 100%|██████████| 69/69 [00:01<00:00, 35.71it/s]


[INFO] Summary of training for LAMBDA = 0.05 (best model values)
Valid Acc = 0.7101449275362319
Valid Loss = 0.5732312286272645


from IPython.display import display
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from torch.autograd import Variable
import cv2


def sample_loader(sample):
    s = ToTensor()(x).float()
    s = Variable(s, requires_grad=False)
    s = s.unsqueeze(0)  
    return s

self = cit_model

for class_name in self._class_names:
    self._G_dict[class_name]= self._G_dict[class_name].to(self._device)


sample = test_data[40]
label = lb1.inverse_transform(test_label[0])[0]
x = ToTensor()(sample).float().unsqueeze(0).to(device)
class_name = 'P'
y = self._G_dict[class_name](x)
y = y[0].cpu().detach().numpy()
print(y.shape)
y = np.moveaxis(y, 0, -1)
#y = cv2.resize(y, dsize=(224, 224))
plt.imshow(y)
plt.imshow(y)


In [6]:
metrics = cit_model.evaluate(test_data, test_label)
print("CIT Classifier Results:")
print("Loss: {}".format(metrics[0]))
print("Acc: {}".format(metrics[1]))
print(metrics[2])

CIT Classifier Results:
Loss: 0.5907702616454042
Acc: 0.7426900584795322
              precision    recall  f1-score   support

           0    0.91111   0.69492   0.78846       118
           1    0.55556   0.84906   0.67164        53

    accuracy                        0.74269       171
   macro avg    0.73333   0.77199   0.73005       171
weighted avg    0.80091   0.74269   0.75225       171



In [7]:
t_train_data, t_train_label = cit_model.transform_data(train_data, train_label, lb1, lb2)
#t_test_data, t_test_label = cit_model.transform_data(test_data, test_label, lb1, lb2)

In [8]:
"""
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import ToPILImage

save_path = "../data/prueba-transformada/"

#image = t_train_data[0]

#r = np.max(image) - np.min(image)

#image = (image - np.min(image))/r

#label = lb2.inverse_transform(t_train_label[0])[0]
#path = save_path + str(label) + ".png"
#print(path)
#plt.imshow(image)

#plt.imshow(t_train_data[0])

new_t_train_data = copy.deepcopy(t_train_data)

for i in range(len(new_t_train_data)):
    image = new_t_train_data[0]
    new_t_train_data[i] = np.asarray(ToPILImage()(image))

print(new_t_train_data[0])
"""

'\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom torchvision.transforms import ToPILImage\n\nsave_path = "../data/prueba-transformada/"\n\n#image = t_train_data[0]\n\n#r = np.max(image) - np.min(image)\n\n#image = (image - np.min(image))/r\n\n#label = lb2.inverse_transform(t_train_label[0])[0]\n#path = save_path + str(label) + ".png"\n#print(path)\n#plt.imshow(image)\n\n#plt.imshow(t_train_data[0])\n\nnew_t_train_data = copy.deepcopy(t_train_data)\n\nfor i in range(len(new_t_train_data)):\n    image = new_t_train_data[0]\n    new_t_train_data[i] = np.asarray(ToPILImage()(image))\n\nprint(new_t_train_data[0])\n'

In [9]:
dict_labels = { 'PTP' : np.argmax(lb2.transform(['PTP'])[0]) , 'PTN' : np.argmax(lb2.transform(['PTN'])[0]) , 
                'NTP' : np.argmax(lb2.transform(['NTP'])[0]) , 'NTN' : np.argmax(lb2.transform(['NTN'])[0]), 
                'P' : lb1.transform(['P'])[0][0], 'N' : lb1.transform(['N'])[0][0]
              } 

print(dict_labels)

from ClassifierModel import ClassifierModel

classifier_model = classifier_builder(cit_model._G_dict)
classifier_model.train(t_train_data, t_train_label)

{'PTP': 3, 'PTN': 2, 'NTP': 1, 'NTN': 0, 'P': 1, 'N': 0}
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 00011: early stopping


In [11]:
metrics = classifier_model.evaluate(test_data, test_label)
print("SDNET Classifier Results:")
print("Acc: {}".format(metrics[0]))
print("Acc_4: {}".format(metrics[1]))
print("No concuerda: {}".format(metrics[2]))
print(metrics[3])


SDNET Classifier Results:
Acc: 0.5263157894736842
Acc_4: 0.672514619883041
No concuerda: 58
              precision    recall  f1-score   support

           0    0.77647   0.73333   0.75429        90
           1    0.72093   0.76543   0.74251        81

    accuracy                        0.74854       171
   macro avg    0.74870   0.74938   0.74840       171
weighted avg    0.75016   0.74854   0.74871       171



In [13]:
G_dict = cit_model._G_dict
dict_labels = { 'PTP' : np.argmax(lb2.transform(['PTP'])[0]) , 'PTN' : np.argmax(lb2.transform(['PTN'])[0]) , 
                    'NTP' : np.argmax(lb2.transform(['NTP'])[0]) , 'NTN' : np.argmax(lb2.transform(['NTN'])[0]) 
                } 
for key, _ in G_dict.items():
    G_dict[key].to("cpu")
classifier_model.get_classification_report(test_files, dict_labels, G_dict)

preds
['P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P'
 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P' 'P']
preds_4
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3

  _warn_prf(average, modifier, msg_start, len(result))
