In [None]:
from model2 import densenet121, densenet121Container
from sklearn import metrics
import torch
from datasets import CheXpert_dataset
from torch.utils import data
from utils import handle_uncertainty_labels, load_weights
import matplotlib.pyplot as plt
import numpy as np
import logging

: 

In [None]:
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
args_datadir = "./data/CheXpert-v1.0-small/"
logger.info("Begin loading the weights")
fedavg_weights_no_comm = load_weights("FedAvg_no_comm")
fedavg_weights_comm = load_weights("FedAvg_comm")

fedma_weights_no_comm = load_weights("FedMA_no_comm")
fedma_weights_comm = load_weights("FedMA_comm")

In [None]:
logger.info("Create a typical densenet121 architecture for FedAvg")
global_model_FedAvg = densenet121()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

logger.info("Create a dataloader for the validation dataset")
validation_set = CheXpert_dataset(args_datadir, valid=True, transform=True)
validation_dl = data.DataLoader(dataset=validation_set, batch_size=validation_set.__len__(), shuffle=True)

In [None]:
new_state_dict = {}
# starting with fedavg weights without communication
for param_idx, (key_name, param) in enumerate(global_model_FedAvg.state_dict().items()):
    if "conv" in key_name or "features" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx])}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_no_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedAvg.load_state_dict(new_state_dict)

In [None]:
target, out, dataset = []
# there should be only one batch containing all of the data
for batch_idx, (x, target_b) in enumerate(validation_dl):
        target_b = handle_uncertainty_labels(target_b)
            
        x, target_b = x.to(device), target_b.to(device)
        out_b = global_model_FedAvg(x)

        dataset.append(x)
        target.append(target_b)
        out.append(out_b)

target = np.reshape(target, (-1, 14))
out = np.reshape(out, (-1, 14))

In [None]:
auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
logging.info("The global model with weights from FedAvg without communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % auroc,
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedAvg-no_comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
observation_classes = [
        "No Finding",
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices"]

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

logger.info("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedAvg-no_comm)")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
num_filters = []
# densenet121 hat insgesamt 241 conv oder norm Schichten und eine classifier Schicht
logger.info("Determining the number of filters for each layer of the FedMA model without communication")
for i in range (0, 242):
    num_filters.append(fedma_weights_no_comm[2*i].shape[0])

logger.info("Create a custom densenet121 architecture for FedMA without communication")
global_model_FedMA_no_comm = densenet121Container(num_filters)

In [None]:
new_state_dict = {}
# continuing with FedMA weights without communication
for param_idx, (key_name, param) in enumerate(global_model_FedMA_no_comm.state_dict().items()):
    if "conv" in key_name or "features" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[param_idx])}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedma_weights_no_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedMA_no_comm.load_state_dict(new_state_dict)

In [None]:
out = []
for x in dataset:
    out_b = global_model_FedMA_no_comm(x)
    out.append(out_b)

out = np.reshape(out, (-1, 14))

auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
logging.info("The global model with weights from FedMA without communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % roc_auc[2],
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedMA-no_comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

logger.info("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedMA-no_comm)")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
logger.info("Continuing with the weights resulting from communication")
logger.info("Starting with FedAvg")
new_state_dict = {}
# starting with fedavg weights with communication
for param_idx, (key_name, param) in enumerate(global_model_FedAvg.state_dict().items()):
    if "conv" in key_name or "features" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedAvg.load_state_dict(new_state_dict)

In [None]:
out = []
for x in dataset:
    out_b = global_model_FedAvg(x)
    out.append(out_b)

out = np.reshape(out, (-1, 14))

auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
logging.info("The global model with weights from FedAvg with communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % roc_auc[2],
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedAvg-comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

logger.info("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedAvg-no_comm)")
    plt.legend(loc="lower right")
    plt.show()

In [None]:
num_filters = []
# densenet121 hat insgesamt 241 conv oder norm Schichten und eine classifier Schicht
logger.info("Determining the number of filters for each layer of the FedMA model with communication")
for i in range (0, 242):
    num_filters.append(fedma_weights_comm[2*i].shape[0])

logger.info("Create a custom densenet121 architecture for FedMA without communication")
global_model_FedMA_comm = densenet121Container(num_filters)

In [None]:
new_state_dict = {}
# continuing with FedMA weights with communication
for param_idx, (key_name, param) in enumerate(global_model_FedMA_comm.state_dict().items()):
    if "conv" in key_name or "features" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].reshape(param.size()))}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}
    elif "fc" in key_name or "classifier" in key_name:
        if "weight" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx].T)}
        elif "bias" in key_name:
            temp_dict = {key_name: torch.from_numpy(fedavg_weights_comm[param_idx])}

    new_state_dict.update(temp_dict)
global_model_FedMA_comm.load_state_dict(new_state_dict)

In [None]:
out = []
for x in dataset:
    out_b = global_model_FedMA_no_comm(x)
    out.append(out_b)

out = np.reshape(out, (-1, 14))

auroc = metrics.roc_auc_score(target, out)
false_positive_rate, true_positive_rate, _ = metrics.roc_curve(target, out)
logging.info("The global model with weights from FedMA with communication has a total AUROC of: " + auroc)

# Plot ROC curve and AUROC for the whole input
plt.figure()
lw = 2
plt.plot(
    false_positive_rate,
    true_positive_rate,
    color="darkorange",
    lw=lw,
    label="ROC curve (area = %0.2f)" % roc_auc[2],
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic total (FedMA-comm)")
plt.legend(loc="lower right")
plt.show()

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(observation_classes)):
    fpr[i], tpr[i], _ = metrics.roc_curve(target[:, i], out[:, i])
    roc_auc[i] = metrics.auc(fpr[i], tpr[i])

logger.info("Plot the ROC curve for each class")
# Plot of a ROC curve for each class
for i in range (len(observation_classes)):
    plt.figure()
    lw = 2
    plt.plot(
        fpr[i],
        tpr[i],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc[i],
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic " + observation_classes[i] + " (FedMA-comm)")
    plt.legend(loc="lower right")
    plt.show()