In [None]:
from model2 import densenet121, densenet121Container
from sklearn import metrics
import torch
from datasets import CheXpert_dataset
from torch.utils import data
from utils import *
import matplotlib.pyplot as plt
import numpy as np
import logging

: 

In [None]:
logging.basicConfig(format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S', level=logging.DEBUG)
logger = logging.getLogger("analysis")

In [None]:
args_datadir = "./data/CheXpert-v1.0-small/"
print("Begin loading the weights")
approach = "_not_layerwise_3_averaged"
print("Approach: {}".format(approach))
fedavg_weights_no_comm = load_weights_from_folder("FedAvg_no_comm" + str(approach))
fedavg_weights_comm = load_weights_from_folder("FedAvg_comm" + str(approach))

fedma_weights_no_comm_help = load_weights_from_folder("FedMA_no_comm" + str(approach))
assignments = load_weights_from_folder("FedMA_no_comm_assignments" + str(approach))

fedma_weights_no_comm =[[] for i in range(16)]
for i in range(16):
    fedma_weights_no_comm[i] = match_global_to_local_weights_2(fedma_weights_no_comm_help, assignments, i, not_layerwise = True) 
    
fedma_weights_comm = load_weights_from_folder("FedMA_comm" + str(approach))
# fedma_weights_comm_unmatched = load_weights_from_folder("FedMA_comm" + str(approach) + "_unmatched")

In [None]:
print("Create a typical densenet121 architecture for FedAvg")
global_model_FedAvg = densenet121()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Create a dataloader for the validation dataset")
validation_set = CheXpert_dataset(args_datadir, valid=True, transform=True)
validation_dl = data.DataLoader(dataset=validation_set, batch_size=16 , shuffle=False, num_workers=16)

In [None]:
new_state_dict = {}
# starting with fedavg weights without communication

for param_idx, (key_name, param) in enumerate(global_model_FedAvg.state_dict().items()):
    if "num_batches_tracked" in key_name:
        continue
    if "conv" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx])}
    elif "norm" in key_name:
        temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx].reshape(param.size()))}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedAvg.load_state_dict(new_state_dict)

In [None]:
target = []
out = []
# there should be only one batch containing all of the data
for batch_idx, (x, target_b) in enumerate(validation_dl):
        target_b = handle_uncertainty_labels(target_b)
        target_b = handle_NaN_values(target_b)
        # target_b = torch.tensor(target_b)
        x, target_b = x.to(device), target_b.to(device)
        out_b = global_model_FedAvg(x)
        #print(out_b)
        target = np.append(target, target_b.tolist())
        out = np.append(out, out_b.tolist())

target = np.reshape(target, (-1, 14))
out = np.reshape(out, (-1, 14))

In [None]:
auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
# print("The global model with weights from FedAvg without communication has a total AUROC of: " + auroc)
print("The global model with weights from FedAvg without communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % auroc,
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedAvg-no_comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
observation_classes = [
        "No Finding",
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices"]

# Compute ROC curve and ROC area for each class
fpr = [[] for i in range(len(observation_classes))]
tpr = [[] for i in range(len(observation_classes))]
roc_auc = [[] for i in range(len(observation_classes))]
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

print("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + str(observation_classes[i]) + " (FedAvg-no_comm)")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
print("Create a densenet121 architecture for FedMA without communication for each layer")
global_model_FedMA_no_comm = [densenet121() for i in range (16)]

In [None]:
# continuing with FedMA weights without communication
for worker_index in range(16):
    new_state_dict = {}
    for param_idx, (key_name, param) in enumerate(global_model_FedMA_no_comm.state_dict().items()):
        if "num_batches_tracked" in key_name:
            continue
        if "conv" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[worker_index][param_idx].reshape(param.size()))}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[worker_index][param_idx])}
        elif "norm" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[worker_index][param_idx].reshape(param.size()))}
        elif "fc" in key_name or "classifier" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[worker_index][param_idx].T)}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[worker_index][param_idx])}

    new_state_dict.update(temp_dict)
    global_model_FedMA_no_comm[worker_index].load_state_dict(new_state_dict)

In [None]:
target = [[] for i in range(16)]
out = [[] for i in range(16)]
# there should be only one batch containing all of the data
for worker_index in range(16):
    for batch_idx, (x, target_b) in enumerate(validation_dl):
            target_b = handle_uncertainty_labels(target_b)
            target_b = handle_NaN_values(target_b)
            # target_b = torch.tensor(target_b)
            x, target_b = x.to(device), target_b.to(device)
            out_b = global_model_FedMA_no_comm[worker_index(x)
            #print(out_b)
            target[worker_index] = np.append(target[worker_index], target_b.tolist())
            out[worker_index] = np.append(out[worker_index], out_b.tolist())

    target[worker_index] = np.reshape(target[worker_index], (-1, 14))
    out[worker_index] = np.reshape(out[worker_index], (-1, 14))

    auroc = metrics.roc_auc_score(target[worker_index], out[worker_index])
    false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
    # logging.info("The local model {} with weights from FedMA without communication has a total AUROC of: {}".format(worker_index, auroc))
    print("The local model {} with weights from FedMA without communication has a total AUROC of: {}".format(worker_index, auroc))


    # Plot ROC curve and AUROC for the whole input
    plt.figure()
    lw = 2
    plt.plot(
        false_positive_rate,
        true_positive_rate,
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % auroc,
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic total (FedMA-no_comm) of model: " + str(worker_index))
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
for worker_index in range(16):
    fpr = [[] for i in range(len(observation_classes))]
    tpr = [[] for i in range(len(observation_classes))]
    roc_auc = [[] for i in range(len(observation_classes))]
    for i in range(len(observation_classes)):
        fpr[i], tpr[i], _ = metrics.roc_curve(target[worker_index][:, i], out[worker_index][:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])
        print("This is the AUROC-score of model {} for observation class '{}': {}".format(worker_index, observation_class[i], roc_auc[i]))

# print("Plot the ROC curve for each class")
# # Plot of a ROC curve for each class
# for i in range (len(observation_classes)):
#     plt.figure()
#     lw = 2
#     plt.plot(
#         fpr[i],
#         tpr[i],
#         color="darkorange",
#         lw=lw,
#         label="ROC curve (area = %0.2f)" % roc_auc[i],
#     )
#     plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
#     plt.xlim([0.0, 1.0])
#     plt.ylim([0.0, 1.05])
#     plt.xlabel("False Positive Rate")
#     plt.ylabel("True Positive Rate")
#     plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedMA-no_comm)")
#     plt.legend(loc="lower right")
#     plt.show()

In [None]:
print("Continuing with the weights resulting from communication")
print("Starting with FedAvg")
new_state_dict = {}
# starting with fedavg weights with communication
for param_idx, (key_name, param) in enumerate(global_model_FedAvg.state_dict().items()):
    if "num_batches_tracked" in key_name:
        continue
    if "conv" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}
    elif "norm" in key_name:
        temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].reshape(param.size()))}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedAvg.load_state_dict(new_state_dict)

In [None]:
target = []
out = []
# there should be only one batch containing all of the data
for batch_idx, (x, target_b) in enumerate(validation_dl):
        target_b = handle_uncertainty_labels(target_b)
        target_b = handle_NaN_values(target_b)
        # target_b = torch.tensor(target_b)
        x, target_b = x.to(device), target_b.to(device)
        out_b = global_model_FedAvg(x)
        #print(out_b)
        target = np.append(target, target_b.tolist())
        out = np.append(out, out_b.tolist())

target = np.reshape(target, (-1, 14))
out = np.reshape(out, (-1, 14))

auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
print("The global model with weights from FedAvg with communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % auroc,
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedAvg-comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

print("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedAvg-no_comm)")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
print("Create a densenet121 architecture for FedMA with communication for each layer")
global_model_FedMA_comm = [densenet121() for i in range (16)]

In [None]:
# continuing with FedMA weights with communication
for worker_index in range(16):
    new_state_dict = {}
    for param_idx, (key_name, param) in enumerate(global_model_FedMA_comm.state_dict().items()):
        if "num_batches_tracked" in key_name:
            continue
        if "conv" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_comm[worker_index][param_idx].reshape(param.size()))}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_comm[worker_index][param_idx])}
        elif "norm" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_comm[worker_index][param_idx].reshape(param.size()))}
        elif "fc" in key_name or "classifier" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_comm[worker_index][param_idx].T)}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(fedma_weights_comm[worker_index][param_idx])}

    new_state_dict.update(temp_dict)
    global_model_FedMA_comm[worker_index].load_state_dict(new_state_dict)

In [None]:
target = [[] for i in range(16)]
out = [[] for i in range(16)]
# there should be only one batch containing all of the data
for worker_index in range(16):
    for batch_idx, (x, target_b) in enumerate(validation_dl):
            target_b = handle_uncertainty_labels(target_b)
            target_b = handle_NaN_values(target_b)
            # target_b = torch.tensor(target_b)
            x, target_b = x.to(device), target_b.to(device)
            out_b = global_model_FedMA_comm[worker_index(x)
            #print(out_b)
            target[worker_index] = np.append(target[worker_index], target_b.tolist())
            out[worker_index] = np.append(out[worker_index], out_b.tolist())

    target[worker_index] = np.reshape(target[worker_index], (-1, 14))
    out[worker_index] = np.reshape(out[worker_index], (-1, 14))

    auroc = metrics.roc_auc_score(target[worker_index], out[worker_index])
    false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
    # logging.info("The local model {} with weights from FedMA without communication has a total AUROC of: {}".format(worker_index, auroc))
    print("The local model {} with weights from FedMA with communication has a total AUROC of: {}".format(worker_index, auroc))


    # Plot ROC curve and AUROC for the whole input
    plt.figure()
    lw = 2
    plt.plot(
        false_positive_rate,
        true_positive_rate,
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % auroc,
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic total (FedMA-comm) of model: " + str(worker_index))
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
for worker_index in range(16):
    fpr = [[] for i in range(len(observation_classes))]
    tpr = [[] for i in range(len(observation_classes))]
    roc_auc = [[] for i in range(len(observation_classes))]
    for i in range(len(observation_classes)):
        fpr[i], tpr[i], _ = metrics.roc_curve(target[worker_index][:, i], out[worker_index][:, i])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])
        print("This is the AUROC-score of model {} for observation class '{}': {}".format(worker_index, observation_class[i], roc_auc[i]))

# print("Plot the ROC curve for each class")
# # Plot of a ROC curve for each class
# for i in range (len(observation_classes)):
#     plt.figure()
#     lw = 2
#     plt.plot(
#         fpr[i],
#         tpr[i],
#         color="darkorange",
#         lw=lw,
#         label="ROC curve (area = %0.2f)" % roc_auc[i],
#     )
#     plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
#     plt.xlim([0.0, 1.0])
#     plt.ylim([0.0, 1.05])
#     plt.xlabel("False Positive Rate")
#     plt.ylabel("True Positive Rate")
#     plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedMA-comm)")
#     plt.legend(loc="lower right")
#     plt.show()