In [1]:
import os
import subprocess
import matplotlib
matplotlib.use("Agg")
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import cv2

import shfl

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
args = {"data_path":"../data/COVIDGR1.0reducido/centralized/cropped-split", 
        "output_path": "../weights",
        "input_path": "",
        "model_name":"transferlearning.model", 
        "label_bin": "lb.pickle", 
        "batch_size": 8,
        "federated_rounds": 1,
        "epochs_per_FL_round": 1,
        "num_nodes": 2,
        "size_averaging": 1,
        "random_rotation": 0,
        "random_shift": 0, 
        "random_zoom": 0,
        "horizontal_flip": False,        
        "finetune": True,
        "train_network": True}

In [3]:
LABELS = ["N", "P"]
print("[INFO] training for labels: " + str(LABELS))
database = shfl.data_base.DatabaseFromDirectory(args["data_path"], height = 256, width = 256)
train_data, train_labels, test_data, test_labels = database.load_data()

print("[INFO] Number of train images: " + str(len(train_data)))
print("[INFO] Number of test images: " + str(len(test_data)))

print("[INFO] Distributing the train set across the nodes...")
iid_distribution = shfl.data_distribution.IidDataDistribution(database)
federated_data, test_data, test_label = iid_distribution.get_federated_data(num_nodes=args["num_nodes"])
print("[INFO] done")

[INFO] training for labels: ['N', 'P']
[INFO] Number of train images: 16
[INFO] Number of test images: 4
[INFO] Distributing the train set across the nodes...
[INFO] done


In [4]:
from CITModel import CITModel
import torch

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
def model_builder():    
    return CITModel(LABELS, classifier_name = "resnet18", lambda_value = 0.00075, batch_size=8, epochs=1, device=device)

In [5]:
aggregator = shfl.federated_aggregator.FedAvgAggregator()
federated_government = shfl.federated_government.FederatedGovernment(model_builder, federated_data, aggregator)
federated_government.run_rounds(args["federated_rounds"], test_data, test_label)

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
[INFO] FREEZING
Accuracy round 0
[INFO] weights = [1. 1.]


[0/1] Loss_D: 0.4488 Acc_D: 0.5000 Loss_G_class1: 10020.9307 Loss_G_class2: 6248.4468: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] weights = [1. 1.]


[0/1] Loss_D: 0.4074 Acc_D: 0.5000 Loss_G_class1: 9437.0664 Loss_G_class2: 5697.4790: 100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f9f35898390>: [0.6873142719268799, 0.5]
Test performance client <shfl.private.federated_operation.FederatedDataNode object at 0x7f9f358982e8>: [0.6861069202423096, 0.5]
Global model test performance : [0.6813296675682068, 0.5]





In [11]:
from shfl.private import UnprotectedAccess
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torch.autograd import Variable
import matplotlib.pyplot as plt
from IPython.display import display
from PIL import Image
import cv2
from sklearn.preprocessing import LabelBinarizer

import copy


def sample_loader(sample):
    loader = transforms.Compose([transforms.ToTensor()])
    s = loader(sample).float()
    s = Variable(s, requires_grad=False)
    s = s.unsqueeze(0)  #this is for VGG, may not be needed for ResNet
    return s.to(device) #assumes that you're using GPU

def show_img(G_dict, sample):
    x = sample_loader(sample)
    class_name = LABELS[0]
    y = G_dict[class_name](x)
    y = ToPILImage()(y[0].cpu().detach())
    #y.save("./prueba.png")
    display(y)
    
def transform_data(G_dict, class_names, data, labels):
    new_labels = []
    new_data = []
    for i in range(len(data)):
        sample = data[i]
        label = labels[i]
        x = sample_loader(sample)
        for class_name in class_names:
            y = G_dict[class_name](x)
            y = y[0].cpu().detach().numpy()
            new_data.append(y)
            new_label = str(label) + "T" + class_name
            new_labels.append(new_label)
    
    lb = LabelBinarizer()
    new_labels = lb.fit_transform(new_labels)
    
    return np.asarray(new_data), np.asarray(new_labels)


G_dict = federated_government.global_model._G_dict
for class_name in LABELS:
    G_dict[class_name]= G_dict[class_name].to(device)
federated_data.configure_data_access(UnprotectedAccess())

new_federated_data = copy.deepcopy(federated_data)

for i in range(federated_data.num_nodes()):
    data_node = federated_data[i]
    new_data_node = new_federated_data[i]
    data = data_node.query()._data
    labels = data_node.query()._label
    new_data, new_labels = transform_data(G_dict, LABELS, data, labels)
    new_data_node.query()._data = new_data
    new_data_node.query()._label = new_labels
    
    print(data_node.query()._label)
    print(new_data_node.query()._label)
    #data_node.query()._data = new_data
    #data_node.query()._label = new_labels




[[0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]]
[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]
 [1 0 0 0]
 [0 1 0 0]
 [1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]
 [1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]
 [0 0 1 0]
 [0 0 0 1]]
[[1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]]
[[0 0 1 0]
 [0 0 0 1]
 [0 0 1 0]
 [0 0 0 1]
 [1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]
 [1 0 0 0]
 [0 1 0 0]
 [1 0 0 0]
 [0 1 0 0]
 [1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]]
