In [None]:
import numpy as np
from methods.utils import VisionSensitivityN
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from visualization import HeatmapVisualizer, visualize_original, visualize, visualize_softmax
from torchvision.models import efficientnet_b0
import torch.nn as nn
import torch
from datasets import load_test_dataset, get_dataset
from methods import big_pipeline, ig_pipeline
from methods import sm_pipeline, ma2norm_b4_softmax_pipeline, ma2norm_after_softmax_pipeline, ma2cos_sign_b4_softmax_pipeline, ma2cos_sign_after_softmax_pipeline, ma2cos_without_sign_b4_softmax_pipeline, ma2cos_without_sign_after_softmax_pipeline, ma2ba_sign_b4_softmax_pipeline, ma2ba_sign_after_softmax_pipeline, ma2ba_without_sign_b4_softmax_pipeline, ma2ba_without_sign_after_softmax_pipeline, dl_pipeline
my_font = fm.FontProperties(fname="fonts/SimHei.ttf")
mask_viz = HeatmapVisualizer(blur=7, normalization_type="signed_max")
device = "cuda" if torch.cuda.is_available() else "cpu"
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(3407)

In [None]:
dataloader, data_min, data_max = get_dataset("imagenet", 32)


In [None]:
model = efficientnet_b0(pretrained=True)
model.to(device)
model.eval()
1


In [None]:
vision_sensitivity = VisionSensitivityN(
    model, (224, 224), 112*112, num_masks=100)


In [None]:
from tqdm.notebook import tqdm
# result = 0
count = 0
attribution_map_ma2ba_without_sign_after_softmax = {
    "attribution": [],
    "data": [],
    "label": [],
}
pbar = tqdm(total=200)
# all_tt = 0
for data, target in tqdm(dataloader,total=len(dataloader)):
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    if len(data) == 0:
        continue
    attr_ma2ba_without_sign_after_softmax,_ = ma2ba_without_sign_after_softmax_pipeline(model, data, target,data_min,data_max)
    # all_tt += tt
    # print(len(data))
    for i in range(len(data)):
        attribution_map_ma2ba_without_sign_after_softmax["attribution"].append(attr_ma2ba_without_sign_after_softmax[i])
        attribution_map_ma2ba_without_sign_after_softmax["data"].append(data[i].cpu().detach().numpy())
        attribution_map_ma2ba_without_sign_after_softmax["label"].append(target[i].cpu().detach().numpy())
        count += 1
        pbar.update(1)
        if count == 200:
            break
    if count == 200:
        break


In [None]:
import pickle
with open("attribution_map_ma2ba_without_sign_after_softmax.pkl","wb") as f:
    pickle.dump(attribution_map_ma2ba_without_sign_after_softmax,f)

In [None]:
import pickle
with open("attribution_map_ma2ba_without_sign_after_softmax.pkl","rb") as f:
    attribution_map_ma2ba_without_sign_after_softmax = pickle.load(f)

In [None]:
def get_sensitivity(attribution_map, idx):
    i = idx
    attribution = attribution_map["attribution"][i]
    data = attribution_map["data"][i]
    label = attribution_map["label"][i]
    attribution = np.array(attribution)
    data = np.array(data)
    if len(attribution.shape) == 3:
        attribution = attribution[np.newaxis, ...]
    if len(data.shape) == 3:
        data = data[np.newaxis, ...]
    im_, mask = mask_viz(attribution, data, overlay_opacity=0.5,
                         imshow=False, return_tiled=True)
    # normalize mask
    mask = (mask - mask.min()) / (mask.max() - mask.min()+1e-10)
    # plt.figure(figsize=(10, 10))
    # plt.imshow(im_)
    # plt.imshow(mask, alpha=0.5, cmap="jet")
    # plt.show()
    sen = (vision_sensitivity.evaluate(heatmap=torch.from_numpy(mask).to(device), input_tensor=torch.from_numpy(
        data.squeeze()).to(device), target=torch.from_numpy(np.array(label)).to(device), calculate_corr=True))
    if np.isnan(sen['correlation'][1, 0]):
        sen = 0
    else:
        sen = sen['correlation'][1, 0]
    return sen


In [None]:
sum([get_sensitivity(attribution_map_ma2ba_without_sign_after_softmax, i)
    for i in range(200)])/200