In [None]:
from metrics import pic
import numpy as np
from methods.utils import VisionSensitivityN
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
from visualization import HeatmapVisualizer, visualize_original, visualize, visualize_softmax
from torchvision.models import efficientnet_b0,resnet50
import torch.nn as nn
import torch
from tqdm.notebook import tqdm
from datasets import load_test_dataset, get_dataset
from methods import big_pipeline, ig_pipeline
from methods import sm_pipeline, ma2norm_b4_softmax_pipeline, ma2norm_after_softmax_pipeline, ma2cos_sign_b4_softmax_pipeline, ma2cos_sign_after_softmax_pipeline, ma2cos_without_sign_b4_softmax_pipeline, ma2cos_without_sign_after_softmax_pipeline, ma2ba_sign_b4_softmax_pipeline, ma2ba_sign_after_softmax_pipeline, ma2ba_without_sign_b4_softmax_pipeline, ma2ba_without_sign_after_softmax_pipeline, dl_pipeline
my_font = fm.FontProperties(fname="fonts/SimHei.ttf")
mask_viz = HeatmapVisualizer(blur=7, normalization_type="signed_max")
device = "cuda" if torch.cuda.is_available() else "cpu"
from multiprocessing.pool import ThreadPool

from methods.utils import VisionInsertionDeletion
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(3407)
mean = np.array((0.485, 0.456, 0.406))
std = np.array((0.229, 0.224, 0.225))
import pickle
# import pickle
with open("attribution_map_ma2ba_without_sign_after_softmax_resnet50.pkl","rb") as f:
    attribution_map_ma2ba_without_sign_after_softmax = pickle.load(f)

In [None]:
def show_curve_xy(x, y, title='PIC', label=None, color='blue',
    ax=None):
  if ax is None:
    fig, ax = plt.subplots(figsize=(12, 6))
  auc = np.trapz(y) / y.size
  label = f'{label}, AUC={auc:.3f}'
  ax.plot(x, y, label=label, color=color)
  ax.set_title(title)
  ax.set_xlim([0.0, 1.0])
  ax.set_ylim([0.0, 1.0])
  ax.legend()


def show_curve(compute_pic_metric_result, title='PIC', label=None, color='blue',
    ax=None):
  show_curve_xy(compute_pic_metric_result.curve_x,
                compute_pic_metric_result.curve_y, title=title, label=label,
                color=color,
                ax=ax)

In [None]:
dataloader, data_min, data_max = get_dataset("imagenet", 2)

In [None]:
model = resnet50(pretrained=True)
model.to(device)
model.eval()
1

In [None]:
vision_insertion_deletion = VisionInsertionDeletion(model,pixel_batch_size=128,sigma=15)

In [None]:
def get_vid(attribution_map):
    data = attribution_map[0]
    label = attribution_map[1]
    attribution = attribution_map[2]
    attribution = np.array(attribution)
    data = np.array(data)
    if len(attribution.shape) == 3:
        attribution = attribution[np.newaxis, ...]
    if len(data.shape) == 3:
        data = data[np.newaxis, ...]
    im_, mask = mask_viz(attribution, data, overlay_opacity=0.5,
                         imshow=False, return_tiled=True)
    mask = (mask - mask.min()) / (mask.max() - mask.min()+1e-10)
    vid = (vision_insertion_deletion.evaluate(heatmap=torch.from_numpy(mask).to(device), input_tensor=torch.from_numpy(
        data.squeeze()).to(device), target=torch.from_numpy(np.array(label)).to(device)))
    return vid['ins_auc'],vid['del_auc']

In [None]:
def calculate_vid_parallel(attribution,num_workers=8):
    data = attribution["data"] # list
    target = attribution["label"] # list
    attribution = attribution["attribution"] # list
    # pool = ThreadPool(num_workers)
    # results = pool.map(get_vid,list(zip(data,target,attribution)))
    # pool.close()
    # pool.join()
    dt = list(zip(data,target,attribution))
    results = []
    for d in tqdm(dt):
        results.append(get_vid(d))
    results = np.array(results)
    print("ins_auc",results[:,0].mean())
    print("del_auc",results[:,1].mean())
    return results

In [None]:
with tqdm(total=200) as pbar:
    calculate_vid_parallel(attribution_map_ma2ba_without_sign_after_softmax)

In [None]:
# Define prediction function.
def create_predict_function_softmax(class_idx):
  """Creates the model prediction function that can be passed to compute_pic_metric method.

    The function returns the softmax value for the Softmax Information Curve.
  Args:
    class_idx: the index of the class for which the model prediction should
      be returned.
  """

  def predict(image_batch):
    """Returns model prediction for a batch of images.

    The method receives a batch of images in uint8 format. The method is responsible to
    convert the batch to whatever format required by the model. In this particular
    implementation the conversion is achieved by calling preprocess_input().

    Args:
      image_batch: batch of images of dimension [B, H, W, C].

    Returns:
      Predictions of dimension [B].
    """
    # print(image_batch)
    image_batch = image_batch / 255
    image_batch -= mean
    image_batch /= std
    # print(image_batch.shape)
    image_batch = image_batch.transpose(0,3,1,2)
    image_batch = torch.from_numpy(image_batch).float().to(device)    
    score = model(image_batch)[:, class_idx]
    return score.cpu().detach().numpy()

  return predict


In [None]:
from multiprocessing.pool import ThreadPool
from tqdm.notebook import tqdm
import time

def process_data_(data, target, attribution):
    set_seed(3407)
    img = data
    std = np.array((0.229, 0.224, 0.225))
    mean = np.array((0.485, 0.456, 0.406))
    img = ((img.transpose(1,2,0) * std + mean) * 255).astype(np.uint8)
    tgt = target
    attribution_map = attribution
    attribution_map = np.abs(np.sum(attribution_map, axis=0))
    saliency_thresholds = [0.5,0.75]
    random_mask = pic.generate_random_mask(image_height=img.shape[0],
                                   image_width=img.shape[1],
                                   fraction=0.01)
    pred_func_sic = create_predict_function_softmax(tgt)
    try:
        result_sic = pic.compute_pic_metric(img=img,
                                            saliency_map=attribution_map,
                                            random_mask=random_mask,
                                            pred_func=pred_func_sic,
                                            min_pred_value=0.5,
                                            saliency_thresholds=saliency_thresholds,
                                            keep_monotonous=True,
                                            num_data_points=1000)
    except:
        return 0
    pbar.update()
    return result_sic.auc

def process_data(data):
    return process_data_(data[0],data[1],data[2])

def calculate_auc_parallel(attribution,num_workers=8):
    data = attribution["data"] # list
    target = attribution["label"] # list
    attribution = attribution["attribution"] # list
    pool = ThreadPool(num_workers)
    results = pool.map(process_data,list(zip(data,target,attribution)))
    pool.close()
    pool.join()
    return results

In [None]:
with tqdm(total=200) as pbar:
    results_ma2ba_without_sign_after_softmax = calculate_auc_parallel(attribution_map_ma2ba_without_sign_after_softmax)
np.sum(results_ma2ba_without_sign_after_softmax) / 200