In [1]:
import numpy as np
import random
import torch
import torchvision
from torchvision.datasets import VOCSegmentation
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import matplotlib.pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_unique_classes(dataset):
    unique_classes = set()
    for img, mask in dataset:
        classes = list(mask.getcolors())
        for count, pixel_value in classes:
            unique_classes.add(pixel_value)
    
    return unique_classes

# 데이터셋 다운로드 및 생성
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor()
])

mask_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=Image.NEAREST), 
    torchvision.transforms.ToTensor()
])

voc_dataset = VOCSegmentation(root='~/.data', year='2012', image_set='val', download=False, transform=transforms, target_transform=mask_transforms)
voc_dataloader = torch.utils.data.DataLoader(voc_dataset, batch_size=200, shuffle=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def mask_img_to_target(mask, num_classes=20):
    # get colors
    mask = to_pil_image(mask)
    colors = list(mask.getcolors())
    torch_target = torch.zeros(num_classes)
    for count, pixel_value in colors:
        if pixel_value in [0, 255]: # 0 is background, 255 is border,
            continue
        torch_target[pixel_value - 1] = 1
    return torch_target

def masks_to_targets(masks, num_classes=20):
    targets = []
    for mask in masks:
        targets.append(mask_img_to_target(mask, num_classes))
    return torch.stack(targets)

images, masks = iter(voc_dataloader).next()
labels = masks_to_targets(masks)

In [3]:
# def display_images(dataset, num_images=5):
#     fig, ax = plt.subplots(num_images, 3, figsize=(5, num_images * 2))
#     for i in range(num_images):
#         # 이미지 및 마스크 가져오기
#         img, mask = dataset[i]

#         # 이미지 및 마스크 표시
#         ax[i, 0].imshow(img.permute(1, 2, 0))
#         ax[i, 0].set_title(f"Image {i + 1}")
#         ax[i, 1].imshow(mask.permute(1, 2, 0), cmap='gray')
#         ax[i, 1].set_title(f"Segmentation Image {i + 1}")
#         # overlay segmentation mask on image
#         ax[i, 2].imshow(img.permute(1, 2, 0))
#         ax[i, 2].imshow(mask.permute(1, 2, 0), alpha=0.4, cmap='gray')
#         ax[i, 2].set_title(f"Overlay {i + 1}")
        
#     plt.tight_layout()
#     plt.show()

# display_images(voc_dataset)
# def display_images(images, masks):
#     num_images = len(images)
#     print(num_images)
#     fig, ax = plt.subplots(num_images, 3, figsize=(5, num_images * 2))
#     for i in range(num_images):
#         # 이미지 및 마스크 가져오기
#         img = images[i].permute(1, 2, 0)
#         mask = masks[i].permute(1, 2, 0)

#         # 이미지 및 마스크 표시
#         ax[i, 0].imshow(img)
#         ax[i, 0].set_title(f"Image {i + 1}")
#         ax[i, 1].imshow(mask, cmap='gray')
#         ax[i, 1].set_title(f"Segmentation Image {i + 1}")
#         # overlay segmentation mask on image
#         ax[i, 2].imshow(img)
#         ax[i, 2].imshow(mask, alpha=0.4, cmap='gray')
#         ax[i, 2].set_title(f"Overlay {i + 1}")
    
#     plt.tight_layout()
#     plt.savefig('test.png')
#     plt.show()
    
# # 데이터셋에서 이미지 및 마스크 표시
# display_images(images, masks)

In [4]:
# %%
import sys
sys.path.append('..')
import utils
import os
import pathlib
import argparse
from tensorboardX import SummaryWriter
import logging
from datetime import datetime
import torch 
import mymodels 
import mydataset 
from torch.utils.data import DataLoader
from utils.myfed import *
import yaml
# %%

In [5]:
yamlfilepath = pathlib.Path.cwd().parent.joinpath('config.yaml')
args = yaml.load(yamlfilepath.open('r'), Loader=yaml.FullLoader)
args = argparse.Namespace(**args)
args.datapath = "~/.data"
args.batchsize = 200
os.environ['CUDA_VISIBLE_DEVICES']=args.gpu
# 1. data
args.datapath = os.path.expanduser(args.datapath)

if args.dataset == 'cifar10':
    publicdata = 'cifar100'
    args.N_class = 10
elif args.dataset == 'cifar100':
    publicdata = 'imagenet'
    args.N_class = 100
elif args.dataset == 'pascal_voc2012':
    publicdata = 'mscoco'
    args.N_class = 20

In [6]:
args.model_name = 'vit_tiny_patch16_224'
net = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class)
net 

DataParallel(
  (module): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
          (act): GELU(approximate=none)
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=768, out_features=192, bias=True)
          (drop2): Dropout(p=0.0, inplace=False)
   

#### test

In [21]:
import copy
model1 = copy.deepcopy(net)
utils.load_dict('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt', model1)
model2 = copy.deepcopy(net)
utils.load_dict('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slNone/oneshot_c1_q0.0_n0.0_h0/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6563.pt', model2)
model1 = copy.deepcopy(net)
utils.load_dict('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_only/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h2/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e10_0.5234.pt', model1)
model2 = copy.deepcopy(net)
# utils.load_dict('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_singlelabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e20_0.76.pt', model2)
utils.load_dict('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_only/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h1/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e10_0.5270.pt', model2)


#### dicescore

In [8]:
VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
# multi label to multi captions
def multi_label_to_multi_captions(labels):
    # if already convert to captions
    if isinstance(labels[0][0], str):
        return labels
    
    captions = []
    for label in labels:
        caption = []
        for i in range(len(label)):
            if label[i] == 1:
                caption.append(VOC_CLASSES[i])
        captions.append(caption)
    return captions

labels = multi_label_to_multi_captions(labels)
labels

[['boat'],
 ['bird', 'bus'],
 ['chair', 'diningtable', 'pottedplant'],
 ['chair', 'person', 'tvmonitor'],
 ['bottle', 'person', 'tvmonitor'],
 ['train'],
 ['cat', 'chair'],
 ['bottle', 'diningtable', 'person'],
 ['bird'],
 ['sheep'],
 ['car'],
 ['sheep'],
 ['bus'],
 ['motorbike', 'person'],
 ['boat'],
 ['pottedplant', 'sofa'],
 ['dog', 'person'],
 ['horse'],
 ['bus', 'car', 'person'],
 ['diningtable', 'person'],
 ['car'],
 ['bicycle', 'person'],
 ['bicycle', 'pottedplant'],
 ['pottedplant'],
 ['aeroplane'],
 ['pottedplant', 'tvmonitor'],
 ['bus'],
 ['tvmonitor'],
 ['cat'],
 ['dog'],
 ['chair', 'sofa'],
 ['sheep'],
 ['horse'],
 ['chair', 'diningtable'],
 ['motorbike', 'person'],
 ['train'],
 ['tvmonitor'],
 ['bicycle', 'car', 'person'],
 ['bus', 'car', 'person'],
 ['sofa'],
 ['bus'],
 ['chair', 'person'],
 ['diningtable', 'person'],
 ['bicycle'],
 ['bus'],
 ['car', 'motorbike', 'person'],
 ['aeroplane'],
 ['cat', 'dog', 'sofa', 'tvmonitor'],
 ['chair'],
 ['cat'],
 ['aeroplane'],
 ['hors

In [None]:
grad_cam_images = []
pred_labels = []
for model in models:
    grad_cam_images.append(model.module.get_class_activation_map(images, labels))
    m = torch.nn.Sigmoid()
    th = 0.3
    outputs = m(model(images)).detach().cpu().numpy()
    outputs[outputs > th] = 1
    outputs[outputs <= th] = 0
    pred = multi_label_to_multi_captions(outputs)
    pred_labels.append(pred)

grad_cam_images = torch.stack([torch.tensor(grad_cam_images[i]) for i in range(len(grad_cam_images))])
grad_cam_images.shape # n_clients * b * 224 * 224

In [9]:
# -----------------------------------------------
def getAccuracy(y_true, y_pred):
    temp = 0
    for i in range(y_true.shape[0]):
        temp += sum(np.logical_and(y_true[i], y_pred[i])) / sum(np.logical_or(y_true[i], y_pred[i]))
    return temp / y_true.shape[0]

def get_Hamming_Loss(y_true, y_pred):
    temp=0
    for i in range(y_true.shape[0]):
        temp += np.size(y_true[i] == y_pred[i]) - np.count_nonzero(y_true[i] == y_pred[i])
    return temp/(y_true.shape[0] * y_true.shape[1])

def getPrecision(y_true, y_pred):
    temp = 0
    for i in range(y_true.shape[0]):
        if sum(y_true[i]) == 0:
            continue
        temp+= sum(np.logical_and(y_true[i], y_pred[i]))/ sum(y_true[i])
    return temp/ y_true.shape[0]

def getRecall(y_true, y_pred):
    temp = 0
    for i in range(y_true.shape[0]):
        if sum(y_pred[i]) == 0:
            continue
        temp+= sum(np.logical_and(y_true[i], y_pred[i]))/ sum(y_pred[i])
    return temp/ y_true.shape[0]

def getF1score(y_true, y_pred):
    temp = 0
    for i in range(y_true.shape[0]):
        if (sum(y_true[i]) == 0) and (sum(y_pred[i]) == 0):
            continue
        temp+= (2*sum(np.logical_and(y_true[i], y_pred[i])))/ (sum(y_true[i])+sum(y_pred[i]))
    return temp/ y_true.shape[0]

def getMetrics(y_true, y_score, th):
    y_pred = (y_score > th).astype(int)
    acc = getAccuracy(y_true, y_pred)
    pre = getPrecision(y_true, y_pred)
    rec = getRecall(y_true, y_pred)
    f1 = getF1score(y_true, y_pred)
    return acc, pre, rec, f1

def accuracyforsinglelabel(output, target, topk=(1,)):
    output = torch.tensor(output)
    target = torch.tensor(target)
    if len(output.shape) == 2:
        predicted = output.argmax(dim=1)
    if len(target.shape) == 2:
        target = target.argmax(dim=1)
    
    total, correct = 0, 0
    total += target.size(0)
    correct += predicted.eq(target).sum().item()
    return correct / total

def accuracy(output, target, topk=(1,)):
    """
    usage:
    prec1,prec5=accuracy(output,target,topk=(1,5))
    """
    # print(output.shape, target.shape)
    
    th_ls = [0.1 * i for i in range(10)]
    opt_th = 0
    best_acc = 0
    for th in th_ls:
        acc, pre, rec, f1 = getMetrics(target, output, th)
        if acc > best_acc:
            best_acc = acc
            opt_th = th

    acc, pre, rec, f1 = getMetrics(target, output, opt_th)
    print(f"opt_th: {opt_th:.2f}, best_acc: {best_acc:.2f}, pre: {pre:.2f}, rec: {rec:.2f}, f1: {f1:.2f}")
        
    res = []
    for k in topk:
        res.append(acc)
    return res

In [11]:
grad_cam_image = model1.module.get_class_activation_map(images, labels) 
grad_cam_image2 = model2.module.get_class_activation_map(images, labels)


In [47]:
def get_pred_label(model, images, th=0.3):
    m = torch.nn.Sigmoid()
    outputs = m(model(images)).detach().cpu().numpy()
    outputs[outputs > th] = 1
    outputs[outputs <= th] = 0
    pred = multi_label_to_multi_captions(outputs)
    return pred
pred_label1 = get_pred_label(model1, images, th=0.3)
pred_label2 = get_pred_label(model2, images, th=0.3)

In [48]:
len(images), len(images)

(200, 200)

In [49]:
pred_label1 == pred_label2
test_score1 = [label == pred for label, pred in zip(labels, pred_label1)]
test_score2 = [label == pred for label, pred in zip(labels, pred_label2)]
import pandas as pd
df = pd.DataFrame({'test_score1': test_score1, 'test_score2': test_score2})
df.apply(pd.value_counts)

Unnamed: 0,test_score1,test_score2
False,108,108
True,92,92


In [None]:
correct_dict = {}
for idx, (true_label, pred_label) in enumerate(zip(labels, pred_label1)):
    correct_clients_count = client_pred[idx] == true_label for client_pred in pred_labels)
    correct_central = pred_label == true_label
    correct_dict[idx] = (correct_clients_count, correct_central)

# correct_dict to dataframe
import pandas as pd
df = pd.DataFrame.from_dict(correct_dict, orient='index', columns=['correct_clients_count', 'correct_central'])
df = df.pivot_table(index='correct_clients_count', columns='correct_central', aggfunc=len, fill_value=0)
df.T

In [12]:
%matplotlib inline
def dice_score(y_pred, y_true, smooth=1):
    if torch.max(y_pred) > 1:
        print("y_pred should be in range [0, 1]")
    if torch.max(y_true) > 1:
        print("y_true should be in range [0, 1]")
    y_pred = y_pred.float()
    y_true = y_true.float()
    dice_loss = (2 * (y_pred * y_true).sum() + smooth) / ((y_pred + y_true).sum() + smooth)
    return dice_loss

def calculate_dice_score(grad_cam_image, masks):
    if not isinstance(grad_cam_image, torch.Tensor):
        grad_cam_image = torch.tensor(grad_cam_image)
        
    dice_scores = []
    for i in range(10):
        print("mean, median of grad_cam_image: ", torch.mean(grad_cam_image[i]), torch.median(grad_cam_image[i])) 
        central_grad_cam = torch.tensor(grad_cam_image[i] > torch.mean(grad_cam_image[i])).float()
        mask_img = masks[i].unsqueeze(0).cpu() > 0
        ds = dice_score(central_grad_cam, mask_img)
        dice_scores.append(ds)
    print(dice_scores)
    return dice_scores

dice_scores = calculate_dice_score(grad_cam_image, masks)
dice_scores2 = calculate_dice_score(grad_cam_image2, masks)
print("mean of dice score: ", torch.mean(torch.tensor(dice_scores)), torch.mean(torch.tensor(dice_scores2)))

def getThresholdImages(grad_cam_image):
    if not isinstance(grad_cam_image, torch.Tensor):
        grad_cam_image = torch.tensor(grad_cam_image)
        
    threshold_images = []
    for i in range(len(grad_cam_image)):
        threshold_images.append(torch.tensor(grad_cam_image[i] > torch.median(grad_cam_image[i])).float())
    return threshold_images
threshold_images = getThresholdImages(grad_cam_image)
threshold_images2 = getThresholdImages(grad_cam_image2)

def drawplots(images, masks, grad_cam_images, threshold_images):
    length = len(images)
    fig, ax = plt.subplots(length, 4, figsize=(20, 20))
    for i in range(length):
        ax[i, 0].imshow(images[i].permute(1, 2, 0))
        ax[i, 0].set_title(f"Image {i + 1}")
        ax[i, 1].imshow(masks[i].permute(1, 2, 0), alpha=0.4, cmap='gray')
        ax[i, 1].set_title(f"Mask {i + 1}")
        ax[i, 2].imshow(grad_cam_images[i])
        ax[i, 2].set_title(f"Grad CAM {i + 1}")
        ax[i, 3].imshow(threshold_images[i], alpha=0.4, cmap='gray')
        ax[i, 3].set_title(f"Threshold {i + 1}")
    plt.show()
# drawplots(images, masks, grad_cam_image, threshold_images)


mean, median of grad_cam_image:  tensor(0.3056) tensor(0.3184)
mean, median of grad_cam_image:  tensor(0.2437) tensor(0.2030)
mean, median of grad_cam_image:  tensor(0.2055) tensor(0.1121)
mean, median of grad_cam_image:  tensor(0.1670) tensor(0.1316)
mean, median of grad_cam_image:  tensor(0.2151) tensor(0.1222)
mean, median of grad_cam_image:  tensor(0.2624) tensor(0.2734)
mean, median of grad_cam_image:  tensor(0.3220) tensor(0.2996)
mean, median of grad_cam_image:  tensor(0.2489) tensor(0.1058)
mean, median of grad_cam_image:  tensor(0.2290) tensor(0.2169)
mean, median of grad_cam_image:  tensor(0.3118) tensor(0.3337)
[tensor(0.1493), tensor(0.7580), tensor(0.7383), tensor(0.5802), tensor(0.0545), tensor(0.0784), tensor(0.3480), tensor(0.7645), tensor(0.0073), tensor(0.4655)]
mean, median of grad_cam_image:  tensor(0.3234) tensor(0.2822)
mean, median of grad_cam_image:  tensor(0.3267) tensor(0.3311)
mean, median of grad_cam_image:  tensor(0.1800) tensor(0.0424)
mean, median of grad

  central_grad_cam = torch.tensor(grad_cam_image[i] > torch.mean(grad_cam_image[i])).float()
  threshold_images.append(torch.tensor(grad_cam_image[i] > torch.median(grad_cam_image[i])).float())


In [22]:
mha, th = model1.module.get_attention_maps_postprocessing_(images.cuda())
if not isinstance(mha, torch.Tensor):
    mha = torch.tensor(mha)
mha_agg = torch.max(mha, dim=1)[0]
print(mha_agg.shape)
mha2, th2 = model2.module.get_attention_maps_postprocessing_(images.cuda())
if not isinstance(mha2, torch.Tensor):
    mha2 = torch.tensor(mha2)
mha_agg2 = torch.max(mha2, dim=1)[0]
print(mha_agg2.shape)


mha_dice_scores = calculate_dice_score(mha_agg, masks)
mha_dice_scores2 = calculate_dice_score(mha_agg2, masks)

mha_threshold_images = getThresholdImages(mha_agg)
mha_threshold_images2 = getThresholdImages(mha_agg2)

torch.Size([200, 224, 224])
torch.Size([200, 224, 224])
mean, median of grad_cam_image:  tensor(0.0034) tensor(0.0019)
mean, median of grad_cam_image:  tensor(0.0051) tensor(0.0047)
mean, median of grad_cam_image:  tensor(0.0046) tensor(0.0028)
mean, median of grad_cam_image:  tensor(0.0059) tensor(0.0030)
mean, median of grad_cam_image:  tensor(0.0056) tensor(0.0037)
mean, median of grad_cam_image:  tensor(0.0051) tensor(0.0023)
mean, median of grad_cam_image:  tensor(0.0065) tensor(0.0047)
mean, median of grad_cam_image:  tensor(0.0036) tensor(0.0021)
mean, median of grad_cam_image:  tensor(0.0049) tensor(0.0017)
mean, median of grad_cam_image:  tensor(0.0053) tensor(0.0027)
[tensor(0.4080), tensor(0.8271), tensor(0.6360), tensor(0.5680), tensor(0.5881), tensor(0.5388), tensor(0.5395), tensor(0.5606), tensor(0.4489), tensor(0.8326)]
mean, median of grad_cam_image:  tensor(0.0048) tensor(0.0018)
mean, median of grad_cam_image:  tensor(0.0063) tensor(0.0051)
mean, median of grad_cam_im

  central_grad_cam = torch.tensor(grad_cam_image[i] > torch.mean(grad_cam_image[i])).float()
  threshold_images.append(torch.tensor(grad_cam_image[i] > torch.median(grad_cam_image[i])).float())


In [26]:
print('mha_dice_scores:', torch.mean(torch.tensor(mha_dice_scores)), torch.mean(torch.tensor(mha_dice_scores2)))

mha_dice_scores: tensor(0.5948) tensor(0.5356)


In [None]:
drawplots(images, masks, mha_agg, mha_threshold_images)

In [None]:
n = 0
loadname = os.path.join("/home/suncheol/code/VFL/FedMAD/checkpoints_backup/pascal_voc2012/a1.0+sd1+e300+b16+lkl", str(n)+'.pt')
if os.path.exists(loadname):
    localmodels = torch.load(loadname)
    #self.localmodels[n].load_state_dict(self.best_statdict, strict=True)
    logging.info(f'Loading Local{n}......')
    print('filepath : ', loadname)
    utils.load_dict(loadname, net)

In [None]:
loadname = os.path.join("/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/a1.0+sd1+e300+b16+lkl/model-0.pth")
loadname = os.path.join("/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/a1.0+sd1+e300+b128+lkl+slmha/oneshot_c1_q0.0_n0.0/q0.0_n0.0_ADAM_b128_5e-05_200_5e-05_m0.9_e10_0.66.pt")
# loadname = os.path.join("/home/suncheol/code/FedTest/pytorch-models/checkpoint/pascal_voc_vit_tiny_patch16_224_0.0001_-1/ckpt.pth")
if os.path.exists(loadname):
    localmodels = torch.load(loadname)
    #self.localmodels[n].load_state_dict(self.best_statdict, strict=True)
    logging.info(f'Loading Local......')
    print('filepath : ', loadname)
    utils.load_dict(loadname, net)

In [None]:
import copy
models = []
for i in range(0, 5):
    model = copy.deepcopy(net)
    loadname = os.path.join(f"/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/a1.0+sd1+e300+b128+lkl+slmha/model-{i}.pth")
    # loadname = os.path.join(f"/home/suncheol/code/FedTest/pytorch-models/checkpoint/pascal_voc_vit_tiny_patch16_224_0.0001_{i}/ckpt.pth")
    if os.path.exists(loadname):
        localmodels = torch.load(loadname)
        #self.localmodels[n].load_state_dict(self.best_statdict, strict=True)
        logging.info(f'Loading Local......', 'filepath : ', loadname)
        utils.load_dict(loadname, model)
    models.append(model)

In [None]:
VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
# multi label to multi captions
def multi_label_to_multi_captions(labels):
    captions = []
    for label in labels:
        caption = []
        for i in range(len(label)):
            if label[i] == 1:
                caption.append(VOC_CLASSES[i])
        captions.append(caption)
    return captions
labels = multi_label_to_multi_captions(labels)

In [None]:
grad_cam_images = []
pred_labels = []
for model in models:
    grad_cam_images.append(model.module.get_class_activation_map(images, labels))
    m = torch.nn.Sigmoid()
    th = 0.3
    outputs = m(model(images)).detach().cpu().numpy()
    outputs[outputs > th] = 1
    outputs[outputs <= th] = 0
    pred = multi_label_to_multi_captions(outputs)
    pred_labels.append(pred)

grad_cam_images = torch.stack([torch.tensor(grad_cam_images[i]) for i in range(len(grad_cam_images))])
grad_cam_images.shape # n_clients * b * 224 * 224

# grayscale_cam = net.module.get_class_activation_map(images, labels)
central_model = copy.deepcopy(net)
central_grad_cam_image = central_model.module.get_class_activation_map(images, labels)
m = torch.nn.Sigmoid()
th = 0.3
outputs = m(central_model(images)).detach().cpu().numpy()
outputs[outputs > th] = 1
outputs[outputs <= th] = 0
central_pred_labels = multi_label_to_multi_captions(outputs)


In [None]:
union_cam = torch.max(grad_cam_images, dim=0)[0]
intersection_cam = torch.min(grad_cam_images, dim=0)[0]
union_cam.cpu().shape

In [None]:
%matplotlib inline
def get_border_color(true_label, pred_label):
    if pred_label == true_label:
        return 'lime'  # green
    elif set(pred_label) & set(true_label):
        return 'gold'  # yellow
    else:
        return 'red'

row = 4
col = 9
clients = 5
extra_plots = 3  # union, intersection, global

import matplotlib.pyplot as plt
plt.figure(figsize=(3 * col, 3 * row))

for j in range(row):
    true_label = labels[j]
    for i in range(clients + extra_plots + 1):
        ax = plt.subplot(row, col, j * col + i + 1)
        # print(i, j)
        if i == 0:
            img = images[j].cpu().permute(1, 2, 0)
            mask = masks[j].permute(1, 2, 0).cpu()
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title(f'GT: {true_label}')
        elif i <= clients:
            img = grad_cam_images[i - 1].cpu()[j]
            pred_label = pred_labels[i - 1][j]
            border_color = get_border_color(true_label, pred_label)
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            # set border color
            ax.patch.set_edgecolor(border_color)
            ax.patch.set_linewidth(5)
            plt.title(f'client{i}: {pred_label}')
        elif i == clients + 1:
            img = union_cam.cpu()[j]
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title('union')
        elif i == clients + 2:
            img = intersection_cam.cpu()[j]
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title('intersection')
        elif i == clients + 3:
            img = central_grad_cam_image[j]
            border_color = get_border_color(true_label, central_pred_labels[j])
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title(f'global: {central_pred_labels[j]}')
            ax.patch.set_edgecolor(border_color)
            ax.patch.set_linewidth(5)
        # plt.gca().set_xticks([])
        # plt.gca().set_yticks([])
        # plt.gca().spines['top'].set_visible(False)
        # plt.gca().spines['right'].set_visible(False)
        # plt.gca().spines['bottom'].set_visible(False)
        # plt.gca().spines['left'].set_visible(False)

plt.savefig('grad_cam.png')
plt.show()
# plt.tight_layout()

In [None]:
correct_dict = {}
for idx, (true_label, central_pred) in enumerate(zip(labels, central_pred_labels)):
    correct_clients_count = sum(client_pred[idx] == true_label for client_pred in pred_labels)
    correct_central = central_pred == true_label
    correct_dict[idx] = (correct_clients_count, correct_central)

# correct_dict to dataframe
import pandas as pd
df = pd.DataFrame.from_dict(correct_dict, orient='index', columns=['correct_clients_count', 'correct_central'])
df = df.pivot_table(index='correct_clients_count', columns='correct_central', aggfunc=len, fill_value=0)
df.T

In [None]:
correct_dict

In [None]:
def find_wrong_clients_correct_central(true_labels, client_pred_labels, central_pred_labels):
    result_indices = []
    for idx, (true_label, central_pred) in enumerate(zip(true_labels, central_pred_labels)):
        if sum(client_pred[idx] == true_label for client_pred in client_pred_labels) < 2 and central_pred == true_label:
            result_indices.append(idx)
    return result_indices

def find_correct_clients_wrong_central(true_labels, client_pred_labels, central_pred_labels):
    result_indices = []
    for idx, (true_label, central_pred) in enumerate(zip(true_labels, central_pred_labels)):
        # if all(client_pred[idx] == true_label for client_pred in client_pred_labels) and central_pred != true_label:
        #     result_indices.append(idx)
        # more than 3 clients are correct and central is wrong
        if sum(client_pred[idx] == true_label for client_pred in client_pred_labels) > 3 and central_pred != true_label:
            result_indices.append(idx)
            
    return result_indices

# Example usage
wrong_clients_correct_central_indices = find_wrong_clients_correct_central(labels, pred_labels, central_pred_labels[0])
print("Wrong clients, correct central indices:", wrong_clients_correct_central_indices)

correct_clients_wrong_central_indices = find_correct_clients_wrong_central(labels, pred_labels, central_pred_labels[0])
print("Correct clients, wrong central indices:", correct_clients_wrong_central_indices)

In [None]:
def dice_score(y_pred, y_true, smooth=1):
    if torch.max(y_pred) > 1:
        print("y_pred should be in range [0, 1]")
    if torch.max(y_true) > 1:
        print("y_true should be in range [0, 1]")
    y_pred = y_pred.float()
    y_true = y_true.float()
    dice_loss = (2 * (y_pred * y_true).sum() + smooth) / ((y_pred + y_true).sum() + smooth)
    return dice_loss

dice_scores = []
for i in range(10):
    central_grad_cam = torch.tensor(central_grad_cam_image[i] >0.1).unsqueeze(0).cpu()
    mask_img = masks[i].unsqueeze(0).cpu() > 0
    ds = dice_score(central_grad_cam, mask_img)
    dice_scores.append(ds)

print(dice_scores)    

In [None]:
%matplotlib inline
def get_border_color(true_label, pred_label):
    if pred_label == true_label:
        return 'lime'  # green
    elif set(pred_label) & set(true_label):
        return 'gold'  # yellow
    else:
        return 'red'
image_index = [1, 4, 5, 8]
row = len(image_index)
col = 9
clients = 5
extra_plots = 3  # union, intersection, global

import matplotlib.pyplot as plt
plt.figure(figsize=(3 * col, 3 * row))

for j in range(row):
    idxImage = image_index[j]
    true_label = labels[idxImage]
    for i in range(clients + extra_plots + 1):
        ax = plt.subplot(row, col, j * col + i + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        # print(i, idxImage)
        if i == 0:
            img = images[idxImage].cpu().permute(1, 2, 0)
            mask = masks[idxImage].cpu().permute(1, 2, 0)
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title(f'GT: {true_label}')
        elif i <= clients:
            img = grad_cam_images[i - 1].cpu()[idxImage]
            pred_label = pred_labels[i - 1][idxImage]
            border_color = get_border_color(true_label, pred_label)
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            # set border color
            ax.patch.set_edgecolor(border_color)
            ax.patch.set_linewidth(7)
            plt.title(f'client{i}: {pred_label}')
        elif i == clients + 1:
            img = union_cam.cpu()[idxImage]
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title('union')
        elif i == clients + 2:
            img = intersection_cam.cpu()[idxImage]
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title('intersection')
        elif i == clients + 3:
            img = central_grad_cam_image[idxImage]
            border_color = get_border_color(true_label, central_pred_labels[idxImage])
            plt.imshow(img)
            plt.imshow(mask, alpha=0.3, cmap='gray')
            plt.title(f'global: {central_pred_labels[idxImage]}, dice: {dice_scores[idxImage]:.2f}')
            ax.patch.set_edgecolor(border_color)
            ax.patch.set_linewidth(7)
        # plt.gca().set_xticks([])
        # plt.gca().set_yticks([])
        # plt.gca().spines['top'].set_visible(False)
        # plt.gca().spines['right'].set_visible(False)
        # plt.gca().spines['bottom'].set_visible(False)
        # plt.gca().spines['left'].set_visible(False)

plt.savefig('grad_cam2.png')
plt.show()
# plt.tight_layout()

In [None]:
mha_images = [] 
th_images = []
for model in models:
    mha, th = model.module.get_attention_maps_postprocessing_(images.cuda())
    mha_images.append(mha)
    th_images.append(th)
    
mha_images = torch.stack([torch.tensor(mha_images[i]) for i in range(len(mha_images))])
th_images = torch.stack([torch.tensor(th_images[i]) for i in range(len(th_images))])
mha_images.shape # n_clients * b * 224 * 224
th_images.shape # n_clients * b * 224 * 224
print(mha_images.shape, th_images.shape)


In [None]:
central_mha, central_th = central_model.module.get_attention_maps_postprocessing_(images.cuda())

In [None]:
central_mha = torch.tensor(central_mha)
central_mha_agg = []
for i in range(central_mha.shape[0]):
    central_mha_agg.append(torch.max(central_mha[i], dim=0)[0])
central_mha_agg = torch.stack(central_mha_agg)
central_mha_agg.shape

In [None]:
plt.imshow(central_mha_agg[2].cpu())
plt.colorbar()

In [None]:
mha_dice_scores = []
for i in range(10):
    central_mha_agg_ = torch.tensor(central_mha_agg[i] > torch.median(central_mha_agg[i])).unsqueeze(0).cpu()
    mask_img = masks[i].unsqueeze(0).cpu() > 0
    ds = dice_score(central_mha_agg_, mask_img)
    mha_dice_scores.append(ds)
mha_dice_scores

In [None]:
# plt.imshow(masks[0].cpu().permute(1, 2, 0))
# plot image : masks > 0
plt.imshow(masks[0].cpu().permute(1, 2, 0) > 0)

In [None]:
grad_cam_images[0]
np.median(grad_cam_images[0].cpu().numpy())

In [None]:
grad_cam_images[0].shape

In [None]:
def plot_mha_images(images, labels, mha_images, pred_labels, central_mha, central_pred_labels):
    n_clients, n_images, n_head, h, w = mha_images.shape
    image_indices = [2,3,5,8]
    row = len(image_indices) * n_head
    col = 1 + n_clients + 1
    plt.figure(figsize=(3 * col, 3 * row))
    for j in range(0, row):
        _j = j // n_head
        img_index = image_indices[_j]
        true_label = labels[img_index]
        k = j % n_head
        if k == 0:
            ax = plt.subplot(row, col, j * col + 1)
            ax.set_xticks([])
            ax.set_yticks([])
            plt.imshow(images[img_index].numpy().transpose(1, 2, 0))
            plt.imshow(masks[img_index].numpy().transpose(1, 2, 0), alpha=0.3, cmap='gray')
            plt.title(true_label)
        for i in range(0, n_clients):
            ax = plt.subplot(row, col, j * col + i + 2)
            ax.set_xticks([])
            ax.set_yticks([])
            plt.imshow(mha_images[i, img_index, k, :, :].numpy())
            plt.imshow(masks[img_index].numpy().transpose(1, 2, 0), alpha=0.3, cmap='gray')
            pred_label = pred_labels[i][img_index]
            border_color = get_border_color(true_label, pred_label)
            ax.patch.set_edgecolor(border_color)
            ax.patch.set_linewidth(7)
            if k == 0:
                plt.title(f'client {i}: {pred_label}')
        ax = plt.subplot(row, col, j * col + n_clients + 2)
        plt.imshow(central_mha[img_index, k, :, :])
        plt.imshow(masks[img_index].numpy().transpose(1, 2, 0), alpha=0.3, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])
        central_pred_label = central_pred_labels[img_index]
        border_color = get_border_color(true_label, central_pred_label)
        ax.patch.set_edgecolor(border_color)
        ax.patch.set_linewidth(7)

        if k == 0:
            plt.title(f'global: {central_pred_label}, dice: {mha_dice_scores[img_index]:.2f}')
    plt.savefig('mha_images.png')
    plt.show()
    plt.tight_layout()

# Usage
plot_mha_images(images, labels, mha_images, pred_labels, central_mha, central_pred_labels)