In [1]:
# perform inference on the test set using a trained model
import os
import torch
import torch.nn as nn
from fl_bench.chest.src.dm_test import ChestDataModuleTest
from fl_bench.chest.src.validator import Validator
from fl_bench.networks.chest_nets import *
import json
import numpy as np
from monai.transforms import Activations
from sklearn.metrics import multilabel_confusion_matrix
from tqdm import tqdm

%reload_ext autoreload
%autoreload 2

# make the directory path for the report with the optimal thresholds and the model
suf_model = "/simulate_job/app_server/FL_global_model.pt"
suf_report = "/simulate_job/cross_site_val/cross_val_results.json"
root_path = # path to the trained model directory
dm_pc = ChestDataModuleTest(
    data_dir=root_path,
    client_idx='client_padchest',
    cache_rate=1.0,
)
dm_cxr = ChestDataModuleTest(
    data_dir=root_path,
    client_idx='client_cxr14',
    cache_rate=1.0,
)
dm_cxp_young = ChestDataModuleTest(
    data_dir=root_path,
    client_idx='client_cxp_young',
    cache_rate=1.0,
)
dm_cxp_old = ChestDataModuleTest(
    data_dir=root_path,
    client_idx='client_cxp_old',
    cache_rate=1.0,
)
validator = Validator()

Loading dataset: 100%|██████████| 10000/10000 [00:28<00:00, 352.60it/s]
Loading dataset: 100%|██████████| 10000/10000 [00:28<00:00, 347.06it/s]
Loading dataset: 100%|██████████| 7797/7797 [00:27<00:00, 287.44it/s]
Loading dataset: 100%|██████████| 7203/7203 [00:16<00:00, 446.09it/s]


In [2]:
def get_test_auroc(arch, method, seed):
    model = eval(arch)(seed).cuda()

    target_model = ""
    model_path = target_model + suf_model

    # load the checkpoint
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])

    # perform inference on the test set
    pc_results = validator.run(model, dm_pc.test_dataloader)
    cxr_results = validator.run(model, dm_cxr.test_dataloader)
    cxp_young_results = validator.run(model, dm_cxp_young.test_dataloader)
    cxp_old_results = validator.run(model, dm_cxp_old.test_dataloader)

    return pc_results['mean_auroc'], cxr_results['mean_auroc'], cxp_young_results['mean_auroc'], cxp_old_results['mean_auroc']

def get_pfl_test_auroc(arch, method, seed):
    model = eval(arch)(seed).to('cuda:0')

    target_model = ""
    model_path_pc = ""
    model_path_cxr = ""
    model_path_cxp_young = ""
    model_path_cxp_old = ""

    # load the checkpoint
    checkpoint = torch.load(model_path_pc)
    model.load_state_dict(checkpoint['model'])
    pc_results = validator.run(model, dm_pc.test_dataloader)

    checkpoint = torch.load(model_path_cxr)
    model.load_state_dict(checkpoint['model'])
    cxr_results = validator.run(model, dm_cxr.test_dataloader)

    checkpoint = torch.load(model_path_cxp_young)
    model.load_state_dict(checkpoint['model'])
    cxp_young_results = validator.run(model, dm_cxp_young.test_dataloader)

    checkpoint = torch.load(model_path_cxp_old)
    model.load_state_dict(checkpoint['model'])
    cxp_old_results = validator.run(model, dm_cxp_old.test_dataloader)

    return pc_results['mean_auroc'], cxr_results['mean_auroc'], cxp_young_results['mean_auroc'], cxp_old_results['mean_auroc']

In [3]:
def get_3_seed_auroc(arch, method, seed_list, func=get_test_auroc):
    seed_dict = {}
    pc_metrics = []
    cxr_metrics = []
    cxp_young_metrics = []
    cxp_old_metrics = []
    for seed in tqdm(seed_list):
        seed_dict[seed] = func(arch, method, seed)
        pc_metrics.append(seed_dict[seed][0])
        cxr_metrics.append(seed_dict[seed][1])
        cxp_young_metrics.append(seed_dict[seed][2])
        cxp_old_metrics.append(seed_dict[seed][3])


    print(f"PC AUROC: \nMean: {100*np.mean(pc_metrics)}\nStd: {100*np.std(pc_metrics)}\n")
    print(f"CXR AUROC: \nMean: {100*np.mean(cxr_metrics)}\nStd: {100*np.std(cxr_metrics)}\n")
    print(f"CXP Young AUROC: \nMean: {100*np.mean(cxp_young_metrics)}\nStd: {100*np.std(cxp_young_metrics)}\n")
    print(f"CXP Old AUROC: \nMean: {100*np.mean(cxp_old_metrics)}\nStd: {100*np.std(cxp_old_metrics)}\n")
    print(f"Total AUROC: {100*np.mean([np.mean(pc_metrics), np.mean(cxr_metrics), np.mean(cxp_young_metrics), np.mean(cxp_old_metrics)])}")
    var1 = np.var(pc_metrics)
    var2 = np.var(cxr_metrics)
    var3 = np.var(cxp_young_metrics)
    var4 = np.var(cxp_old_metrics)
    std_total = np.sqrt((var1 + var2 + var3 + var4)/4)
    print(f"Total AUROC Std: {100*std_total}\n")

In [5]:
get_3_seed_auroc('resnet_50_supervised', 'fedprox', [42, 1995, 99])

Validation DataLoader: 100%|██████████| 100/100 [00:03<00:00, 32.10it/s]
Validation DataLoader: 100%|██████████| 100/100 [00:02<00:00, 33.35it/s]
Validation DataLoader: 100%|██████████| 78/78 [00:02<00:00, 31.88it/s]
Validation DataLoader: 100%|██████████| 73/73 [00:02<00:00, 31.30it/s]
Validation DataLoader: 100%|██████████| 100/100 [00:03<00:00, 32.55it/s]
Validation DataLoader: 100%|██████████| 100/100 [00:02<00:00, 33.63it/s]
Validation DataLoader: 100%|██████████| 78/78 [00:02<00:00, 32.81it/s]
Validation DataLoader: 100%|██████████| 73/73 [00:02<00:00, 32.78it/s]
Validation DataLoader: 100%|██████████| 100/100 [00:03<00:00, 32.44it/s]
Validation DataLoader: 100%|██████████| 100/100 [00:02<00:00, 33.35it/s]
Validation DataLoader: 100%|██████████| 78/78 [00:02<00:00, 32.54it/s]
Validation DataLoader: 100%|██████████| 73/73 [00:02<00:00, 32.47it/s]
100%|██████████| 3/3 [00:35<00:00, 11.96s/it]

PC AUROC: 
Mean: 89.10805468466182
Std: 0.03476624621040819

CXR AUROC: 
Mean: 84.37405418596526
Std: 0.0756157078910905

CXP Young AUROC: 
Mean: 79.12829581384672
Std: 0.1695321185951078

CXP Old AUROC: 
Mean: 75.965183348437
Std: 0.08457003941732374

Total AUROC: 82.1438970082277
Total AUROC Std: 0.10346455668226057






In [None]:
# get pfl auroc
get_3_seed_auroc('resnet_50_supervised', 'fedper', [42, 1995, 2024], func=get_pfl_test_auroc)

In [43]:
def get_metrics(model, dm, val_thresholds):

    transform_post = Activations(sigmoid=True)
    model.eval()
    metrics = {}
    y = []
    y_pred = []

    with torch.no_grad():
        for batch in dm.test_dataloader:
            batch["image"] = batch["image"].to("cuda:0")
            batch["label"] = batch["label"].to("cuda:0")
            batch["preds"] = model(batch["image"])
            batch["preds"] = transform_post(batch["preds"])
            y.append(batch["label"])
            y_pred.append(batch["preds"])

        y = torch.cat(y)
        y_pred = torch.cat(y_pred)

        y_pred_np = y_pred.numpy(force=True)
        y_np = y.numpy(force=True)

    bin_labels = (y_pred_np > val_thresholds).astype(np.int32)
    # Metrics calculation (macro) over the whole set
    total_cm = multilabel_confusion_matrix(y_true=y_np, y_pred=bin_labels)
    eps = 1e-7
    f1 = []
    accuracy = []
    for cls_cm in total_cm:
        TP = cls_cm[1, 1]
        TN = cls_cm[0, 0]
        FP = cls_cm[0, 1]
        FN = cls_cm[1, 0]
        f1.append(2 * TP / (2 * TP + FN + FP + eps))
        accuracy.append((TP + TN) / (TP + TN + FP + FN + eps))

    metrics["macro_f1_score"] = np.mean(f1)
    metrics["macro_accuracy"] = np.mean(accuracy)
    # print(metrics)
    return metrics


In [44]:
def get_multi_site_metrics(arch, method, seed):
    model = eval(arch)(seed).cuda()

    target_model = ""
    model_path = target_model + suf_model

    # load the checkpoint
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model'])
    report_path = target_model + suf_report
    # read the json report and find the optimal thresholds field
    with open(report_path, "r") as f:
        report = json.load(f)
        val_thresholds_pc = report['client_padchest']['SRV_server']["optimal_thresholds"]
        val_thresholds_cxr = report['client_cxr14']['SRV_server']["optimal_thresholds"]
        val_thresholds_cxp_young = report['client_cxp_young']['SRV_server']["optimal_thresholds"]
        val_thresholds_cxp_old = report['client_cxp_old']['SRV_server']["optimal_thresholds"]
    
    pc_metrics = get_metrics(model, dm_pc, val_thresholds_pc)
    cxr_metrics = get_metrics(model, dm_cxr, val_thresholds_cxr)
    cxp_young_metrics = get_metrics(model, dm_cxp_young, val_thresholds_cxp_young)
    cxp_old_metrics = get_metrics(model, dm_cxp_old, val_thresholds_cxp_old)
    return pc_metrics, cxr_metrics, cxp_young_metrics, cxp_old_metrics

In [48]:
def get_multi_pfl_metrics(arch, method, seed):
    model = eval(arch)(seed).cuda()
    target_model = f"{arch}_{method}_{seed}"
    report_path = target_model + suf_report
    # read the json report and find the optimal thresholds field
    with open(report_path, "r") as f:
        report = json.load(f)
        val_thresholds_pc = report['client_padchest']['client_padchest']["optimal_thresholds"]
        val_thresholds_cxr = report['client_cxr14']['client_cxr14']["optimal_thresholds"]
        val_thresholds_cxp_young = report['client_cxp_young']['client_cxp_young']["optimal_thresholds"]
        val_thresholds_cxp_old = report['client_cxp_old']['client_cxp_old']["optimal_thresholds"]

    model_path_pc = target_model + "/simulate_job/app_client_padchest/models/best_model.pt"
    model_path_cxr = target_model + "/simulate_job/app_client_cxr14/models/best_model.pt"
    model_path_cxp_young = target_model + "/simulate_job/app_client_cxp_young/models/best_model.pt"
    model_path_cxp_old = target_model + "/simulate_job/app_client_cxp_old/models/best_model.pt"

    # load the checkpoint
    checkpoint = torch.load(model_path_pc)
    model.load_state_dict(checkpoint['model'])
    pc_metrics = get_metrics(model, dm_pc, val_thresholds_pc)

    checkpoint = torch.load(model_path_cxr)
    model.load_state_dict(checkpoint['model'])
    cxr_metrics = get_metrics(model, dm_cxr, val_thresholds_cxr)

    checkpoint = torch.load(model_path_cxp_young)
    model.load_state_dict(checkpoint['model'])
    cxp_young_metrics = get_metrics(model, dm_cxp_young, val_thresholds_cxp_young)

    checkpoint = torch.load(model_path_cxp_old)
    model.load_state_dict(checkpoint['model'])
    cxp_old_metrics = get_metrics(model, dm_cxp_old, val_thresholds_cxp_old)

    return pc_metrics, cxr_metrics, cxp_young_metrics, cxp_old_metrics

In [49]:
def get_3_seed_metrics(arch, method, seed_list, func=get_multi_site_metrics):
    seed_dict = {}
    pc_f1_metrics = []
    cxr_f1_metrics = []
    cxp_young_f1_metrics = []
    cxp_old_f1_metrics = []
    cxr_acc_metrics = []
    pc_acc_metrics = []
    cxp_young_acc_metrics = []
    cxp_old_acc_metrics = []
    for seed in tqdm(seed_list):
        seed_dict[seed] = func(arch, method, seed)
        pc_f1_metrics.append(seed_dict[seed][0]['macro_f1_score'])
        cxr_f1_metrics.append(seed_dict[seed][1]['macro_f1_score'])
        cxp_young_f1_metrics.append(seed_dict[seed][2]['macro_f1_score'])
        cxp_old_f1_metrics.append(seed_dict[seed][3]['macro_f1_score'])
        pc_acc_metrics.append(seed_dict[seed][0]['macro_accuracy'])
        cxr_acc_metrics.append(seed_dict[seed][1]['macro_accuracy'])
        cxp_young_acc_metrics.append(seed_dict[seed][2]['macro_accuracy'])
        cxp_old_acc_metrics.append(seed_dict[seed][3]['macro_accuracy'])


    print(f"PC F1: \nMean: {100*np.mean(pc_f1_metrics)}\nStd: {100*np.std(pc_f1_metrics)}\n")
    print(f"CXR F1: \nMean: {100*np.mean(cxr_f1_metrics)}\nStd: {100*np.std(cxr_f1_metrics)}\n")
    print(f"CXP Young F1: \nMean: {100*np.mean(cxp_young_f1_metrics)}\nStd: {100*np.std(cxp_young_f1_metrics)}\n")
    print(f"CXP Old F1: \nMean: {100*np.mean(cxp_old_f1_metrics)}\nStd: {100*np.std(cxp_old_f1_metrics)}\n")
    print(f"Total F1: {100*np.mean([np.mean(pc_f1_metrics), np.mean(cxr_f1_metrics), np.mean(cxp_young_f1_metrics), np.mean(cxp_old_f1_metrics)])}")
    var1 = np.var(pc_f1_metrics)
    var2 = np.var(cxr_f1_metrics)
    var3 = np.var(cxp_young_f1_metrics)
    var4 = np.var(cxp_old_f1_metrics)
    std_total = np.sqrt((var1 + var2 + var3 + var4)/4)
    print(f"Total F1 Std: {100*std_total}\n")

    print(f"PC Acc: \nMean: {100*np.mean(pc_acc_metrics)}\nStd: {100*np.std(pc_acc_metrics)}\n")
    print(f"CXR Acc: \nMean: {100*np.mean(cxr_acc_metrics)}\nStd: {100*np.std(cxr_acc_metrics)}\n")
    print(f"CXP Young Acc: \nMean: {100*np.mean(cxp_young_acc_metrics)}\nStd: {100*np.std(cxp_young_acc_metrics)}\n")
    print(f"CXP Old Acc: \nMean: {100*np.mean(cxp_old_acc_metrics)}\nStd: {100*np.std(cxp_old_acc_metrics)}\n")

    print(f"Total Acc: {100*np.mean([np.mean(pc_acc_metrics), np.mean(cxr_acc_metrics), np.mean(cxp_young_acc_metrics), np.mean(cxp_old_acc_metrics)])}")
    # Grouped Std
    var1 = np.var(pc_acc_metrics)
    var2 = np.var(cxr_acc_metrics)
    var3 = np.var(cxp_young_acc_metrics)
    var4 = np.var(cxp_old_acc_metrics)
    std_total = np.sqrt((var1 + var2 + var3 + var4)/4)
    print(f"Total Acc Std: {100*std_total}")

In [77]:
get_3_seed_metrics('seresnet_50_supervised', 'scaffold', [99,1,42])

100%|██████████| 3/3 [01:43<00:00, 34.46s/it]

PC F1: 
Mean: 43.36098475939678
Std: 0.22224987498085474

CXR F1: 
Mean: 39.10883599227938
Std: 0.26266691840025225

CXP Young F1: 
Mean: 47.26411166448809
Std: 0.4783709901928217

CXP Old F1: 
Mean: 44.89886147960561
Std: 0.3575812474246077

Total F1: 43.65819847394247
Total F1 Std: 0.3446346143811912

PC Acc: 
Mean: 82.82374999917175
Std: 0.2484137174126345

CXR Acc: 
Mean: 78.17208333255161
Std: 0.5309134554864734

CXP Young Acc: 
Mean: 72.42796374581127
Std: 0.9392042631502389

CXP Old Acc: 
Mean: 71.3296543097136
Std: 0.7130748633492264

Total Acc: 76.18836284681205
Total Acc Std: 0.6584373320361347





In [None]:
# pfl metrics
get_3_seed_metrics('seresnet50_pretrained', 'fedper', [2024,1995,42], get_multi_pfl_metrics)