# Import

In [None]:
import random
import os
import os.path as osp
import sys
import math
import copy 
import pprint
import shutil

from tqdm import tqdm
from abc import ABC, abstractmethod

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import cv2
import copy
import pandas as pd
import seaborn as sns
import imgaug.augmenters as iaa
from PIL import Image
from sklearn import manifold, datasets
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from scipy.special import expit
from scipy.stats import pearsonr

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable

import torchvision
from torchvision import datasets
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.models import resnet152, densenet121, mobilenet_v2

import SimpleITK as sitk


# from DiffableMaps import ExpModel
from integrated_gradients import IntegratedGradients
from xrai import XRAI
from utils import *
from ExplModel import *
from attack import attack_expl_T, attack_pred_T

# from utils import ShowImage, ShowHeatMap
# from utils import get_expl, plot_overview, clamp, load_image, make_dir, img_norm
# from chexpert import fetch_dataloader, grad_cam, DenseNet
from dataset import ChexpertSmall

%matplotlib inline

In [None]:
# random seeds
cudnn.benchmark = True
manual_seed = 41
random.seed(manual_seed)
np.random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed(manual_seed)

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

# Settings

In [None]:
cfgs = {}
### model # (densenet121, resnet152, mobilenet_v2, efficientnet-b[0-7])')
# ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
dense121_cfgs = {
    'model_info':{
        'model_name': 'densenet121',
        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_weakrobust/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/robust_densenet121/best_checkpoints/checkpoint_9.pt',
    },
    'exp_method': 'VanillaBP',
    'exp_cfgs':{
#         'target_layer': ['features','denseblock4','denselayer14'],
    }
}
# res152_cfgs = {
#     'model_info':{
#         'model_name': 'resnet152',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/resnet152_5c/best_checkpoints/checkpoint_5.pt',
#     },
#     'exp_method': 'IntegratedBP',
#     'exp_cfgs':{
# #         'target_layer': ['layer4','2'],
#     }
# }
# mobilV2_cfgs = {
#     'model_info':{
#         'model_name': 'mobilenet_v2',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/mobilenet_v2_5c/best_checkpoints/checkpoint_7.pt',
#     },
#     'exp_method': 'IntegratedBP',
#     'exp_cfgs':{
# #         'target_layer': ['features', '17'],
#     }
# }

model_cfgs = {}
model_cfgs['src_models']=[
    dense121_cfgs,
#     mobilV2_cfgs,
#     res152_cfgs,
]
model_cfgs['tgt_models']=[
#     res152_cfgs,
]
model_cfgs['pretrained'] = False
cfgs['model'] = model_cfgs


### data
data_cfgs = {}
data_cfgs['data_path'] = '/home/Attack_Attn/ChestXpert/'
data_cfgs['resize'] = 512
data_cfgs['data_mean'] = np.array([0.5330])
data_cfgs['data_std'] = np.array([0.0349])
data_cfgs['mini_data'] = None
data_cfgs['drop_lateral'] = True ## discard images with lateral view
data_cfgs['n_classes'] = len(ChexpertSmall.attr_names)
cfgs['dataset'] = data_cfgs

# att
att_cfgs = {}
att_cfgs['num_iter'] = 1000
att_cfgs['lr'] = 2e-4
att_cfgs['output_dir'] = './save_dir'
att_cfgs['epsilon'] = 0.03
att_cfgs['start_beta'] = None
att_cfgs['end_beta'] = None
att_cfgs['beta_growth'] = False
att_cfgs['prefactors'] = [8e4, 1e1]
cfgs['attack'] = att_cfgs

# att_cfgs = {}
# att_cfgs['num_iter'] = 1000
# att_cfgs['lr'] = 2e-4
# att_cfgs['output_dir'] = './save_dir'
# # att_cfgs['beta'] = 1000
# att_cfgs['epsilon'] = 0.03
# att_cfgs['start_beta'] = 10 #None
# att_cfgs['end_beta'] = 100 #None
# att_cfgs['beta_growth'] = True #False
# # att_cfgs['prefactors'] = [1e6, 1] # [1e13, 1e2]
# # att_cfgs['prefactors'] = [1e10, 0] # [1e13, 1e2]
# # att_cfgs['prefactors'] = [1e10, 1] # [1e13, 1e2]
# att_cfgs['prefactors'] = [1e14, 1e4] #[8e4, 1e1] # [1e13, 1e2]
# # att_cfgs['prefactors'] = [8e11, 1e8] # [1e13, 1e2]
# cfgs['attack'] = att_cfgs


cfgs['device'] = 'cuda'
cfgs['batch_size'] = 1
### save
cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred'
os.makedirs(cfgs['output_dir'], exist_ok=True)

cfgs['restore'] = ''
cfgs['step'] = 0 

# visualize and save expl maps
cfgs['expl_vis'] = True
cfgs['expl_save'] = True
cfgs['expl_save_path'] = './save_dir'
os.makedirs(cfgs['expl_save_path'], exist_ok=True)

In [None]:
# cfgs['model']['tgt_models'][0]['model_info']['model_path']

# Utils

In [None]:
# Boilerplate methods.
def ShowImage(im, title='', ax=None):
    if ax is None:
        plt.figure()
#     plt.axis('off')
    plt.imshow(im)
    plt.title(title)
    
def ShowGrayscaleImage(im, title='', ax=None):
    if ax is None:
        plt.figure()
#     plt.axis('off')
    plt.imshow(im, cmap=plt.cm.gray, vmin=0, vmax=1)
    plt.title(title)
    
def ShowHeatMap(im, title, ax=None):
    if ax is None:
        plt.figure()
#     plt.axis('off')
    plt.imshow(im, cmap='inferno')
    plt.title(title)
    
def clamp(x, mean, std):
    upper = torch.from_numpy(np.array((1.0 - mean) / std)).to(x.device)
    lower = torch.from_numpy(np.array((0.0 - mean) / std)).to(x.device)

    if x.shape[1] == 3:  # 3-channel image
        for i in [0, 1, 2]:
            x[0][i] = torch.clamp(x[0][i], min=lower[i], max=upper[i])
    else:
        x = torch.clamp(x, min=lower[0], max=upper[0])
    return x

def load_image_view(data_mean, data_std, image):
    transforms_org = T.Compose([
        T.Resize(data_cfgs['resize']),
        T.CenterCrop(data_cfgs['resize']),
        T.ToTensor(),
#         T.Lambda(lambda x: torch.from_numpy(np.array(x, copy=True)).float().div(255)),   # tensor in [0,1]
        T.Lambda(lambda x: x.expand(3,-1,-1))
    ])
    transforms_pred = T.Compose([
        T.Resize(data_cfgs['resize']),
        T.CenterCrop(data_cfgs['resize']),
        T.ToTensor(),
#         T.Lambda(lambda x: torch.from_numpy(np.array(x, copy=True)).float().div(255)),   # tensor in [0,1]
        T.Normalize(mean=data_mean, std=data_std),
        T.Lambda(lambda x: x.expand(3,-1,-1))
    ])
    if data_mean is not None and data_std is not None:
        img_ = transforms_pred(image).permute(1,2,0)
    else:
        img_ = transforms_org(image).permute(1,2,0)
    return img_

In [None]:
'''load image for momdel input'''
def get_transforms(data_cfgs, if_norm=False, aug=None):
    transforms = [
#         T.Resize((data_cfgs['resize'], data_cfgs['resize']))
        T.Resize(data_cfgs['resize']),
        T.CenterCrop(data_cfgs['resize']),
    ]
    
    if aug is not None:
        transforms += [
            T.Lambda(aug),
            T.CenterCrop(data_cfgs['resize']),
            T.RandomHorizontalFlip(),
        ]
    if if_norm:
        transforms += [
            T.Normalize(mean=data_cfgs['data_mean'], std=data_cfgs['data_std'])
        ]
        
    transforms.append(T.ToTensor())
    return T.Compose(transforms)

def image_process(image, data_cfgs, if_norm=False, aug=None):
    transforms = get_transforms(data_cfgs, if_norm, aug)
    return transforms(image)


# '''load image for visualization'''
# def image_process_vis(image, data_cfgs, if_norm=False, aug=None):
#     return image_process(image, data_cfgs, if_norm, aug).expand(3,-1,-1).permute(1,2,0).contiguous()


def img_norm(image, k=1):
    if isinstance(image, np.ndarray):
        image = image.astype('float')
    else:
        image = image.to(torch.float32)
    image = image - image.min()
    image = image / image.max()
    if isinstance(image, np.ndarray):
        image = np.clip(image*k, 0, 1)
    else:
        image = torch.clip(image*k, 0, 1)
    return image


def batch_img_norm(image, k=1):
    if isinstance(image, np.ndarray):
        image = image.astype('float')
        n, h, w = image.shape
        image = image.reshape(n, -1)
        image = image - image.min(axis=1, keepdims=True)
        image = image / (image.max(axis=1, keepdims=True) + 1e-10)
        image = np.clip(image*k, 0, 1)
        image = image.reshape(n, h, w)
    else:
        image = image.to(torch.float32)
        n, h, w = image.size()
        image = image.view(n, -1)
        image = image - image.min(dim=1, keepdim=True)
        image = image / (image.max(dim=1, keepdim=True) + 1e-10)
        image = torch.clip(image*k, 0, 1)
        image = image.view(n, h, w)
    return image

# Dataset

In [None]:
'''data augmentation'''
class Augmenter(object):
    def __init__(self):
        self.aug_seq = iaa.Sequential([
            iaa.AdditiveGaussianNoise(scale=(0.0*255, 0.02*255), per_channel=True),
#             iaa.GaussianBlur(sigma=(0.0, 1.0))
        ])
    def __call__(self, img):
        img = np.asarray(img)
        aug_img = self.aug_seq(image=img)
        aug_img = Image.fromarray(aug_img)
        return aug_img
    
AUGMENT = Augmenter()

In [None]:
def fetch_dataloader(cfgs, mode):
    assert mode in ['train', 'valid', 'test', 'vis']
    data_cfgs = cfgs['dataset']
    if mode == 'train':
        transforms = get_transforms(data_cfgs, aug=AUGMENT)
    else:
        transforms = get_transforms(data_cfgs)
        
    dataset = ChexpertSmall(
        data_cfgs['data_path'],
        mode, transforms, 
        mini_data=data_cfgs['mini_data'], 
        drop_lateral=data_cfgs['drop_lateral'])
    
    return DataLoader(
        dataset, cfgs['batch_size'],
        shuffle=(mode=='train'),
        pin_memory=(cfgs['device']=='cuda'),
        num_workers=0 if mode=='valid' else 16) 
# since evaluating the valid_dataloader is called inside the
# train_dataloader loop, 0 workers for valid_dataloader avoids
# forking (cf torch dataloader docs); else memory sharing gets clunky

In [None]:
train_dataloader = fetch_dataloader(cfgs, mode='train')
valid_dataloader = fetch_dataloader(cfgs, mode='valid')

print('Attributes: ', valid_dataloader.dataset.attr_names)
print('Num classes:', len(valid_dataloader.dataset.attr_names))
print('Train data length: ', len(train_dataloader.dataset))
print('Valid data length: ', len(valid_dataloader.dataset))

batch_datas, batch_labels, batch_index = next(iter(train_dataloader)) 
print('label:', batch_labels[0:5])

# print images
plt.figure(figsize=(6,6))
img = make_grid(batch_datas[:32], nrow=8, padding=2)
npimg = img.numpy()
print(npimg.shape)
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

# batch size
print(batch_datas.size(), batch_datas.max(), batch_datas.min())

In [None]:
# batch_labels = batch_labels.to(torch.device(cfgs['device']))
# batch_labels.device

# Train & Test

In [None]:
@torch.no_grad()
def evaluate_single_model(model, dataloader, loss_fn, cfgs):
    outputs, targets, losses = evaluate(model, dataloader, loss_fn, cfgs)
    return compute_metrics(outputs, targets, losses)
    
    
def evaluate(model, dataloader, loss_fn, cfgs):
    model.eval()
    targets, outputs, losses = [], [], []
    index = 0
    for x, target, idxs in dataloader:
        index+=1
#         if index<30:
        out = model(x.cuda())
#             print(out)
        loss = loss_fn(out, target.cuda())
        outputs += [out.cpu()]
        targets += [target]
        losses  += [loss.cpu()]
    return torch.cat(outputs), torch.cat(targets), torch.cat(losses)


def compute_metrics(outputs, targets, losses):
    n_classes = outputs.shape[1]
    fpr, tpr, aucs, precision, recall = {}, {}, {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(targets[:,i], outputs[:,i])
        aucs[i] = auc(fpr[i], tpr[i])
        precision[i], recall[i], _ = precision_recall_curve(targets[:,i], outputs[:,i])
        fpr[i], tpr[i], precision[i], recall[i] = fpr[i].tolist(), tpr[i].tolist(), precision[i].tolist(), recall[i].tolist()

    metrics = {'fpr': fpr,
               'tpr': tpr,
               'aucs': aucs,
               'precision': precision,
               'recall': recall,
               'loss': dict(enumerate(losses.mean(0).tolist()))}
    return metrics

## Eval metrics

In [None]:
def eval_total(dataset, args, model_name='densenet'):
    if model_name == 'densenet':
        des_model = densenet121(pretrained=False)#.cuda()
        des_model.classifier = nn.Linear(des_model.classifier.in_features, out_features=args['n_classes'])
        checkpoint = torch.load(args['densenet_path'])
        des_model.load_state_dict(checkpoint['state_dict'])
        des_model.eval()
    
    ### init the lists for three metrics[method names]
    targets, outputs_org, outputs_adv = [], [], []
    tot_ssim_list, tot_mse_list, tot_pcc_list = {}, {}, {}
    method_list = ['VanillaBP', 'VanillaBP_Img', 'GradCam', 'GuidedGradCam', 'IntegratedBP', 'SmoothBP', 'XRAI']
    for method_name in method_list:
        tot_ssim_list[method_name], tot_mse_list[method_name], tot_pcc_list[method_name] = [], [], []
    
    for image_index in np.arange(55): #np.arange(len(dataset)): #
#         if image_index > 74:
            print('============== image_index:', image_index, '=================\n')
            if image_index>-1:
                _, target, _ = dataset[image_index]
                target = target.unsqueeze(0)
            '''three metrics'''
            for method_name in method_list:
                img_list, _, ssim_list, mse_list, pcc_list = ImgPairMetric(case_num=image_index, metric_name=method_name)
                tot_ssim_list[method_name].append(ssim_list)
                tot_mse_list[method_name].append(mse_list)
                tot_pcc_list[method_name].append(pcc_list)
            '''model performance (AUC, precision, ...)'''  
            img_org0, img_adv0 = img_list
            img_org0_ts = torch.Tensor(img_org0).permute(2,0,1).unsqueeze(0)
            img_adv0_ts = torch.Tensor(img_adv0).permute(2,0,1).unsqueeze(0)
            with torch.no_grad(): # run on cpu, no cuda
                out_org, out_adv = des_model(img_org0_ts), des_model(img_adv0_ts)
                outputs_org += [out_org]
                outputs_adv += [out_adv]
                targets += [target]
    outputs_org, outputs_adv, targets, = torch.cat(outputs_org), torch.cat(outputs_adv), torch.cat(targets)
    metrics_org, metrics_adv = compute_metrics(outputs_org, targets), compute_metrics(outputs_adv, targets)
    return tot_ssim_list, tot_mse_list, tot_pcc_list, metrics_org, metrics_adv


# __Main__

## Attack source model & save data

### Instantiation models

In [None]:
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]
# target_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['tgt_models']]

#  def __init__(self, cfgs, target_layer, init_from=None, model_info=None):

In [None]:
print('Attributes: ', valid_dataloader.dataset.attr_names)
print('Num classes:', len(valid_dataloader.dataset.attr_names))
print('Valid data length: ', len(valid_dataloader.dataset))

img_idx = 160

img, label, patient_id = valid_dataloader.dataset[img_idx]
print('label:', label)
print(img.size())

class_of_interest = [1., 0., 0., 0., 0.]

### Heatmap of one sample

In [None]:
img, label, patient_id = valid_dataloader.dataset[img_idx]
print('label:', label)
print(img.size())

res = [_.cal_exp_map(img.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]


In [None]:
class_of_interest

In [None]:
print(img.min(), img.max())

In [None]:

if source_exp_models[0].exp_name == 'XRAI':
    for heatmap, pred in res:
        ROWS = 1
        COLS = 3
        UPSCALE_FACTOR = 20
        plt.figure(figsize=(ROWS * UPSCALE_FACTOR, COLS * UPSCALE_FACTOR))
        ##
#         im_orig = load_image_view(data_mean=None, data_std=None, image=img)
        im_orig = img.permute(1,2,0)
        ## Show original image
        ShowGrayscaleImage(im_orig, title='Original Image', ax=plt.subplot(ROWS, COLS, 1))
        ## Show XRAI heatmap attributions
        ShowHeatMap(heatmap, title='XRAI Heatmap', ax=plt.subplot(ROWS, COLS, 2))
        ## Show most salient 30% of the image
        mask = heatmap > np.percentile(heatmap, 92)
        im_mask = np.array(im_orig)
        im_mask[~mask] = 0
        ShowImage(im_mask, title='Top 15%', ax=plt.subplot(ROWS, COLS, 3))
else:       
    for heatmap, pred in res:
        print(torch.sigmoid(pred))
        heatmap = heatmap.data.cpu().numpy()
        heatmap = img_norm(heatmap,k=1.7) # norm grad to speial range
        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)                     # re-norm to [0,1]

        ## load original image
        base_img = img.expand(3,-1,-1).permute(1,2,0)
        alpha = 0.5
        img_fused = base_img*(1-alpha) + cmap*alpha
        plt.figure(figsize=(6,6))
        plt.imshow(img_fused[0])
        plt.show()

In [None]:
pred.shape

### Attack one sample PRED

In [None]:
''' trun off the beta update scheduler '''
att_cfgs = {}
att_cfgs['num_iter'] = 8 # VBP:8 | VBP*I: 8 | IG: 20 | SG: 50  | GCAM:60
att_cfgs['lr'] = 2e-4
att_cfgs['output_dir'] = './save_dir'
att_cfgs['epsilon'] = 0.03
att_cfgs['start_beta'] = None
att_cfgs['end_beta'] = None
att_cfgs['beta_growth'] = False
att_cfgs['prefactors'] = [8e4, 1e1]
cfgs['attack'] = att_cfgs

In [None]:
''' targeted-logit attack'''
img, label, patient_id = valid_dataloader.dataset[img_idx]
print('label:', label)
print(img.size())

# adv_imgs = attack_pred_T(
#     source_exp_models,
#     img.to(torch.device(cfgs['device'])).unsqueeze(0),
#     class_of_interest,
#     label.to(torch.device(cfgs['device'])).unsqueeze(0),
#     cfgs,
# #     attack_steps=1000,
# #     vis_iter=50
# )

adv_imgs = attack_pred_T_old(
    source_exp_models,
    img.to(torch.device(cfgs['device'])).unsqueeze(0),
    class_of_interest,
    label.to(torch.device(cfgs['device'])).unsqueeze(0),
    cfgs,
#     attack_steps=1000,
#     vis_iter=50
)

print((img.cuda().unsqueeze(0)-adv_imgs).max())

In [None]:
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]
# target_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['tgt_models']]

In [None]:
'''--- evaluate on source model ---'''
_img = adv_imgs.clone().detach()
res = [_.cal_exp_map(_img, class_of_interest) for _ in source_exp_models]

# img, label, patient_id = valid_dataloader.dataset[img_idx]
# res = [_.cal_exp_map(img.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]


for heatmap, pred in res:
    print(torch.sigmoid(pred))
    heatmap = heatmap.data.cpu().numpy()
    heatmap = img_norm(heatmap, k=1.7) # norm grad to speial range
    cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
    cmap = img_norm(cmap)                     # re-norm to [0,1]
    
    ## load original image
    base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap
    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight
    
    
    plt.figure(figsize=(10,10))
    base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
    alpha = .5
    img_fused = base_img*(1-alpha) + cmap*alpha
    plt.subplot(2,2,1)
    plt.imshow(img[0], cmap='gray')
    plt.subplot(2,2,2)
#     plt.imshow(attack_target_maps[0].data.cpu(), cmap='gray')
    plt.subplot(2,2,3)
    plt.imshow(_img[0,0].data.cpu().numpy(), cmap='gray')
    plt.subplot(2,2,4)
    plt.imshow(img_fused[0])
    plt.show()

In [None]:
# (img[0] - _img[0,0].data.cpu()).max()

In [None]:
'''--- evaluate on target model ---'''
# # 0.03
# _img = adv_imgs.clone().detach()
# res = [_.cal_exp_map(_img, class_of_interest) for _ in target_exp_models]

# for heatmap, pred in res:
#     print(torch.sigmoid(pred))
#     heatmap = heatmap.data.cpu().numpy()
#     heatmap = img_norm(heatmap) # norm grad to speial range
#     cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
#     cmap = img_norm(cmap)                     # re-norm to [0,1]
    
#     ## load original image
#     plt.figure(figsize=(10,10))
#     base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
#     alpha = .5
#     img_fused = base_img*(1-alpha) + cmap*alpha
#     plt.subplot(2,2,1)
#     plt.imshow(img[0], cmap='gray')
#     plt.subplot(2,2,2)
# #     plt.imshow(attack_target_maps[0].data.cpu(), cmap='gray')
#     plt.subplot(2,2,3)
#     plt.imshow(_img[0,0].data.cpu().numpy(), cmap='gray')
#     plt.subplot(2,2,4)
#     plt.imshow(img_fused[0])
#     plt.show()

### Attack one sample EXPL

In [None]:
''' trun on the beta update scheduler '''
att_cfgs['start_beta'] = 10 #None
att_cfgs['end_beta'] = 100 #None
att_cfgs['beta_growth'] = True #False
att_cfgs['num_iter'] = 200

att_cfgs['prefactors'] = [1e14, 1e4] 

In [None]:
img, label, patient_id = valid_dataloader.dataset[img_idx]
print('label:', label)
print(img.size())

res = [_.cal_exp_map(img.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
h_map = res[0][0]
h_map = img_norm(h_map)
attack_target_maps = 1-h_map
attack_target_maps = img_norm(attack_target_maps)
print(attack_target_maps.min(), attack_target_maps.max())
print((1-attack_target_maps).min(), (1-attack_target_maps).max())
plt.imshow(1-attack_target_maps[0].data.cpu())
plt.show()

In [None]:
img, label, patient_id = valid_dataloader.dataset[img_idx]
print('label:', label)
print(img.size())

attack_target_maps = torch.zeros_like(img).detach_()
attack_target_maps[:, 130:190, 380:440] = 1.0/3600 # binary mask

adv_imgs = attack_expl_T(
    source_exp_models, 
    img.to(torch.device(cfgs['device'])).unsqueeze(0), 
    class_of_interest, 
    attack_target_maps.to(torch.device(cfgs['device'])).unsqueeze(0),
    cfgs, 
    attack_steps=200,
    epsilon=0.06,
    vis_iter=40
)

In [None]:
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]
# target_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['tgt_models']]

In [None]:
# 0.03
_img = adv_imgs.clone().detach()
res = [_.cal_exp_map(_img, class_of_interest) for _ in source_exp_models]
# res_org = [_.cal_exp_map(img, class_of_interest) for _ in source_exp_models]

for heatmap, _ in res:
    print(torch.sigmoid(_))
    print(torch.sigmoid(pred))
    heatmap = heatmap.data.cpu().numpy()
    heatmap = img_norm(heatmap,k=2.7) # norm grad to speial range
#     heatmap = img_norm(heatmap) # norm grad to speial range
    cmap = cm.plasma(heatmap)[..., :3]     # color map proj
    cmap = img_norm(cmap)                     # re-norm to [0,1]
    
    ## load original image
    plt.figure(figsize=(10,10))
    base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap
    
    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight
    plt.subplot(2,2,1)
    plt.imshow(img[0], cmap='gray')
    plt.subplot(2,2,2)
    plt.imshow(attack_target_maps[0].data.cpu(), cmap='gray')
    plt.subplot(2,2,3)
    plt.imshow(_img[0,0].data.cpu().numpy(), cmap='gray')
    plt.subplot(2,2,4)
    plt.imshow(img_fused[0])
    plt.show()

In [None]:
diff = np.abs(_img[0,0].data.cpu().numpy() - img[0].data.cpu().numpy())
print(diff.min(), diff.max())

In [None]:
_img = adv_imgs.clone().detach()
# res = [_.cal_exp_map(_img, class_of_interest) for _ in target_exp_models]
res = [_.cal_exp_map(_img, class_of_interest) for _ in source_exp_models]
for heatmap, _ in res:
    print(torch.sigmoid(pred))
    heatmap = heatmap.data.cpu().numpy()
    heatmap = img_norm(heatmap) # norm grad to speial range
    cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
    cmap = img_norm(cmap)                     # re-norm to [0,1]
    
    ## load original image
    plt.figure(figsize=(10,10))
    base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
    alpha = .5
    img_fused = base_img*(1-alpha) + cmap*alpha
    plt.subplot(2,2,1)
    plt.imshow(img[0], cmap='gray')
    plt.subplot(2,2,2)
    plt.imshow(attack_target_maps[0].data.cpu(), cmap='gray')
    plt.subplot(2,2,3)
    plt.imshow(_img[0,0].data.cpu().numpy(), cmap='gray')
    plt.subplot(2,2,4)
    plt.imshow(img_fused[0])
    plt.show()

### Test on target models

In [None]:
# img, label, patient_id = valid_dataloader

In [None]:
# src_model = load_model(src_tgt='src', args=args)
# src_model_name = 'densenet' if 'densenet' in args['src_model'] else 'resnet'

# tgt_model = load_model(src_tgt='tgt', args=args)
# tgt_model_name = 'densenet' if 'densenet' in args['tgt_model'] else 'resnet'

In [None]:
# ### define number of steps to attack
# label_dict = {'0': [1., 0., 0., 0., 0.],
#               '1': [0., 1., 0., 0., 0.],
#               '2': [0., 0., 1., 0., 0.],
#               '3': [0., 0., 0., 1., 0.],
#               '4': [0., 0., 0., 0., 1.]}
# args_att['num_iter'] = 40
# SSIM = []

# for image_index in np.arange(1): #len(ds)
#     img, label, patient_id = ds[image_index]
#     try: 
#         first_one_position = np.where(np.array(label) == 1)[0][0] ## first 1
#     except: 
#         first_one_position = -1  ### no 1
    
#     if first_one_position>=0:
#         label_list = label_dict[str(first_one_position)]
        
#         print('========== src attacked image {} ==========='.format(image_index))
#         args['expl_vis'] = True
#         x_adv, one_hot = AttackExpl(model_name=src_model_name, model=src_model, 
#                                    dataset=ds, image_index=image_index, label_list=label_list, 
#                                    method_name=method_name, args=args, args_att=args_att, args_method=None)
        
#         print('========== tgt original image {} ==========='.format(image_index))
#         args['expl_vis'] = True
#         G_org = GetExpl(model_name=tgt_model_name, model=tgt_model,
#                         dataset=ds, image_index=image_index, label_list=label_list, 
#                         beta=None, method_name=method_name, args=args)

#         print('========== tgt attacked image {} ==========='.format(image_index))
#         G_att = GetExpl_from_image(model_name=tgt_model_name, model=tgt_model,
#                                    image_ts=x_adv, one_hot=one_hot, 
#                                    method_name=method_name, args=args, beta=None, args_method=None)
#         expl_ssim, _ = ssim(img_norm_02(G_org), img_norm_02(G_att), data_range=255, full=True, multichannel=False)
#         SSIM.append(expl_ssim)

In [None]:
# SSIM_np=np.asarray(SSIM)
# plt.hist(SSIM_np,50)
# plt.xlabel('SSIM',fontsize=16)
# plt.ylabel('Counts',fontsize=16)

# __Test__

In [None]:
aa = [1,2,3,4,1,2,3,4,1]
for m in aa:
    if m != 1: 
        continue
    else: 
        m +=1
    print(m)

## attack PRED

In [None]:
''' trun off the beta update scheduler '''
# att_cfgs['start_beta'] = None
# att_cfgs['end_beta'] = None
# att_cfgs['beta_growth'] = False
# att_cfgs['num_iter'] = 4

# att_cfgs['prefactors'] = [1, 8e14]#[8e4, 1e1]
# att_cfgs['class_of_interest'] = 3

# cfgs['attack'] = att_cfgs

''' trun off the beta update scheduler '''
att_cfgs = {}
att_cfgs['num_iter'] = 8
att_cfgs['lr'] = 2e-4
att_cfgs['output_dir'] = './save_dir'
att_cfgs['epsilon'] = 0.03
att_cfgs['start_beta'] = None
att_cfgs['end_beta'] = None
att_cfgs['beta_growth'] = False
att_cfgs['prefactors'] = [8e4, 1e1]
cfgs['attack'] = att_cfgs

label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

In [None]:
# for imgs, labels, idx in valid_dataloader:
#     print(labels)

In [None]:
for ii in np.arange(5):
    
    org_imgs = []
    adv_imgs = []
    org_labels = []
    source_org_hm = [[] for _ in range(len(cfgs['model']['src_models']))]
    source_org_pred = [[] for _ in range(len(cfgs['model']['src_models']))]
    # target_org_hm = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # target_org_pred = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    source_adv_hm = [[] for _ in range(len(cfgs['model']['src_models']))]
    source_adv_pred = [[] for _ in range(len(cfgs['model']['src_models']))]
    # target_adv_hm = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # target_adv_pred = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # class_of_interest = [1., 0., 0., 0., 0.]
    iteration = 0


    for imgs, labels, idx in valid_dataloader:
        print('===========Iteration %d===========' % iteration)

        iteration += 1
#         try: first_one_position = np.where(np.array(labels[0]) == 1)[0][0] ## first 1
#         except: first_one_position = -1  ### no 1
        class_of_interest = label_dict[str(ii)]

#         if labels[0][ii] == 1:
#             class_of_interest = label_dict[str(ii)]
#             print('interest: {}, labels: {}'.format(ii, labels))
#         else:
#             continue

    #     if iteration > 10:
    #         break
        try: del source_exp_models, target_exp_models
        except: pass
        torch.cuda.empty_cache()

        org_images = imgs.detach()
        org_labels.append(labels)
        source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]
    #     target_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['tgt_models']]
    
        _advs = attack_pred_T_old(
                                source_exp_models,
                                imgs.detach(),
                                class_of_interest,
                                labels.detach(),
                                cfgs,
                            )
        
        adv_imgs.append(_advs.data.cpu())
        org_imgs.append(org_images.data.cpu())

        source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]

        _advs = _advs.detach()
        torch.cuda.empty_cache()

        s_org = [_.cal_exp_map(org_images.cuda(), class_of_interest) for _ in source_exp_models]
        
        for i in range(len(s_org)):
            source_org_hm[i].append(s_org[i][0].data.cpu())
            source_org_pred[i].append(s_org[i][1].data.cpu())

        del s_org

        s_adv = [_.cal_exp_map(_advs.detach().cuda(), class_of_interest) for _ in source_exp_models]

        for i in range(len(s_adv)):
            source_adv_hm[i].append(s_adv[i][0].data.cpu())
            source_adv_pred[i].append(s_adv[i][1].data.cpu())

        del s_adv, _advs, source_exp_models
        torch.cuda.empty_cache()

    '''save npy images'''
    adv_imgs = torch.cat(adv_imgs, dim=0)
    print(adv_imgs.shape)
    org_imgs = torch.cat(org_imgs, dim=0)
    print(org_imgs.shape)

    source_org_hm = torch.cat(source_org_hm[0], dim=0).unsqueeze_(1)
    print(source_org_hm.shape)
    source_org_pred = torch.cat(source_org_pred[0], dim=0)
    print(source_org_pred.shape)

    source_adv_hm = torch.cat(source_adv_hm[0], dim=0).unsqueeze_(1)
    print(source_adv_hm.shape)
    source_adv_pred = torch.cat(source_adv_pred[0], dim=0)
    print(source_adv_pred.shape)

    # source_org_hm = [torch.cat(_, dim=0) for _ in source_org_hm]
    # source_org_pred = [torch.cat(_, dim=0) for _ in source_org_pred]
    # target_org_hm = [torch.cat(_, dim=0) for _ in target_org_hm]
    # target_org_pred = [torch.cat(_, dim=0) for _ in target_org_pred]
    # source_adv_hm = [torch.cat(_, dim=0) for _ in source_adv_hm]
    # source_adv_pred = [torch.cat(_, dim=0) for _ in source_adv_pred]
    # target_adv_hm = [torch.cat(_, dim=0) for _ in target_adv_hm]
    # target_adv_pred = [torch.cat(_, dim=0) for _ in target_adv_pred]

    org_labels = torch.cat(org_labels, dim=0)
    print('original label size: ', org_labels.size())

    '''save image cubes'''
    cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred'
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    os.makedirs(save_dir, exist_ok=True)

    np.save(os.path.join(save_dir, 'org_imgs.npy'), org_imgs.numpy())
    np.save(os.path.join(save_dir, 'adv_imgs.npy'), adv_imgs.numpy())

    np.save(os.path.join(save_dir, 'org_hm.npy'), source_org_hm.numpy())
    np.save(os.path.join(save_dir, 'adv_hm.npy'), source_adv_hm.numpy())

    np.save(os.path.join(save_dir, 'org_labels.npy'), org_labels.numpy())
    np.save(os.path.join(save_dir, 'org_pred.npy'), source_org_pred.numpy())
    np.save(os.path.join(save_dir, 'adv_pred.npy'), source_adv_pred.numpy())


In [None]:
# org_labels = []
# for imgs, labels, idx in valid_dataloader:
#     org_labels.append(labels)
# org_labels = torch.cat(org_labels, dim=0)

### Utils

In [None]:
def compute_metrics(outputs, targets):
    n_classes = outputs.shape[1]
    fpr, tpr, aucs, precision, recall = {}, {}, {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(targets[:,i], outputs[:,i])
        aucs[i] = auc(fpr[i], tpr[i])
        precision[i], recall[i], _ = precision_recall_curve(targets[:,i], outputs[:,i])
        fpr[i], tpr[i], precision[i], recall[i] = fpr[i].tolist(), tpr[i].tolist(), precision[i].tolist(), recall[i].tolist()

    metrics = {'fpr': fpr,
               'tpr': tpr,
               'aucs': aucs,
               'precision': precision,
               'recall': recall}
    return metrics

In [None]:
def eval_img_similarity(imgs_1, imgs_2):
    assert len(imgs_1.shape) == len(imgs_2.shape) == 3
    ssim_vals = []
    pcc_vals = []
    imgs_1 = batch_img_norm(imgs_1) * 255
    imgs_2 = batch_img_norm(imgs_2) * 255
    for i in range(imgs_1.shape[0]):
        _img1 = imgs_1[i]
        _img2 = imgs_2[i]
        if np.isnan(_img1.mean()) or np.isnan(_img2.mean()):
            print(i)
            continue
        
        # SSIM
        _ssim, _ = ssim(_img1, _img2, data_range=255, full=True, multichannel=False)
        ssim_vals.append(_ssim)
        
        # PCC
        _pcc = pearsonr(_img1.reshape(-1), _img2.reshape(-1))[0]
        pcc_vals.append(_pcc)
    
    # MSE
    mse_vals = np.nanmean(((imgs_1/255 - imgs_2/255)**2).mean(-1))
    return {
        'ssim': np.asarray(ssim_vals),
        'mse': mse_vals,
        'pcc': np.asarray(pcc_vals),
    }

In [None]:
# cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred'
#     save_dir = os.path.join(cfgs['output_dir'], 
#                             model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
#                             model_cfgs['src_models'][0]['exp_method']+\
#                             '_class_'+str(ii))

In [None]:
# def eval_auc_sim(model_name, xai_name):
def load_all(model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    img_org, img_att = [], []
    pred_org, pred_att, label = [], [], []
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load predictions'''
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'org_hm.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'adv_hm.npy')))
            '''load images'''
            img_org.append(np.load(os.path.join(save_dir, 'org_imgs.npy')))
            img_att.append(np.load(os.path.join(save_dir, 'adv_imgs.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)

        img_org = np.concatenate(img_org, 0)
        img_att = np.concatenate(img_att, 0)
    else:
        save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii_specific))
        '''load predictions'''
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'org_hm.npy'))
        hm_att = np.load(os.path.join(save_dir, 'adv_hm.npy'))
        '''load images'''
        img_org = np.load(os.path.join(save_dir, 'org_imgs.npy'))
        img_att = np.load(os.path.join(save_dir, 'adv_imgs.npy')) 
    
    return [img_org, img_att], [hm_org, hm_att], [pred_org, pred_att, label]

In [None]:
def eval_all_metrics(model_name, xai_name, ii_specific=None):
    img_, hm_, pred_ = load_all(model_name, xai_name, ii_specific)
    
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
        
    print('=== Org AUC ===')
    auc_org = compute_metrics(pred_[0], pred_[2])['aucs']
    print(auc_org)
    
    print('=== Attack AUC ===')
    auc_att = compute_metrics(pred_[1], pred_[2])['aucs']
    print(auc_att)
    
    print('=== Heat Map ===')
    res = eval_img_similarity(hm_[0][:,0,:,:], hm_[1][:,0,:,:])
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_hp = _v
        if _k == 'mse': mse_hp = _v
    
    print('=== Image ===')
    res = eval_img_similarity(img_[0][:,0,:,:], img_[1][:,0,:,:])
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_im = _v
        if _k == 'mse': mse_im = _v
            
    return auc_org[ii_specific], auc_att[ii_specific], ssim_hp, mse_hp, ssim_im, mse_im

### load and eval

In [None]:
# xai_names = ['VanillaBP']
num_classes = len(valid_dataloader.dataset.attr_names)
xai_names = ['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
model_name_list = ['resnet152', 'densenet121']

for model_name in model_name_list:
    print('-------------------------------- {} --------------------------------'.format(model_name))
    for xai_name in xai_names:
        print('-------------------------------- {} --------------------------------'.format(xai_name))
        AUC_org, AUC_att, SSIM_hp, MSE_hp, SSIM_im, MSE_im = 0,0,0,0,0,0
        for ii in np.arange(num_classes):
            auc_org, auc_att, ssim_hp, mse_hp, ssim_im, mse_im = eval_all_metrics(model_name, xai_name, ii)
            AUC_org += auc_org/num_classes
            AUC_att += auc_att/num_classes
            SSIM_hp += ssim_hp/num_classes
            SSIM_im += ssim_im/num_classes
            MSE_hp  += mse_hp/num_classes
            MSE_im  += mse_im/num_classes
    #     eval_all_metrics(model_name, xai_name, None)
        print('============= averaged results of {} ============= '.format(xai_name))
        print('AUC_org: ',AUC_org)
        print('AUC_att: ',AUC_att)
        print('SSIM_hp: ',SSIM_hp)
        print('SSIM_im: ',SSIM_im)
        print('MSE_hp: ',MSE_hp)
        print('MSE_im: ',MSE_im)

### Compute and save heatmap of G-GradCAM

In [None]:
label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

'''init model'''
dense121_cfgs = {
    'model_info':{
        'model_name': 'densenet121',
        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_weakrobust/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/robust_densenet121/best_checkpoints/checkpoint_9.pt',
    },
    'exp_method': 'GuidedGradCAM',
    'exp_cfgs':{
        'target_layer': ['features','denseblock4','denselayer14'],
    }
}

model_cfgs = {}
model_cfgs['src_models']=[
    dense121_cfgs,
#     mobilV2_cfgs,
#     res152_cfgs,
]
model_cfgs['tgt_models']=[
#     res152_cfgs,
]
model_cfgs['pretrained'] = False
cfgs['model'] = model_cfgs
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]


'''load data from GradCAM'''
xai_names =  ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name = 'densenet121'
img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='GradCAM', ii_specific=1)
print(img_[0].shape) # original image
print(img_[1].shape) # adversarial image
# image_num, _, _, _ = img_[0].shape
# print(image_num)

In [None]:
for ii in np.arange(5): # 5 different classes
    print('==================== calss number: {} ===================='.format(ii))
    hm_org, hm_att = [], []
    img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='GradCAM', ii_specific=ii)
    image_num, _, _, _ = img_[0].shape
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    class_of_interest = label_dict[str(ii)]
    os.makedirs(save_dir, exist_ok=True)
    for img_index in np.arange(image_num):
        print('---- img_index: {} ----'.format(img_index))
        img_org = torch.Tensor(img_[0][img_index]).clone().detach()
        img_adv = torch.Tensor(img_[1][img_index]).clone().detach()
#         class_of_interest = torch.Tensor(pred_[2][0])
#         print(img_org.size())
        res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
        res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]

        hm_org.append(res_org[0][0].data.cpu())
        hm_att.append(res_adv[0][0].data.cpu())
        del res_org, res_adv
        torch.cuda.empty_cache()

    '''save npy images'''
    hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
    hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
    print(hm_org.shape)
    
    np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
    np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

In [None]:
# pred_

### Compute and save heatmap of XRAI

In [None]:
label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

'''init model'''
dense121_cfgs = {
    'model_info':{
        'model_name': 'densenet121',
        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_weakrobust/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/robust_densenet121/best_checkpoints/checkpoint_9.pt',
    },
    'exp_method': 'XRAI',
    'exp_cfgs':{
#         'target_layer': ['features','denseblock4','denselayer14'],
    }
}

model_cfgs = {}
model_cfgs['src_models']=[
    dense121_cfgs,
#     mobilV2_cfgs,
#     res152_cfgs,
]
model_cfgs['tgt_models']=[
#     res152_cfgs,
]
model_cfgs['pretrained'] = False
cfgs['model'] = model_cfgs
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]


'''load data from GradCAM'''
xai_names =  ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name = 'densenet121'
img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='IntegratedBP', ii_specific=1)
print(img_[0].shape) # original image
print(img_[1].shape) # adversarial image
# image_num, _, _, _ = img_[0].shape
# print(image_num)

In [None]:
for ii in np.arange(5): # 5 different classes
    print('==================== calss number: {} ===================='.format(ii))
    hm_org, hm_att = [], []
    img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='IntegratedBP', ii_specific=ii)
    image_num, _, _, _ = img_[0].shape
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    class_of_interest = label_dict[str(ii)]
    os.makedirs(save_dir, exist_ok=True)
    for img_index in np.arange(image_num):
        print('---- img_index: {} ----'.format(img_index))
        img_org = torch.Tensor(img_[0][img_index]).clone().detach()
        img_adv = torch.Tensor(img_[1][img_index]).clone().detach()
#         class_of_interest = torch.Tensor(pred_[2][0])
#         print(img_org.size())
        res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
        res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]

        hm_org.append(torch.Tensor(res_org[0][0]).unsqueeze_(0))
        hm_att.append(torch.Tensor(res_adv[0][0]).unsqueeze_(0))
        del res_org, res_adv
        torch.cuda.empty_cache()

    '''save npy images'''
    hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
    hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
    print(hm_org.shape)
    
    np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
    np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

### eval G-GradCAM and XRAI

In [None]:
def load_all__(model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'hm_org.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'hm_att.npy')))

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)
    else:
        save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii_specific))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'hm_org.npy'))
        hm_att = np.load(os.path.join(save_dir, 'hm_att.npy'))
    
    return [hm_org, hm_att]

In [None]:
def eval_all_metrics__(model_name, xai_name, ii_specific=None):
    hm_ = load_all__(model_name, xai_name, ii_specific)
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
    
    print('=== Heat Map ===')
    res = eval_img_similarity(hm_[0][:,0,:,:], hm_[1][:,0,:,:])
    for _k in res.keys():
        print(_k, np.nanmean(res[_k]))


In [None]:
xai_names =  ['GuidedGradCAM', 'XRAI']
model_name = 'densenet121'
for xai_name in xai_names:
    print('-------------------------------- {} --------------------------------'.format(xai_name))
    for ii in np.arange(5):
        eval_all_metrics__(model_name, xai_name, ii)
    eval_all_metrics__(model_name, xai_name, None)

In [None]:
res_org[0][0].shape

In [None]:
# for heatmap, pred in res:
#     print(torch.sigmoid(pred))
#     heatmap = heatmap.data.cpu().numpy()
#     heatmap = img_norm(heatmap) # norm grad to speial range
#     cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
#     cmap = img_norm(cmap)                     # re-norm to [0,1]
    
#     ## load original image
#     base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
#     alpha = .99
#     htmp_weight = np.zeros_like(cmap.squeeze())
#     print(htmp_weight.shape)
#     htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap
#     img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight
    
    
#     plt.figure(figsize=(6,6))
#     base_img = _img[0].data.cpu().expand(3,-1,-1).permute(1,2,0)
#     alpha = .5
#     img_fused = base_img*(1-alpha) + cmap*alpha
#     plt.subplot(2,2,1)
#     plt.imshow(_img[0], cmap='gray')
#     plt.subplot(2,2,2)
# #     plt.imshow(attack_target_maps[0].data.cpu(), cmap='gray')
#     plt.subplot(2,2,3)
#     plt.imshow(_img[0].data.cpu().numpy(), cmap='gray')
#     plt.subplot(2,2,4)
#     plt.imshow(img_fused[0])
#     plt.show()

## attack EXPL

In [None]:
''' trun on the beta update scheduler '''
att_cfgs['start_beta'] = 10 #None
att_cfgs['end_beta'] = 100 #None
att_cfgs['beta_growth'] = True #False
att_cfgs['num_iter'] = 200
att_cfgs['lr'] = 2e-4
att_cfgs['epsilon'] = 0.06
att_cfgs['prefactors'] = [1e14, 1e4]
cfgs['attack'] = att_cfgs

label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

attack_target_maps = torch.zeros_like(img).detach_()
attack_target_maps[:, 130:190, 380:440] = 1.0/3600 # binary mask
attack_target_maps = attack_target_maps.to(torch.device(cfgs['device'])).unsqueeze(0)

In [None]:
for ii in np.arange(5):
    
    org_imgs = []
    adv_imgs = []
    org_labels = []
    source_org_hm = [[] for _ in range(len(cfgs['model']['src_models']))]
    source_org_pred = [[] for _ in range(len(cfgs['model']['src_models']))]
    # target_org_hm = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # target_org_pred = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    source_adv_hm = [[] for _ in range(len(cfgs['model']['src_models']))]
    source_adv_pred = [[] for _ in range(len(cfgs['model']['src_models']))]
    # target_adv_hm = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # target_adv_pred = [[] for _ in range(len(cfgs['model']['tgt_models']))]
    # class_of_interest = [1., 0., 0., 0., 0.]
    iteration = 0
    print('---------------------- Class %d ---------------------' % ii)
    for imgs, labels, idx in valid_dataloader:
        print('====== Iteration %d ======' % iteration)

        iteration += 1
#         if iteration >10:
#             break
#         try: first_one_position = np.where(np.array(labels[0]) == 1)[0][0] ## first 1
#         except: first_one_position = -1  ### no 1
        class_of_interest = label_dict[str(ii)]
    
        try: del source_exp_models
        except: pass
        torch.cuda.empty_cache()

        org_images = imgs.detach()
        org_labels.append(labels)
        source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]
        '''get adversarial images'''
        if labels[0][ii] == 1:
            print('Have such class:{}, labels:{}'.format(ii, labels))
            _advs = org_images
            flag = 0
        else:
            print('No such class:{}, labels:{}, target:{}'.format(ii, labels, class_of_interest))
            _advs = attack_expl_T(
                                    source_exp_models, 
                                    imgs.to(torch.device(cfgs['device'])), 
                                    class_of_interest, 
                                    attack_target_maps,
                                    cfgs, 
                                    attack_steps=200,
                                    epsilon=0.06,
                                    vis_iter=49
                                )
            flag = 1
        '''eval'''
        try: del source_exp_models
        except: pass
        adv_imgs.append(_advs.data.cpu())
        org_imgs.append(org_images.data.cpu())
        source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]

        _advs = _advs.detach()
        torch.cuda.empty_cache()
        
        ### original
        s_org = [_.cal_exp_map(org_images.cuda(), class_of_interest) for _ in source_exp_models]
        for i in range(len(s_org)):
            source_org_hm[i].append(s_org[i][0].data.cpu())
            source_org_pred[i].append(s_org[i][1].data.cpu())

        ### adv
        if flag == 0:
            source_adv_hm[i].append(s_org[i][0].data.cpu())
            source_adv_pred[i].append(s_org[i][1].data.cpu())
            del s_org, source_exp_models
        else:
            del s_org
            s_adv = [_.cal_exp_map(_advs.detach().cuda(), class_of_interest) for _ in source_exp_models]
            for i in range(len(s_adv)):
                source_adv_hm[i].append(s_adv[i][0].data.cpu())
                source_adv_pred[i].append(s_adv[i][1].data.cpu())
            del s_adv, _advs, source_exp_models
        torch.cuda.empty_cache()

    '''save npy images'''
    adv_imgs = torch.cat(adv_imgs, dim=0)
    print(adv_imgs.shape)
    org_imgs = torch.cat(org_imgs, dim=0)
    print(org_imgs.shape)

    source_org_hm = torch.cat(source_org_hm[0], dim=0).unsqueeze_(1)
    print(source_org_hm.shape)
    source_org_pred = torch.cat(source_org_pred[0], dim=0)
    print(source_org_pred.shape)

    source_adv_hm = torch.cat(source_adv_hm[0], dim=0).unsqueeze_(1)
    print(source_adv_hm.shape)
    source_adv_pred = torch.cat(source_adv_pred[0], dim=0)
    print(source_adv_pred.shape)

    # source_org_hm = [torch.cat(_, dim=0) for _ in source_org_hm]
    # source_org_pred = [torch.cat(_, dim=0) for _ in source_org_pred]
    # target_org_hm = [torch.cat(_, dim=0) for _ in target_org_hm]
    # target_org_pred = [torch.cat(_, dim=0) for _ in target_org_pred]
    # source_adv_hm = [torch.cat(_, dim=0) for _ in source_adv_hm]
    # source_adv_pred = [torch.cat(_, dim=0) for _ in source_adv_pred]
    # target_adv_hm = [torch.cat(_, dim=0) for _ in target_adv_hm]
    # target_adv_pred = [torch.cat(_, dim=0) for _ in target_adv_pred]

    org_labels = torch.cat(org_labels, dim=0)
    print('original label size: ', org_labels.size())

    '''save image cubes'''
    cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl_inv'
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    os.makedirs(save_dir, exist_ok=True)

    np.save(os.path.join(save_dir, 'org_imgs.npy'), org_imgs.numpy())
    np.save(os.path.join(save_dir, 'adv_imgs.npy'), adv_imgs.numpy())

    np.save(os.path.join(save_dir, 'org_hm.npy'), source_org_hm.numpy())
    np.save(os.path.join(save_dir, 'adv_hm.npy'), source_adv_hm.numpy())

    np.save(os.path.join(save_dir, 'org_labels.npy'), org_labels.numpy())
    np.save(os.path.join(save_dir, 'org_pred.npy'), source_org_pred.numpy())
    np.save(os.path.join(save_dir, 'adv_pred.npy'), source_adv_pred.numpy())


### Utils

In [None]:
def compute_metrics(outputs, targets):
    n_classes = outputs.shape[1]
    fpr, tpr, aucs, precision, recall = {}, {}, {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(targets[:,i], outputs[:,i])
        aucs[i] = auc(fpr[i], tpr[i])
        precision[i], recall[i], _ = precision_recall_curve(targets[:,i], outputs[:,i])
        fpr[i], tpr[i], precision[i], recall[i] = fpr[i].tolist(), tpr[i].tolist(), precision[i].tolist(), recall[i].tolist()

    metrics = {'fpr': fpr,
               'tpr': tpr,
               'aucs': aucs,
               'precision': precision,
               'recall': recall}
    return metrics

In [None]:
def eval_img_similarity(imgs_1, imgs_2):
    assert len(imgs_1.shape) == len(imgs_2.shape) == 3
    ssim_vals = []
    pcc_vals = []
    imgs_1 = batch_img_norm(imgs_1) * 255
    imgs_2 = batch_img_norm(imgs_2) * 255
    for i in range(imgs_1.shape[0]):
        _img1 = imgs_1[i]
        _img2 = imgs_2[i]
        if np.isnan(_img1.mean()) or np.isnan(_img2.mean()):
            print(i)
            continue
        
        # SSIM
        _ssim, _ = ssim(_img1, _img2, data_range=255, full=True, multichannel=False)
        ssim_vals.append(_ssim)
        
        # PCC
        _pcc = pearsonr(_img1.reshape(-1), _img2.reshape(-1))[0]
        pcc_vals.append(_pcc)
        
        # JS-div
        
    
    # MSE
    mse_vals = np.nanmean(((imgs_1/255 - imgs_2/255)**2).mean(-1))
    return {
        'ssim': np.asarray(ssim_vals),
        'mse': mse_vals,
        'pcc': np.asarray(pcc_vals),
    }

In [None]:
# def eval_auc_sim(model_name, xai_name):
def load_all(model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    img_org, img_att = [], []
    pred_org, pred_att, label = [], [], []
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(cfgs['output_dir'], 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load predictions'''
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'org_hm.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'adv_hm.npy')))
            '''load images'''
            img_org.append(np.load(os.path.join(save_dir, 'org_imgs.npy')))
            img_att.append(np.load(os.path.join(save_dir, 'adv_imgs.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)

        img_org = np.concatenate(img_org, 0)
        img_att = np.concatenate(img_att, 0)
    else:
        save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii_specific))
        '''load predictions'''
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'org_hm.npy'))
        hm_att = np.load(os.path.join(save_dir, 'adv_hm.npy'))
        '''load images'''
        img_org = np.load(os.path.join(save_dir, 'org_imgs.npy'))
        img_att = np.load(os.path.join(save_dir, 'adv_imgs.npy')) 
    
    return [img_org, img_att], [hm_org, hm_att], [pred_org, pred_att, label]

In [None]:
def eval_all_metrics(model_name, xai_name, ii_specific=None):
    img_, hm_, pred_ = load_all(model_name, xai_name, ii_specific)
    
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
        
    print('=== Org AUC ===')
    auc_org = compute_metrics(pred_[0], pred_[2])['aucs']
    print(auc_org)
    
    print('=== Attack AUC ===')
    auc_att = compute_metrics(pred_[1], pred_[2])['aucs']
    print(auc_att)
    
    print('=== Heat Map Org===')
    att_index = np.where(pred_[2][:,ii_specific]==1)[0]
    res = eval_img_similarity(hm_[0][att_index,0,:,:], hm_[1][att_index,0,:,:])
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_hp_org = _v
        if _k == 'mse': mse_hp_org = _v
    
    print('=== Heat Map Target===')
    att_index = np.where(pred_[2][:,ii_specific]==1)[0]
    attack_target_maps = np.zeros_like(hm_[0][att_index,0,:,:])
    attack_target_maps[:, 130:190, 380:440] = 1.0/3600 # binary mask
    res = eval_img_similarity(hm_[1][att_index,0,:,:], attack_target_maps)
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_hp_att = _v
        if _k == 'mse': mse_hp_att = _v
    
    print('=== Image ===')
    res = eval_img_similarity(img_[0][:,0,:,:], img_[1][:,0,:,:])
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_im = _v
        if _k == 'mse': mse_im = _v
            
    return auc_org[ii_specific], auc_att[ii_specific], ssim_hp_org, mse_hp_org, ssim_hp_att, mse_hp_att, ssim_im, mse_im

### load and eval

In [None]:
# xai_names = ['VanillaBP']
cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl'
num_classes = len(valid_dataloader.dataset.attr_names)
xai_names = ['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
model_name_list = ['resnet152', 'densenet121']

for model_name in model_name_list:
    print('-------------------------------- {} --------------------------------'.format(model_name))
    for xai_name in xai_names:
        print('-------------------- {} ----------------------'.format(xai_name))
        AUC_org, AUC_att, SSIM_hp_org, MSE_hp_org, SSIM_hp_att, MSE_hp_att, SSIM_im, MSE_im = 0,0,0,0,0,0,0,0
        for ii in np.arange(num_classes):
            auc_org, auc_att, ssim_hp_org, mse_hp_org, ssim_hp_att, mse_hp_att, ssim_im, mse_im = eval_all_metrics(model_name, 
                                                                                                                   xai_name, ii)
            AUC_org += auc_org/num_classes
            AUC_att += auc_att/num_classes

            SSIM_hp_org += ssim_hp_org/num_classes
            SSIM_hp_att += ssim_hp_att/num_classes
            SSIM_im += ssim_im/num_classes

            MSE_hp_org  += mse_hp_org/num_classes
            MSE_hp_att  += mse_hp_att/num_classes
            MSE_im  += mse_im/num_classes
    #     eval_all_metrics(model_name, xai_name, None)
        print('============= averaged results of {} ============= '.format(xai_name))
        print('AUC_org: ',AUC_org)
        print('AUC_att: ',AUC_att)

        print('SSIM_hp_org: ',SSIM_hp_org)
        print('SSIM_hp_att: ',SSIM_hp_att)
        print('SSIM_im: ',SSIM_im)

        print('MSE_hp_org: ',MSE_hp_org)
        print('MSE_hp_att: ',MSE_hp_att)
        print('MSE_im: ',MSE_im)

### Compute and save heatmap of G-GradCAM

In [None]:
cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl'

In [None]:
label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

'''init model'''
dense121_cfgs = {
    'model_info':{
        'model_name': 'densenet121',
        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_weakrobust/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/robust_densenet121/best_checkpoints/checkpoint_9.pt',
    },
    'exp_method': 'GuidedGradCAM',
    'exp_cfgs':{
        'target_layer': ['features','denseblock4','denselayer14'],
    }
}

model_cfgs = {}
model_cfgs['src_models']=[
    dense121_cfgs,
#     mobilV2_cfgs,
#     res152_cfgs,
]
model_cfgs['tgt_models']=[
#     res152_cfgs,
]
model_cfgs['pretrained'] = False
cfgs['model'] = model_cfgs
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]


'''load data from GradCAM'''
xai_names =  ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name = 'densenet121'
img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='GradCAM', ii_specific=1)
print(img_[0].shape) # original image
print(img_[1].shape) # adversarial image
# image_num, _, _, _ = img_[0].shape
# print(image_num)

In [None]:
for ii in np.arange(5): # 5 different classes
    print('==================== calss number: {} ===================='.format(ii))
    hm_org, hm_att = [], []
    img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='GradCAM', ii_specific=ii)
    image_num, _, _, _ = img_[0].shape
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    class_of_interest = label_dict[str(ii)]
    os.makedirs(save_dir, exist_ok=True)
    for img_index in np.arange(image_num):
        print('---- img_index: {} ----'.format(img_index))
        img_org = torch.Tensor(img_[0][img_index]).clone().detach()
        img_adv = torch.Tensor(img_[1][img_index]).clone().detach()
#         class_of_interest = torch.Tensor(pred_[2][0])
#         print(img_org.size())
        res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
        res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]

        hm_org.append(res_org[0][0].data.cpu())
        hm_att.append(res_adv[0][0].data.cpu())
        del res_org, res_adv
        torch.cuda.empty_cache()

    '''save npy images'''
    hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
    hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
    print(hm_org.shape)
    
    np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
    np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

### Compute and save heatmap of XRAI

In [None]:
label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

'''init model'''
dense121_cfgs = {
    'model_info':{
        'model_name': 'densenet121',
        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_weakrobust/best_checkpoints/checkpoint_9.pt',
#         'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/robust_densenet121/best_checkpoints/checkpoint_9.pt',
    },
    'exp_method': 'XRAI',
    'exp_cfgs':{
#         'target_layer': ['features','denseblock4','denselayer14'],
    }
}

model_cfgs = {}
model_cfgs['src_models']=[
    dense121_cfgs,
#     mobilV2_cfgs,
#     res152_cfgs,
]
model_cfgs['tgt_models']=[
#     res152_cfgs,
]
model_cfgs['pretrained'] = False
cfgs['model'] = model_cfgs
source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]


'''load data from GradCAM'''
xai_names =  ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name = 'densenet121'
img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='IntegratedBP', ii_specific=1)
print(img_[0].shape) # original image
print(img_[1].shape) # adversarial image
# image_num, _, _, _ = img_[0].shape
# print(image_num)

In [None]:
for ii in np.arange(5): # 5 different classes
    print('==================== calss number: {} ===================='.format(ii))
    hm_org, hm_att = [], []
    img_, hm_, pred_ = load_all(model_name='densenet121', xai_name='IntegratedBP', ii_specific=ii)
    image_num, _, _, _ = img_[0].shape
    save_dir = os.path.join(cfgs['output_dir'], 
                            model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                            model_cfgs['src_models'][0]['exp_method']+\
                            '_class_'+str(ii))
    class_of_interest = label_dict[str(ii)]
    os.makedirs(save_dir, exist_ok=True)
    for img_index in np.arange(image_num):
        print('---- img_index: {} ----'.format(img_index))
        img_org = torch.Tensor(img_[0][img_index]).clone().detach()
        img_adv = torch.Tensor(img_[1][img_index]).clone().detach()
#         class_of_interest = torch.Tensor(pred_[2][0])
#         print(img_org.size())
        res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
        res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]

        hm_org.append(torch.Tensor(res_org[0][0]).unsqueeze_(0))
        hm_att.append(torch.Tensor(res_adv[0][0]).unsqueeze_(0))
        del res_org, res_adv
        torch.cuda.empty_cache()

    '''save npy images'''
    hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
    hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
    print(hm_org.shape)
    
    np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
    np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

### eval G-GradCAM and XRAI

In [None]:
def load_all__(model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'hm_org.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'hm_att.npy')))

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)
    else:
        save_dir = os.path.join(cfgs['output_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii_specific))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'hm_org.npy'))
        hm_att = np.load(os.path.join(save_dir, 'hm_att.npy'))
    
    return [hm_org, hm_att]

In [None]:
def eval_all_metrics__(model_name, xai_name, ii_specific=None):
    hm_ = load_all__(model_name, xai_name, ii_specific)
    if xai_name == 'GuidedGradCAM':
        img_, _, pred_ = load_all(model_name, 'GradCAM', ii_specific)
    if xai_name == 'XRAI':
        img_, _, pred_ = load_all(model_name, 'IntegratedBP', ii_specific)
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
        
    print('=== Heat Map Org===')
    att_index = np.where(pred_[2][:,ii_specific]==1)[0]
    res = eval_img_similarity(hm_[0][att_index,0,:,:], hm_[1][att_index,0,:,:])
    for _k in res.keys():
        print(_k, np.nanmean(res[_k]))
        
    print('=== Heat Map Target===')
    att_index = np.where(pred_[2][:,ii_specific]==1)[0]
    attack_target_maps = np.zeros_like(hm_[0][att_index,0,:,:])
    attack_target_maps[:, 130:190, 380:440] = 1.0/3600 # binary mask
    res = eval_img_similarity(hm_[1][att_index,0,:,:], attack_target_maps)
    for _k in res.keys():
        print(_k, np.nanmean(res[_k]))

In [None]:
xai_names =  ['GuidedGradCAM', 'XRAI']
model_name = 'densenet121'
for xai_name in xai_names:
    print('-------------------------------- {} --------------------------------'.format(xai_name))
    for ii in np.arange(5):
        eval_all_metrics__(model_name, xai_name, ii)
#     eval_all_metrics__(model_name, xai_name, None)

In [None]:
ii_specific_ = 4
_, _, pred_ = load_all(model_name='densenet121', xai_name='IntegratedBP', ii_specific=ii_specific_)
hm_ = load_all__(model_name='densenet121', xai_name='XRAI', ii_specific=ii_specific_)

In [None]:
pred_[2].shape

In [None]:
index = np.where(pred_[2][:,ii_specific_]==1)[0]
partial = hm_[0][index,0,:,:]
print(partial.shape)
print(index)

In [None]:
hm_org, hm_att = hm_[0], hm_[1]
print(hm_org[0,0].shape, hm_att.shape)

In [None]:
res = eval_img_similarity(hm_org[index,0,:,:], hm_att[index,0,:,:])
for _k in res.keys():
        print(_k, np.nanmean(res[_k]))

In [None]:
kk = 17
plt.imshow(hm_org[index[kk],0])

In [None]:
plt.imshow(hm_att[index[kk],0])

## Compute PSC

###  Utils

In [None]:
def load_all(output_dir, model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    img_org, img_att = [], []
    pred_org, pred_att, label = [], [], []
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load predictions'''
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'org_hm.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'adv_hm.npy')))
            '''load images'''
            img_org.append(np.load(os.path.join(save_dir, 'org_imgs.npy')))
            img_att.append(np.load(os.path.join(save_dir, 'adv_imgs.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)

        img_org = np.concatenate(img_org, 0)
        img_att = np.concatenate(img_att, 0)
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        '''load predictions'''
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'org_hm.npy'))
        hm_att = np.load(os.path.join(save_dir, 'adv_hm.npy'))
        '''load images'''
        img_org = np.load(os.path.join(save_dir, 'org_imgs.npy'))
        img_att = np.load(os.path.join(save_dir, 'adv_imgs.npy')) 
    
    return [img_org, img_att], [hm_org, hm_att], [pred_org, pred_att, label]

In [None]:
def eval_all_metrics(model_name, xai_name, ii_specific=None):
    img_, hm_, pred_ = load_all(model_name, xai_name, ii_specific)
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
    print('=== Org AUC ===')
    print(compute_metrics(pred_[0], pred_[2])['aucs'])
    print('=== Attack AUC ===')
    print(compute_metrics(pred_[1], pred_[2])['aucs'])
    
    print('=== Heat Map ===')
    res = eval_img_similarity(hm_[0][:,0,:,:], hm_[1][:,0,:,:])
    for _k in res.keys():
        print(_k, np.nanmean(res[_k]))
    
    print('=== Image ===')
    res = eval_img_similarity(img_[0][:,0,:,:], img_[1][:,0,:,:])
    for _k in res.keys():
        print(_k, np.nanmean(res[_k]))
    
    print('=== JSD-PRED ===')
    jsd = compute_jsd(pred_[0], pred_[1], ii_specific)
    print(jsd.mean())
    print('=== JSD-HM ===')
    jsd = compute_jsd(hm_[0], hm_[1], ii_specific)
    print(jsd.mean())

In [None]:
def compute_jsd(pred_org, pred_adv, class_of_interest):
    
    pred_org, pred_adv = torch.Tensor(pred_org), torch.Tensor(pred_adv)
    if len(pred_org.shape) == 4: #image data
#         print('image')
        prob_org, prob_adv = torch.flatten(pred_org, start_dim=1), torch.flatten(pred_adv, start_dim=1)
        prob_mean = torch.clamp((prob_org + prob_adv) / 2., 1e-7, 1).log()
#         print(prob_mean.shape)
        kld1 = F.kl_div(prob_mean, prob_org, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv, reduction="none").sum(1)
    else: # logits data
#         print('logits')
        prob_org, prob_adv = torch.sigmoid(pred_org), torch.sigmoid(pred_adv)
        prob_org_intr, prob_adv_intr = prob_org[:,class_of_interest].unsqueeze(1), prob_adv[:,class_of_interest].unsqueeze(1)

        prob_org_binary = torch.cat([prob_org_intr, 1-prob_org_intr], dim=1)
        prob_adv_binary = torch.cat([prob_adv_intr, 1-prob_adv_intr], dim=1)
        prob_mean = torch.clamp((prob_org_binary + prob_adv_binary) / 2., 1e-7, 1).log()

        kld1 = F.kl_div(prob_mean, prob_org_binary, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv_binary, reduction="none").sum(1)
    jsd = (kld1 + kld2) * 0.5
    return jsd

In [None]:
### load pred attack
# ['VanillaBP', 'VanillaBP_Img', 'GuidedBP', 'IntegratedBP', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
num_classes = len(valid_dataloader.dataset.attr_names)
model_list = ['densenet121', 'resnet152']
xai_list = ['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
pred_att_dir = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred'
expl_att_dir = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl'

for model_ in model_list:
# model_='resnet152'
    print('++++++++++++++++ Model: {} ++++++++++++++++'.format(model_))
    for xai_ in xai_list:
# xai_='GradCAM'
        print('========= {} ========='.format(xai_))
        psc_pred, psc_expl = 0, 0
        for class_index in np.arange(num_classes):
# class_index=1
            '''eval pred attack'''
            pred_img_, pred_hm_, pred_lab_ = load_all(output_dir=pred_att_dir, 
                                                      model_name=model_, 
                                                      xai_name=xai_, 
                                                      ii_specific=class_index)
            pred_lab_jsd = compute_jsd(pred_org=pred_lab_[0], 
                                       pred_adv=pred_lab_[1], 
                                       class_of_interest=class_index)
            pred_hm_jsd = compute_jsd(pred_org=pred_hm_[0], 
                                      pred_adv=pred_hm_[1], 
                                      class_of_interest=class_index)
            pred_lab_jsd_np = np.asarray(pred_lab_jsd.numpy())
            pred_hm_jsd_np = np.asarray(pred_hm_jsd.numpy())
            pred_hm_jsd_np[np.isnan(pred_hm_jsd_np)] = np.nanmean(pred_hm_jsd_np)
            psc_pred += pearsonr(pred_lab_jsd_np, pred_hm_jsd_np)[0]/num_classes

            '''eval expl attack'''
            expl_img_, expl_hm_, expl_lab_ = load_all(output_dir=expl_att_dir, 
                                                     model_name=model_, 
                                                     xai_name=xai_, 
                                                     ii_specific=class_index)
            att_index = np.where(expl_lab_[2][:,class_index]==1)[0]
            expl_lab_jsd = compute_jsd(pred_org=expl_lab_[0][att_index,:], 
                                       pred_adv=expl_lab_[1][att_index,:], 
                                       class_of_interest=class_index)
            expl_hm_jsd = compute_jsd(pred_org=expl_hm_[0][att_index,:], 
                                      pred_adv=expl_hm_[1][att_index,:], 
                                      class_of_interest=class_index)
            expl_lab_jsd_np = np.asarray(expl_lab_jsd.numpy())
            expl_hm_jsd_np = np.asarray(expl_hm_jsd.numpy())
            expl_hm_jsd_np[np.isnan(expl_hm_jsd_np)] = np.nanmean(expl_hm_jsd_np)
            psc_expl += pearsonr(expl_lab_jsd_np, expl_hm_jsd_np)[0]/num_classes
        print('pred psc: {},   expl psc: {}'.format(psc_pred, psc_expl))

# Cross XAI transfer

## PRED

In [None]:
def load_all(output_dir, model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    img_org, img_att = [], []
    pred_org, pred_att, label = [], [], []
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load predictions'''
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'org_hm.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'adv_hm.npy')))
            '''load images'''
            img_org.append(np.load(os.path.join(save_dir, 'org_imgs.npy')))
            img_att.append(np.load(os.path.join(save_dir, 'adv_imgs.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)

        img_org = np.concatenate(img_org, 0)
        img_att = np.concatenate(img_att, 0)
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        '''load predictions'''
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'org_hm.npy'))
        hm_att = np.load(os.path.join(save_dir, 'adv_hm.npy'))
        '''load images'''
        img_org = np.load(os.path.join(save_dir, 'org_imgs.npy'))
        img_att = np.load(os.path.join(save_dir, 'adv_imgs.npy')) 
    
    return [img_org, img_att], [hm_org, hm_att], [pred_org, pred_att, label]

In [None]:
xai_namelist = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
src_xai = 'IntegratedBP'
model_name = 'resnet152' #'densenet121'

label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

cfgs['data_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred/'
cfgs['save_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_pred/'


In [None]:

for xai_name in xai_namelist:
    # xai_name = xai_namelist[0]
    '''init model'''
    if model_name == 'densenet121':
        if xai_name in ['GradCAM', 'GuidedGradCAM']:
            dense121_cfgs = {
                'model_info':{
                    'model_name': model_name,
                    'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
                },
                'exp_method': xai_name,
                'exp_cfgs':{
                    'target_layer': ['features','denseblock4','denselayer14'],
                }
            }
        else:
            dense121_cfgs = {
                'model_info':{
                    'model_name': model_name,
                    'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
                },
                'exp_method': xai_name,
                'exp_cfgs':{
            #         'target_layer': ['features','denseblock4','denselayer14'],
                }
            }
        model_cfgs = {}
        model_cfgs['src_models']=[
            dense121_cfgs,
        ]
        model_cfgs['pretrained'] = False
        cfgs['model'] = model_cfgs
        
    elif model_name == 'resnet152':
        if xai_name in ['GradCAM', 'GuidedGradCAM']:
            res152_cfgs = {
                'model_info':{
                    'model_name': model_name,
                    'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/resnet152_5c/best_checkpoints/checkpoint_5.pt',
                },
                'exp_method': xai_name,
                'exp_cfgs':{
                    'target_layer': ['layer4','2'],
                }
            }
        else:
            res152_cfgs = {
                'model_info':{
                    'model_name': model_name,
                    'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/resnet152_5c/best_checkpoints/checkpoint_5.pt',
                },
                'exp_method': xai_name,
                'exp_cfgs':{
#                     'target_layer': ['layer4','2'],
                }
            }
        model_cfgs = {}
        model_cfgs['src_models']=[
            res152_cfgs,
        ]
        model_cfgs['pretrained'] = False
        cfgs['model'] = model_cfgs
    
    
    source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]

    for ii in np.arange(5): # 5 different classes
        # ii = 0
        print('============== calss number: {} ==============='.format(ii))
        '''load data from IG folder'''
        hm_org, hm_att = [], []
        img_, _, _ = load_all(output_dir=cfgs['data_dir'], 
                            model_name=model_cfgs['src_models'][0]['model_info']['model_name'], 
                            xai_name=src_xai, 
                            ii_specific=ii)

        image_num, _, _, _ = img_[0].shape
        save_dir = os.path.join(cfgs['save_dir'], 
                                model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                xai_name+\
                                '_class_'+str(ii))
        class_of_interest = label_dict[str(ii)]
        os.makedirs(save_dir, exist_ok=True)

        for img_index in np.arange(image_num):
            print('---- img_index: {} ----'.format(img_index))
            img_org = torch.Tensor(img_[0][img_index]).clone().detach()
            img_adv = torch.Tensor(img_[1][img_index]).clone().detach()

            res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
            res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
            
            if xai_name == 'XRAI':
                hm_org.append(torch.Tensor(res_org[0][0]).unsqueeze_(0))
                hm_att.append(torch.Tensor(res_adv[0][0]).unsqueeze_(0))
            else:
                hm_org.append(res_org[0][0].data.cpu())
                hm_att.append(res_adv[0][0].data.cpu())
            del res_org, res_adv
            torch.cuda.empty_cache()

        '''save npy images'''
        hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
        hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
        print(hm_org.shape)

        np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
        np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

### Eval-PRED

In [None]:
cfgs['output_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_pred/'

In [None]:
def compute_jsd(pred_org, pred_adv, class_of_interest):
    
    pred_org, pred_adv = torch.Tensor(pred_org), torch.Tensor(pred_adv)
    if len(pred_org.shape) == 4: #image data
#         print('image')
        prob_org, prob_adv = torch.flatten(pred_org, start_dim=1), torch.flatten(pred_adv, start_dim=1)
        prob_mean = torch.clamp((prob_org + prob_adv) / 2., 1e-7, 1).log()
#         print(prob_mean.shape)
        kld1 = F.kl_div(prob_mean, prob_org, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv, reduction="none").sum(1)
    else: # logits data
#         print('logits')
        prob_org, prob_adv = torch.sigmoid(pred_org), torch.sigmoid(pred_adv)
        prob_org_intr, prob_adv_intr = prob_org[:,class_of_interest].unsqueeze(1), prob_adv[:,class_of_interest].unsqueeze(1)

        prob_org_binary = torch.cat([prob_org_intr, 1-prob_org_intr], dim=1)
        prob_adv_binary = torch.cat([prob_adv_intr, 1-prob_adv_intr], dim=1)
        prob_mean = torch.clamp((prob_org_binary + prob_adv_binary) / 2., 1e-7, 1).log()

        kld1 = F.kl_div(prob_mean, prob_org_binary, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv_binary, reduction="none").sum(1)
    jsd = (kld1 + kld2) * 0.5
    return jsd


def eval_img_similarity(imgs_1, imgs_2):
    assert len(imgs_1.shape) == len(imgs_2.shape) == 3
    ssim_vals = []
    pcc_vals = []
    imgs_1 = batch_img_norm(imgs_1) * 255
    imgs_2 = batch_img_norm(imgs_2) * 255
    for i in range(imgs_1.shape[0]):
        _img1 = imgs_1[i]
        _img2 = imgs_2[i]
        if np.isnan(_img1.mean()) or np.isnan(_img2.mean()):
            print(i)
            continue
        
        # SSIM
        _ssim, _ = ssim(_img1, _img2, data_range=255, full=True, multichannel=False)
        ssim_vals.append(_ssim)
        
        # PCC
        _pcc = pearsonr(_img1.reshape(-1), _img2.reshape(-1))[0]
        pcc_vals.append(_pcc)
    
    # MSE
    mse_vals = np.nanmean(((imgs_1/255 - imgs_2/255)**2).mean(-1))
    return {
        'ssim': np.asarray(ssim_vals),
        'mse': mse_vals,
        'pcc': np.asarray(pcc_vals),
    }


def eval_all_metrics(output_dir, model_name, xai_name, ii_specific=None):
    hm_org, hm_att = load_all_tgt(output_dir, model_name, xai_name, ii_specific) 
    print(hm_org.shape, hm_att.shape)
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
    
    print('=== Heat Map ===')
    res = eval_img_similarity(hm_org[:,0,:,:], hm_att[:,0,:,:])
    for _k in res.keys():
        _v = np.nanmean(res[_k])
        print(_k, _v)
        if _k == 'ssim': ssim_hp = _v
        if _k == 'mse': mse_hp = _v
            
    return ssim_hp, mse_hp



def load_all_src(output_dir, model_name, xai_name, ii_specific=None):
    '''load predictions'''
    hm_org, hm_att = [], []
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)
        
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
    
    return pred_org, pred_att, label


def load_all_tgt(output_dir, model_name, xai_name, ii_specific=None):
    
    hm_org, hm_att = [], []
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            hm_org.append(np.load(os.path.join(save_dir, 'hm_org.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'hm_att.npy')))

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)
        
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        hm_org = np.load(os.path.join(save_dir, 'hm_org.npy'))
        hm_att = np.load(os.path.join(save_dir, 'hm_att.npy'))
    
    return hm_org, hm_att

In [None]:
num_classes = len(valid_dataloader.dataset.attr_names)
model_list = ['resnet152'] #, 'resnet152']
xai_list = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
pred_att_dir_src = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred'
pred_att_dir_tgt = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_pred'
xai_src = 'IntegratedBP'


for model_ in model_list:
#     model_='densenet121'
        print('++++++++++++++++ Model: {} ++++++++++++++++'.format(model_))
        for xai_ in xai_list:
    # xai_='GradCAM'
            print('========= {} ========='.format(xai_))
            psc_pred = 0
            for class_index in np.arange(num_classes):
                # class_index=0

                '''load IG pred and compute JSD'''
                pred_org, pred_att, label = load_all_src(output_dir=pred_att_dir_src, 
                                                              model_name=model_, 
                                                              xai_name=xai_src, 
                                                              ii_specific=class_index)
                '''eval pred attack'''
                hm_org, hm_att = load_all_tgt(output_dir=pred_att_dir_tgt, 
                                              model_name=model_, 
                                              xai_name=xai_, 
                                              ii_specific=class_index)

                pred_lab_jsd = compute_jsd(pred_org=pred_org, 
                                           pred_adv=pred_att, 
                                           class_of_interest=class_index)

                pred_hm_jsd = compute_jsd(pred_org=hm_org, 
                                          pred_adv=hm_att, 
                                          class_of_interest=class_index)

                pred_lab_jsd_np = np.asarray(pred_lab_jsd.numpy())
                pred_hm_jsd_np = np.asarray(pred_hm_jsd.numpy())
                pred_hm_jsd_np[np.isnan(pred_hm_jsd_np)] = np.nanmean(pred_hm_jsd_np)
                psc_pred += pearsonr(pred_lab_jsd_np, pred_hm_jsd_np)[0]/num_classes
            print('pred psc: {}'.format(psc_pred))

# '''eval expl attack'''
# expl_img_, expl_hm_, expl_lab_ = load_all(output_dir=expl_att_dir, 
#                                          model_name=model_, 
#                                          xai_name=xai_, 
#                                          ii_specific=class_index)
# att_index = np.where(expl_lab_[2][:,class_index]==1)[0]
# expl_lab_jsd = compute_jsd(pred_org=expl_lab_[0][att_index,:], 
#                            pred_adv=expl_lab_[1][att_index,:], 
#                            class_of_interest=class_index)
# expl_hm_jsd = compute_jsd(pred_org=expl_hm_[0][att_index,:], 
#                           pred_adv=expl_hm_[1][att_index,:], 
#                           class_of_interest=class_index)
# expl_lab_jsd_np = np.asarray(expl_lab_jsd.numpy())
# expl_hm_jsd_np = np.asarray(expl_hm_jsd.numpy())
# expl_hm_jsd_np[np.isnan(expl_hm_jsd_np)] = np.nanmean(expl_hm_jsd_np)
# psc_expl += pearsonr(expl_lab_jsd_np, expl_hm_jsd_np)[0]/num_classes
# print('pred psc: {},   expl psc: {}'.format(psc_pred, psc_expl))

In [None]:
# xai_names = ['VanillaBP']
num_classes = len(valid_dataloader.dataset.attr_names)
xai_names = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name = 'resnet152' # 'densenet121'
pred_att_dir_tgt = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_pred'

for xai_name in xai_names:
    print('-------------------------------- {} --------------------------------'.format(xai_name))
    SSIM_hp, MSE_hp = 0,0
    for ii in np.arange(num_classes):
        ssim_hp, mse_hp = eval_all_metrics(pred_att_dir_tgt,model_name,xai_name,ii)
        SSIM_hp += ssim_hp/num_classes
        MSE_hp  += mse_hp/num_classes

#     eval_all_metrics(model_name, xai_name, None)
    print('============= averaged results of {} ============= '.format(xai_name))
    print('SSIM_hp: ',SSIM_hp)
    print('MSE_hp: ',MSE_hp)

### Vis-PRED

In [None]:
### 5 xai
'''load predictions'''
img_index = 124
class_index = 1
model_name = 'densenet121' #'resnet152'
src_xai_name = 'IntegratedBP' #['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
src_img_dir = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred/' 
folder_name = model_name+'_'+src_xai_name+'_class_'+str(class_index)


'''load images'''
img_org = np.load(os.path.join(src_img_dir, folder_name, 'org_imgs.npy'))
img_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_imgs.npy'))
pred_org = np.load(os.path.join(src_img_dir, folder_name, 'org_pred.npy'))
pred_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_pred.npy'))
label = np.load(os.path.join(src_img_dir, folder_name, 'org_labels.npy'))
print('label: ', label[img_index])

In [None]:
### 5 xai
'''load predictions'''
img_index = 124
class_index = 1
model_name = 'densenet121' #'resnet152'
src_xai_name = 'IntegratedBP' #['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
src_img_dir = '/data/users/Attack_Attn/save_dir/NatComm/attack_pred/' 
folder_name = model_name+'_'+src_xai_name+'_class_'+str(class_index)


'''load images'''
img_org = np.load(os.path.join(src_img_dir, folder_name, 'org_imgs.npy'))
img_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_imgs.npy'))
pred_org = np.load(os.path.join(src_img_dir, folder_name, 'org_pred.npy'))
pred_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_pred.npy'))
label = np.load(os.path.join(src_img_dir, folder_name, 'org_labels.npy'))
print('label: ', label[img_index])


'''load heatmap'''
for xai_name in ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']:
    save_dir = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_pred/' 

    folder_name = model_name+'_'+xai_name+'_class_'+str(class_index)
    hm_org = np.load(os.path.join(save_dir, folder_name, 'hm_org.npy'))
    hm_att = np.load(os.path.join(save_dir, folder_name, 'hm_att.npy'))

    save_img_dir = os.path.join(save_dir, 'save_img')
    os.makedirs(save_img_dir, exist_ok=True)
    print(hm_att.shape, img_att.shape)
    
    
    '''attack'''
    output = torch.Tensor(pred_att[img_index])
    print(output.sigmoid())
    print(label[img_index])

    if xai_name == 'GradCAM':
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=0.88) # norm grad to speial range
        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)
    elif xai_name == 'XRAI':
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap, 1.0) # norm grad to speial range
        heatmap[np.where(heatmap<=0.6)] = 0.0
        heatmap[np.where(heatmap<=0.6)] = 0.01# XRAI

        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)                     # re-norm to [0,1]
        cmap[np.where(cmap<=0.3)] = 0.00
    else:
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=4.7) # norm grad to speial range
        cmap = cm.plasma(heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)

    base_img0 = img_att[img_index, 0, :, :]
    base_img = np.zeros_like(cmap.squeeze())
    base_img[:,:,0], base_img[:,:,1], base_img[:,:,2] = base_img0, base_img0, base_img0


    ## load original image
    plt.figure(figsize=(6,6))
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap

    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight

    fig = plt.imshow(img_fused)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    save_loc = os.path.join(save_img_dir, folder_name+'_index_'+str(img_index)+'_org.png')
    plt.savefig(save_loc, bbox_inches='tight', pad_inches = 0)
    
    
    '''original image'''
    output = torch.Tensor(pred_org[img_index])
    print(output.sigmoid())

    if xai_name == 'GradCAM':
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=0.88) # norm grad to speial range
        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)
    elif xai_name == 'XRAI':
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap, 1.0) # norm grad to speial range
        heatmap[np.where(heatmap<=0.95)] = 0.0
        heatmap[np.where(heatmap<=0.6)] = 0.01# XRAI

        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)                     # re-norm to [0,1]
        cmap[np.where(cmap<=0.3)] = 0.00
    else:
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=4.7) # norm grad to speial range
        cmap = cm.plasma(heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)

    base_img0 = img_org[img_index, 0, :, :]
    base_img = np.zeros_like(cmap.squeeze())
    base_img[:,:,0], base_img[:,:,1], base_img[:,:,2] = base_img0, base_img0, base_img0


    ## load original image
    plt.figure(figsize=(6,6))
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap

    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight

    fig = plt.imshow(img_fused)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    save_loc = os.path.join(save_img_dir, folder_name+'_index_'+str(img_index)+'_adv.png')
    plt.savefig(save_loc, bbox_inches='tight', pad_inches = 0)

In [None]:
## load original image
plt.figure(figsize=(6,6))
alpha = .99
htmp_weight = np.zeros_like(cmap.squeeze())
print(htmp_weight.shape)
htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap

img_fused = base_img

fig = plt.imshow(img_fused)
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)

save_loc = os.path.join(save_img_dir, folder_name+'_index_'+str(img_index)+'_img.png')
plt.savefig(save_loc, bbox_inches='tight', pad_inches = 0)

## EXPL

In [None]:
def load_all(output_dir, model_name, xai_name, ii_specific=None):
#     model_name = model_cfgs['src_models'][0]['model_info']['model_name']
#     xai_name = model_cfgs['src_models'][0]['exp_method']
    img_org, img_att = [], []
    pred_org, pred_att, label = [], [], []
    hm_org, hm_att = [], []
    
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            '''load predictions'''
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))
            '''load heat map'''
            hm_org.append(np.load(os.path.join(save_dir, 'org_hm.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'adv_hm.npy')))
            '''load images'''
            img_org.append(np.load(os.path.join(save_dir, 'org_imgs.npy')))
            img_att.append(np.load(os.path.join(save_dir, 'adv_imgs.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)

        img_org = np.concatenate(img_org, 0)
        img_att = np.concatenate(img_att, 0)
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        '''load predictions'''
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        '''load heat map'''
        hm_org = np.load(os.path.join(save_dir, 'org_hm.npy'))
        hm_att = np.load(os.path.join(save_dir, 'adv_hm.npy'))
        '''load images'''
        img_org = np.load(os.path.join(save_dir, 'org_imgs.npy'))
        img_att = np.load(os.path.join(save_dir, 'adv_imgs.npy')) 
    
    return [img_org, img_att], [hm_org, hm_att], [pred_org, pred_att, label]

In [None]:
xai_namelist = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
src_xai = 'IntegratedBP'
model_name_list = ['densenet121'] #'resnet152']#, 

label_dict = {'0': [1., 0., 0., 0., 0.],
              '1': [0., 1., 0., 0., 0.],
              '2': [0., 0., 1., 0., 0.],
              '3': [0., 0., 0., 1., 0.],
              '4': [0., 0., 0., 0., 1.]}

cfgs['data_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl/'
cfgs['save_dir'] = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_expl'


In [None]:
for model_name in model_name_list:
    print('---------------- model: {} ------------------'.format(model_name))
    for xai_name in xai_namelist:
        print('=============== {} ==============='.format(xai_name))
        # xai_name = xai_namelist[0]
        '''init model'''
        if model_name == 'densenet121':
            if xai_name in ['GradCAM', 'GuidedGradCAM']:
                dense121_cfgs = {
                    'model_info':{
                        'model_name': model_name,
                        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
                    },
                    'exp_method': xai_name,
                    'exp_cfgs':{
                        'target_layer': ['features','denseblock4','denselayer14'],
                    }
                }
            else:
                dense121_cfgs = {
                    'model_info':{
                        'model_name': model_name,
                        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/densenet121_5c/best_checkpoints/checkpoint_9.pt',
                    },
                    'exp_method': xai_name,
                    'exp_cfgs':{
                #         'target_layer': ['features','denseblock4','denselayer14'],
                    }
                }
            model_cfgs = {}
            model_cfgs['src_models']=[
                dense121_cfgs,
            ]
            model_cfgs['pretrained'] = False
            cfgs['model'] = model_cfgs

        elif model_name == 'resnet152':
            if xai_name in ['GradCAM', 'GuidedGradCAM']:
                res152_cfgs = {
                    'model_info':{
                        'model_name': model_name,
                        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/resnet152_5c/best_checkpoints/checkpoint_5.pt',
                    },
                    'exp_method': xai_name,
                    'exp_cfgs':{
                        'target_layer': ['layer4','2'],
                    }
                }
            else:
                res152_cfgs = {
                    'model_info':{
                        'model_name': model_name,
                        'model_path': '/home/Attack_Attn/ChestXpert/save_dir_multimodel/resnet152_5c/best_checkpoints/checkpoint_5.pt',
                    },
                    'exp_method': xai_name,
                    'exp_cfgs':{
    #                     'target_layer': ['layer4','2'],
                    }
                }
            model_cfgs = {}
            model_cfgs['src_models']=[
                res152_cfgs,
            ]
            model_cfgs['pretrained'] = False
            cfgs['model'] = model_cfgs


        source_exp_models = [expmodel_factory(model_cfgs, cfgs) for model_cfgs in cfgs['model']['src_models']]

        for ii in np.arange(5): # 5 different classes
            # ii = 0
            print('++++++ calss number: {} ++++++'.format(ii))
            '''load data from IG folder'''
            hm_org, hm_att = [], []
            img_, _, pred_ = load_all(output_dir=cfgs['data_dir'], 
                                    model_name=model_cfgs['src_models'][0]['model_info']['model_name'], 
                                    xai_name=src_xai, 
                                    ii_specific=ii)
            '''select attacked images'''
            att_index = np.where(pred_[2][:,ii]==1)[0]
            
#             image_num, _, _, _ = img_[0][att_index].shape
            save_dir = os.path.join(cfgs['save_dir'], 
                                    model_cfgs['src_models'][0]['model_info']['model_name']+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            class_of_interest = label_dict[str(ii)]
            os.makedirs(save_dir, exist_ok=True)

            for img_index in att_index: #np.arange(image_num):
                print('---- img_index: {} ----'.format(img_index))
                img_org = torch.Tensor(img_[0][img_index]).clone().detach()
                img_adv = torch.Tensor(img_[1][img_index]).clone().detach()

                res_org = [_.cal_exp_map(img_org.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]
                res_adv = [_.cal_exp_map(img_adv.to(torch.device(cfgs['device'])).unsqueeze(0), class_of_interest) for _ in source_exp_models]

                if xai_name == 'XRAI':
                    hm_org.append(torch.Tensor(res_org[0][0]).unsqueeze_(0))
                    hm_att.append(torch.Tensor(res_adv[0][0]).unsqueeze_(0))
                else:
                    hm_org.append(res_org[0][0].data.cpu())
                    hm_att.append(res_adv[0][0].data.cpu())
                del res_org, res_adv
                torch.cuda.empty_cache()

            '''save npy images'''
            hm_org = torch.cat(hm_org, dim=0).unsqueeze_(1)
            hm_att = torch.cat(hm_att, dim=0).unsqueeze_(1)
            print(hm_org.shape)

            np.save(os.path.join(save_dir, 'hm_org.npy'), hm_org.numpy())
            np.save(os.path.join(save_dir, 'hm_att.npy'), hm_att.numpy())

### Eval-EXPL

In [None]:
def compute_jsd(pred_org, pred_adv, class_of_interest):
    
    pred_org, pred_adv = torch.Tensor(pred_org), torch.Tensor(pred_adv)
    if len(pred_org.shape) == 4: #image data
#         print('image')
        prob_org, prob_adv = torch.flatten(pred_org, start_dim=1), torch.flatten(pred_adv, start_dim=1)
        prob_mean = torch.clamp((prob_org + prob_adv) / 2., 1e-7, 1).log()
#         print(prob_mean.shape)
        kld1 = F.kl_div(prob_mean, prob_org, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv, reduction="none").sum(1)
    else: # logits data
#         print('logits')
        prob_org, prob_adv = torch.sigmoid(pred_org), torch.sigmoid(pred_adv)
        prob_org_intr, prob_adv_intr = prob_org[:,class_of_interest].unsqueeze(1), prob_adv[:,class_of_interest].unsqueeze(1)

        prob_org_binary = torch.cat([prob_org_intr, 1-prob_org_intr], dim=1)
        prob_adv_binary = torch.cat([prob_adv_intr, 1-prob_adv_intr], dim=1)
        prob_mean = torch.clamp((prob_org_binary + prob_adv_binary) / 2., 1e-7, 1).log()

        kld1 = F.kl_div(prob_mean, prob_org_binary, reduction="none").sum(1)
        kld2 = F.kl_div(prob_mean, prob_adv_binary, reduction="none").sum(1)
    jsd = (kld1 + kld2) * 0.5
    return jsd


def eval_img_similarity(imgs_1, imgs_2):
    assert len(imgs_1.shape) == len(imgs_2.shape) == 3
    ssim_vals = []
    pcc_vals = []
    imgs_1 = batch_img_norm(imgs_1) * 255
    imgs_2 = batch_img_norm(imgs_2) * 255
    for i in range(imgs_1.shape[0]):
        _img1 = imgs_1[i]
        _img2 = imgs_2[i]
        if np.isnan(_img1.mean()) or np.isnan(_img2.mean()):
            print(i)
            continue
        
        # SSIM
        _ssim, _ = ssim(_img1, _img2, data_range=255, full=True, multichannel=False)
        ssim_vals.append(_ssim)
        
        # PCC
        _pcc = pearsonr(_img1.reshape(-1), _img2.reshape(-1))[0]
        pcc_vals.append(_pcc)
    
    # MSE
    mse_vals = np.nanmean(((imgs_1/255 - imgs_2/255)**2).mean(-1))
    return {
        'ssim': np.asarray(ssim_vals),
        'mse': mse_vals,
        'pcc': np.asarray(pcc_vals),
    }


def eval_all_metrics(output_dir, model_name, xai_name, ii_specific=None):
    hm_org, hm_att = load_all_tgt(output_dir, model_name, xai_name, ii_specific) 
    print(hm_org.shape, hm_att.shape)
    if ii_specific is None:
        print('+++++++++++++++++++++ all images ++++++++++++++++++++')
    else:
        print('+++++++++++++++++++++ class {} ++++++++++++++++++++'.format(ii_specific))
    
    print('=== Heat Map ===')
    attack_target_maps = np.zeros_like(hm_org[:,0,:,:])
    attack_target_maps[:, 130:190, 380:440] = 1.0/3600 # binary mask
    res_org = eval_img_similarity(hm_org[:,0,:,:], hm_att[:,0,:,:])
    res_tgt = eval_img_similarity(attack_target_maps, hm_att[:,0,:,:])
    
    for _k in res_org.keys():
#         print(res_org[_k].shape)
        _v = np.nanmean(res_org[_k])
#         print(_k, _v)
        if _k == 'ssim': ssim_hp_org = _v
        if _k == 'mse': mse_hp_org = _v
    
    for _k in res_tgt.keys():
        _v = np.nanmean(res_tgt[_k])
#         print(_k, _v)
        if _k == 'ssim': ssim_hp_tgt = _v
        if _k == 'mse': mse_hp_tgt = _v
            
    return ssim_hp_org, ssim_hp_tgt, mse_hp_org, mse_hp_tgt



def load_all_src(output_dir, model_name, xai_name, ii_specific=None):
    '''load predictions'''
    hm_org, hm_att = [], []
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            pred_org.append(np.load(os.path.join(save_dir, 'org_pred.npy')))
            pred_att.append(np.load(os.path.join(save_dir, 'adv_pred.npy')))
            label.append(np.load(os.path.join(save_dir, 'org_labels.npy')))

        pred_org = np.concatenate(pred_org, 0)
        pred_att = np.concatenate(pred_att, 0)
        label = np.concatenate(label, 0)
        
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        pred_org = np.load(os.path.join(save_dir, 'org_pred.npy'))
        pred_att = np.load(os.path.join(save_dir, 'adv_pred.npy'))
        label = np.load(os.path.join(save_dir, 'org_labels.npy'))
        
        '''select attacked images'''
        att_index = np.where(label[:,ii_specific]==1)[0]
        pred_org = pred_org[att_index]
        pred_att = pred_att[att_index]
        label = label[att_index]
    
    return pred_org, pred_att, label


def load_all_tgt(output_dir, model_name, xai_name, ii_specific=None):
    
    hm_org, hm_att = [], []
    if ii_specific is None:
        for ii in np.arange(5):
            save_dir = os.path.join(output_dir, 
                                    model_name+'_'+\
                                    xai_name+\
                                    '_class_'+str(ii))
            hm_org.append(np.load(os.path.join(save_dir, 'hm_org.npy')))
            hm_att.append(np.load(os.path.join(save_dir, 'hm_att.npy')))

        hm_org = np.concatenate(hm_org, 0)
        hm_att = np.concatenate(hm_att, 0)
        
    else:
        save_dir = os.path.join(output_dir, 
                                model_name+'_'+\
                                xai_name+\
                                '_class_'+str(ii_specific))
        hm_org = np.load(os.path.join(save_dir, 'hm_org.npy'))
        hm_att = np.load(os.path.join(save_dir, 'hm_att.npy'))
    
    return hm_org, hm_att

In [None]:
# xai_names = ['VanillaBP']
num_classes = len(valid_dataloader.dataset.attr_names)
xai_names = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
model_name_list = ['resnet152', 'densenet121']
pred_att_dir_tgt = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_expl'

for model_name in model_name_list:
    print('Model: ', model_name)
    for xai_name in xai_names:
        print('-------------------------------- {} --------------------------------'.format(xai_name))
        SSIM_hp_org, SSIM_hp_tgt, MSE_hp_org, MSE_hp_tgt = 0,0,0,0
        for ii in np.arange(num_classes):
            ssim_hp_org, ssim_hp_tgt, mse_hp_org, mse_hp_tgt = eval_all_metrics(pred_att_dir_tgt,model_name,xai_name,ii)
            SSIM_hp_org += ssim_hp_org/num_classes
            SSIM_hp_tgt += ssim_hp_tgt/num_classes
            MSE_hp_org  += mse_hp_org/num_classes
            MSE_hp_tgt  += mse_hp_tgt/num_classes

    #     eval_all_metrics(model_name, xai_name, None)
        print('============= averaged results of {} ============= '.format(xai_name))
        print('SSIM origin: ',SSIM_hp_org)
        print('SSIM target: ',SSIM_hp_tgt)
        print('MSE origin: ',MSE_hp_org)
        print('MSE target: ',MSE_hp_tgt)

In [None]:
num_classes = len(valid_dataloader.dataset.attr_names)
model_list = ['resnet152', 'densenet121']
xai_list = ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']
pred_att_dir_src = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl'
pred_att_dir_tgt = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_expl'
xai_src = 'IntegratedBP'


for model_ in model_list:
#     model_='densenet121'
        print('++++++++++++++++ Model: {} ++++++++++++++++'.format(model_))
        for xai_ in xai_list:
    # xai_='GradCAM'
            print('========= {} ========='.format(xai_))
            psc_pred = 0
            for class_index in np.arange(num_classes):
                # class_index=0

                '''load IG pred and compute JSD'''
                pred_org, pred_att, label = load_all_src(output_dir=pred_att_dir_src, 
                                                         model_name=model_, 
                                                         xai_name=xai_src, 
                                                         ii_specific=class_index)
                '''eval pred attack'''
                hm_org, hm_att = load_all_tgt(output_dir=pred_att_dir_tgt, 
                                              model_name=model_, 
                                              xai_name=xai_, 
                                              ii_specific=class_index)

                pred_lab_jsd = compute_jsd(pred_org=pred_org, 
                                           pred_adv=pred_att, 
                                           class_of_interest=class_index)

                pred_hm_jsd = compute_jsd(pred_org=hm_org, 
                                          pred_adv=hm_att, 
                                          class_of_interest=class_index)

                pred_lab_jsd_np = np.asarray(pred_lab_jsd.numpy())
                pred_hm_jsd_np = np.asarray(pred_hm_jsd.numpy())
                pred_hm_jsd_np[np.isnan(pred_hm_jsd_np)] = np.nanmean(pred_hm_jsd_np)
                psc_pred += pearsonr(pred_lab_jsd_np, pred_hm_jsd_np)[0]/num_classes
            print('pred psc: {}'.format(psc_pred))
            

### Vis-EXPL

In [None]:
### 5 xai
'''load predictions'''
img_index = 16
class_index = 1
model_name = 'densenet121' #'resnet152'
src_xai_name = 'IntegratedBP' #['VanillaBP', 'VanillaBP_Img', 'IntegratedBP', 'GradCAM', 'SmoothBP']
src_img_dir = '/data/users/Attack_Attn/save_dir/NatComm/attack_expl/' 
folder_name = model_name+'_'+src_xai_name+'_class_'+str(class_index)


'''load images'''
img_org = np.load(os.path.join(src_img_dir, folder_name, 'org_imgs.npy'))
img_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_imgs.npy'))
pred_org = np.load(os.path.join(src_img_dir, folder_name, 'org_pred.npy'))
pred_att = np.load(os.path.join(src_img_dir, folder_name, 'adv_pred.npy'))
label = np.load(os.path.join(src_img_dir, folder_name, 'org_labels.npy'))

'''select attacked images'''
att_index = np.where(label[:,class_index]==1)[0]
img_org = img_org[att_index]
img_att = img_att[att_index] 
pred_org = pred_org[att_index]
pred_att = pred_att[att_index]
label = label[att_index]
# print(img_org.shape, img_att.shape)

# print('label: ', label[img_index])

'''load heatmap'''
for xai_name in ['VanillaBP', 'VanillaBP_Img', 'GradCAM', 'GuidedGradCAM', 'SmoothBP', 'XRAI']:
    save_dir = '/data/users/Attack_Attn/save_dir/NatComm/ig_attack_expl/' 

    folder_name = model_name+'_'+xai_name+'_class_'+str(class_index)
    hm_org = np.load(os.path.join(save_dir, folder_name, 'hm_org.npy'))
    hm_att = np.load(os.path.join(save_dir, folder_name, 'hm_att.npy'))

    save_img_dir = os.path.join(save_dir, 'save_img')
    os.makedirs(save_img_dir, exist_ok=True)
#     print(hm_att.shape, img_att.shape)
    
    '''attack'''
    output = torch.Tensor(pred_att[img_index])
    print(output.sigmoid())
    print(label[img_index])

    if xai_name == 'GradCAM':
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=0.88) # norm grad to speial range
        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)
    elif xai_name == 'XRAI':
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap, 1.0) # norm grad to speial range
        heatmap[np.where(heatmap<=0.98)] = 0.0
        heatmap[np.where(heatmap<=0.7)] = 0.01# XRAI

        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)                     # re-norm to [0,1]
        cmap[np.where(cmap<=0.3)] = 0.00
    else:
        heatmap = hm_att[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=4.7) # norm grad to speial range
        cmap = cm.plasma(heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)

    base_img0 = img_att[img_index, 0, :, :]
    base_img = np.zeros_like(cmap.squeeze())
    base_img[:,:,0], base_img[:,:,1], base_img[:,:,2] = base_img0, base_img0, base_img0


    ## load original image
    plt.figure(figsize=(6,6))
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap

    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight

    fig = plt.imshow(img_fused)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    save_loc = os.path.join(save_img_dir, folder_name+'_index_'+str(img_index)+'_org.png')
    plt.savefig(save_loc, bbox_inches='tight', pad_inches = 0)
    
    
    '''original image'''
    output = torch.Tensor(pred_org[img_index])
    print(output.sigmoid())

    if xai_name == 'GradCAM':
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=0.88) # norm grad to speial range
        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)
    elif xai_name == 'XRAI':
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap, 1.0) # norm grad to speial range
        heatmap[np.where(heatmap<=0.82)] = 0.0
        heatmap[np.where(heatmap<=0.7)] = 0.01# XRAI

        cmap = cm.jet_r(1-heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)                     # re-norm to [0,1]
        cmap[np.where(cmap<=0.3)] = 0.00
    else:
        heatmap = hm_org[img_index, 0, :, :]
        heatmap = img_norm(heatmap,k=4.7) # norm grad to speial range
        cmap = cm.plasma(heatmap)[..., :3]     # color map proj
        cmap = img_norm(cmap)

    base_img0 = img_org[img_index, 0, :, :]
    base_img = np.zeros_like(cmap.squeeze())
    base_img[:,:,0], base_img[:,:,1], base_img[:,:,2] = base_img0, base_img0, base_img0


    ## load original image
    plt.figure(figsize=(6,6))
    alpha = .99
    htmp_weight = np.zeros_like(cmap.squeeze())
    print(htmp_weight.shape)
    htmp_weight[:,:,0], htmp_weight[:,:,1], htmp_weight[:,:,2] = heatmap, heatmap, heatmap

    img_fused = base_img*(1-htmp_weight) + cmap*alpha*htmp_weight

    fig = plt.imshow(img_fused)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    save_loc = os.path.join(save_img_dir, folder_name+'_index_'+str(img_index)+'_adv.png')
    plt.savefig(save_loc, bbox_inches='tight', pad_inches = 0)