In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Loading necessary packages
import os
import math
import glob
import numpy as np
import scipy.misc
import scipy.ndimage
import random
import pyemd
from tqdm import tqdm
from collections import OrderedDict
from keras.preprocessing import image
import matplotlib.pyplot as plt


from PIL import Image
import time

#Takes an array with paths of images as input, converts the images after accessing them
#from the given paths to numpy arrays
#Returns an array with all converted images appended to it
def calc_all_maps(images):
    assert(len(images) != 0)
    all_map = None
    for ind, image in enumerate(images):
        im1 = Image.open(image)
        im = np.asarray(im1)
        if ind == 0:
            all_map = np.zeros(im.shape, dtype=np.float)
        all_map += im
    return all_map

#Calculate NSS value from predicted saliency maps and binary fixation maps as inputs
def nss(pred_sal, fix_map):
    
    print(pred_sal.shape)
    print(fix_map.shape)

    fix_map = fix_map.astype(np.bool)
    if pred_sal.shape != fix_map.shape:
      pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
    pred_sal = (pred_sal - np.mean(pred_sal)) / np.std(pred_sal)
    return np.mean(pred_sal[fix_map])

#Calculate AUC_JUDD value from predicted saliency maps and binary fixation maps as inputs
def auc_judd(pred_sal, fix_map, jitter=True):
    if pred_sal.shape != fix_map.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
    fix_map = fix_map.flatten().astype(np.bool)
    pred_sal = pred_sal.flatten().astype(np.float)
    if jitter:
        jitter = np.random.rand(pred_sal.shape[0]) / 1e7
        pred_sal += jitter
    pred_sal = (pred_sal - pred_sal.min())/(pred_sal.max() - pred_sal.min())
    all_thres = np.sort(pred_sal[fix_map])[::-1]

    tp = np.concatenate([[0], np.linspace(0.0, 1.0, all_thres.shape[0]), [1]])
    fp = np.zeros((all_thres.shape[0]))
    sorted_sal = np.sort(pred_sal)
    for ind, thres in enumerate(all_thres):
        above_thres = sorted_sal.shape[0] - sorted_sal.searchsorted(thres, side='left')
        fp[ind] = (above_thres-ind) * 1. / (pred_sal.shape[0] - all_thres.shape[0])
    fp = np.concatenate([[0], fp, [1]])
    return np.trapz(tp, fp)

#Calculate AUC_BORJI value from predicted saliency maps and binary fixation maps as inputs
def auc_borji(pred_sal, fix_map, n_split=100, step_size=.1):
    if pred_sal.shape != fix_map.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
        
    fix_map = fix_map.flatten().astype(np.bool)
    pred_sal = pred_sal.flatten().astype(np.float)
    pred_sal = (pred_sal - pred_sal.min())/(pred_sal.max() - pred_sal.min())
    sal_fix = pred_sal[fix_map]
    sorted_sal_fix = np.sort(sal_fix)
    
    r = np.random.randint(0, pred_sal.shape[0], (sal_fix.shape[0], n_split))
    rand_fix = pred_sal[r]
    auc = np.zeros((n_split))
    for i in list(range(n_split)):
        cur_fix = rand_fix[:,i]
        sorted_cur_fix = np.sort(cur_fix)
        max_val = np.maximum(cur_fix.max(), sal_fix.max())
        tmp_all_thres = np.arange(0, max_val, step_size)[::-1]
        tp = np.zeros((tmp_all_thres.shape[0]))
        fp = np.zeros((tmp_all_thres.shape[0]))
        for ind, thres in enumerate(tmp_all_thres):
            tp[ind] = (sorted_sal_fix.shape[0] - sorted_sal_fix.searchsorted(thres, side='left'))*1./sal_fix.shape[0]
            fp[ind] = (sorted_cur_fix.shape[0] - sorted_cur_fix.searchsorted(thres, side='left'))*1./sal_fix.shape[0]
        tp = np.concatenate([[0], tp, [1]])
        fp = np.concatenate([[0], fp, [1]])
        auc[i] = np.trapz(tp, fp)
    return np.mean(auc)

#Calculate CC value from predicted saliency maps and continuous fixation maps as inputs
def cc(pred_sal, gt_sal):
    if pred_sal.shape != gt_sal.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((gt_sal.shape)))
      
    pred_sal = (pred_sal - pred_sal.mean())/(pred_sal.std())
    gt_sal = (gt_sal - gt_sal.mean())/(gt_sal.std())
    return np.corrcoef(pred_sal.flat, gt_sal.flat)[0, 1]

#Calculate SIM value from predicted saliency maps and continuous fixation maps as inputs
def sim(pred_sal, gt_sal):

    if pred_sal.shape != gt_sal.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((gt_sal.shape)))
        pred_sal = pred_sal.T
    pred_sal = pred_sal.astype(np.float)
    gt_sal = gt_sal.astype(np.float)
    pred_sal = (pred_sal - pred_sal.min())/(pred_sal.max()-pred_sal.min())
    pred_sal = pred_sal / pred_sal.sum()
    gt_sal = (gt_sal - gt_sal.min())/(gt_sal.max()-gt_sal.min())
    gt_sal = gt_sal / gt_sal.sum()
    diff = np.minimum(pred_sal, gt_sal)
    return np.sum(diff)

#Calculate KL value from predicted saliency maps and binary fixation maps as inputs
def kl(pred_sal, fix_map):
    if pred_sal.shape != fix_map.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
    eps = np.finfo(float).eps
    pred_sal = pred_sal.astype(np.float)
    fix_map = fix_map.astype(np.float)
    pred_sal = pred_sal / pred_sal.sum()
    fix_map = fix_map / fix_map.sum()
    return np.sum(fix_map * np.log(eps + fix_map / (pred_sal + eps)))

#Calculate information gain value from predicted saliency maps and binary fixation maps as inputs
def ig(pred_sal, fix_map, base_sal):
    if pred_sal.shape != fix_map.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
    if base_sal.shape != fix_map.shape:
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
    eps = np.finfo(float).eps
    fix_map = fix_map.astype(np.bool)
    pred_sal = pred_sal.astype(np.float32).flatten()
    base_sal = base_sal.astype(np.float32).flatten()
    pred_sal = (pred_sal - pred_sal.min()) / (pred_sal.max() - pred_sal.min())
    base_sal = (base_sal - base_sal.min()) / (base_sal.max() - base_sal.min())
    pred_sal = pred_sal / pred_sal.sum()
    base_sal = base_sal / base_sal.sum()
    locs = fix_map.flatten()
    return np.mean(np.log2(eps+pred_sal[locs])-np.log2(eps+base_sal[locs])) 

#Calculate AUC_SHUFFLED value from predicted saliency maps and binary fixation maps as inputs
def auc_shuffled(pred_sal, fix_map, base_map, n_split=100, step_size=.1):
    if pred_sal.shape != fix_map.shape:
        pred_sal = scipy.misc.imresize(pred_sal, fix_map.shape)
    assert(base_map.shape == fix_map.shape)
    pred_sal = pred_sal.flatten().astype(np.float)
    base_map = base_map.flatten().astype(np.float)
    fix_map = fix_map.flatten().astype(np.bool)
    pred_sal = (pred_sal - pred_sal.min()) / (pred_sal.max() - pred_sal.min())
    sal_fix = pred_sal[fix_map]
    sorted_sal_fix = np.sort(sal_fix)
    ind = np.where(base_map>0)[0]
    n_fix = sal_fix.shape[0]
    n_fix_oth = np.minimum(n_fix, ind.shape[0])
    
    rand_fix = np.zeros((n_fix_oth, n_split))
    for i in list(range(n_split)):
        rand_ind = random.sample(list(ind), n_fix_oth)
        rand_fix[:,i] = pred_sal[rand_ind]
    auc = np.zeros((n_split))
    for i in list(range(n_split)):
        cur_fix = rand_fix[:, i]
        sorted_cur_fix = np.sort(cur_fix)
        max_val = np.maximum(cur_fix.max(), sal_fix.max())
        tmp_all_thres = np.arange(0, max_val, step_size)[::-1]
        tp = np.zeros((tmp_all_thres.shape[0]))
        fp = np.zeros((tmp_all_thres.shape[0]))
        for ind, thres in enumerate(tmp_all_thres):
            tp[ind] = (sorted_sal_fix.shape[0] - sorted_sal_fix.searchsorted(thres, side='left')) * 1. / n_fix
            fp[ind] = (sorted_cur_fix.shape[0] - sorted_cur_fix.searchsorted(thres, side='left')) * 1. / n_fix_oth
        tp = np.concatenate([[0], tp, [1]])
        fp = np.concatenate([[0], fp, [1]])
        auc[i] = np.trapz(tp, fp)
    return np.mean(auc)

#Takes the path of the folder which contains the images as an input
#Return the extension of the images
def get_image_ext(folder):
    images = os.listdir(folder)
    if len(images) == 0:
        print (folder, " doesn't include any images")
        assert(len(images) > 0)
    res = dict()
    for image in images:
        ext = os.path.splitext(image)[-1]
        if ext in res:
            res[ext] += 1
        else:
            res[ext] = 1
    ext = ''; cnt = 0
    for text in res:
        if res[text] > cnt:
            cnt = res[text]
            ext = text
    return ext

#Takes a zipped file containing pred_sals, gt_sals, binarys (array containing paths)
#Gives an error if no images exist in any of the folders
def check_image_exist(image_infos):
    for image_info in image_infos:
        for i in range(3):
            image_path = image_info[i]
            if(os.path.exists(image_info[i])==0):
                print("image does not exist")
                print(image_info)
            assert(os.path.exists(image_info[i]), "can't find image:" + image_path)


#Use the above defined functions to calculate and print the values of all metrics
#Takes folder paths of predicted saliency maps, binary and continuous maps as inputs
def calc_all(pred_sal_folder, gt_sal_folder, binary_folder, base='binary'):

    pred_sal_ext = get_image_ext(pred_sal_folder)
    gt_sal_ext = get_image_ext(gt_sal_folder)
    binary_ext = get_image_ext(binary_folder))
    
    assert(base in ['binary', 'pred_sal', 'gt_sal'])

    if base == 'binary':
        selected_image_names = os.path.join(binary_folder, "*"+binary_ext)
    elif base == 'pred_sal':
        selected_image_names = os.path.join(pred_sal_folder, "*"+pred_sal_ext)
    elif base == 'gt_sal':
        selected_image_names = os.path.join(gt_sal_folder, "*"+gt_sal_ext)

    selected_images = sorted(glob.glob(selected_image_names))  
    assert(len(selected_images) > 0)


    selected_images = selected_images[0:698]
    
    selected_names = [os.path.splitext(os.path.basename(x))[0] for x in selected_images]

    pred_sals = [os.path.join(pred_sal_folder, x+pred_sal_ext) for x in selected_names]
    gt_sals = [os.path.join(gt_sal_folder, x+gt_sal_ext) for x in selected_names]
    binarys = [os.path.join(binary_folder, x+binary_ext) for x in selected_names]


    image_infos = zip(pred_sals, gt_sals, binarys)


    metrics = ['NSS', 'AUC_Judd', 'AUC_Borji', 'sAUC', 'CC', 'SIM', 'KL', 'IG']
    res = OrderedDict()
    for metric in metrics:
        res[metric] = list()


    all_map = calc_all_maps(binarys)

    for ind, image_info in enumerate(image_infos):

        pred_sal = np.array(Image.open(image_info[0]))
        gt_sal = np.array(Image.open(image_info[1]))
        fix_map = np.array(Image.open(image_info[2]))

        print(ind)
        
        pred_sal = np.array(Image.fromarray(pred_sal).resize((fix_map.shape)))
        gt_sal = np.array(Image.fromarray(gt_sal).resize((fix_map.shape)))

        pred_sal = pred_sal.T

        
        res['NSS'].append(nss(pred_sal, fix_map))
        res['AUC_Judd'].append(auc_judd(pred_sal, fix_map))
        res['AUC_Borji'].append(auc_borji(pred_sal, fix_map))
        res['sAUC'].append(auc_shuffled(pred_sal, fix_map, all_map - fix_map)) #check what all_map is
        res['CC'].append(cc(pred_sal, gt_sal))
        res['SIM'].append(sim(pred_sal, gt_sal))
        res['KL'].append(kl(pred_sal, fix_map))
        res['IG'].append(ig(pred_sal, fix_map, all_map - fix_map)) #here too
    et = time.time()

    for metric in metrics:
        print(metric, np.mean(np.array(res[metric])))

  assert(os.path.exists(image_info[i]), "can't find image:" + image_path)


In [None]:
#Define the paths of saliency maps, binary and continuous maps and call calc_all functions
#using the defined paths as parameters
def calc_metrics():
    pred_sal_folder = "/content/drive/My Drive/NNFL Project/Dataset/Unzipped/PredSalMaps2"
    gt_sal_folder = "/content/drive/My Drive/NNFL Project/Dataset/Unzipped/Fixation_Maps/FixationMap/Continous_map"
    binary_folder = "/content/drive/My Drive/NNFL Project/Dataset/Unzipped/Fixation_Maps/FixationMap/Binary_map"

    calc_all(pred_sal_folder, gt_sal_folder, binary_folder, base='binary')
cacl_metrics()

698
0
(768, 1024)
(768, 1024)
1
(768, 1024)
(768, 1024)
2
(768, 1024)
(768, 1024)
3
(768, 1024)
(768, 1024)
4
(768, 1024)
(768, 1024)
5
(768, 1024)
(768, 1024)
6
(768, 1024)
(768, 1024)
7
(768, 1024)
(768, 1024)
8
(768, 1024)
(768, 1024)
9
(768, 1024)
(768, 1024)
10
(768, 1024)
(768, 1024)
11
(768, 1024)
(768, 1024)
12
(768, 1024)
(768, 1024)
13
(768, 1024)
(768, 1024)
14
(768, 1024)
(768, 1024)
15
(768, 1024)
(768, 1024)
16
(768, 1024)
(768, 1024)
17
(768, 1024)
(768, 1024)
18
(768, 1024)
(768, 1024)
19
(768, 1024)
(768, 1024)
20
(768, 1024)
(768, 1024)
21
(768, 1024)
(768, 1024)
22
(768, 1024)
(768, 1024)
23
(768, 1024)
(768, 1024)
24
(768, 1024)
(768, 1024)
25
(768, 1024)
(768, 1024)
26
(768, 1024)
(768, 1024)
27
(768, 1024)
(768, 1024)
28
(768, 1024)
(768, 1024)
29
(768, 1024)
(768, 1024)
30
(768, 1024)
(768, 1024)
31
(768, 1024)
(768, 1024)
32
(768, 1024)
(768, 1024)
33
(768, 1024)
(768, 1024)
34
(768, 1024)
(768, 1024)
35
(768, 1024)
(768, 1024)
36
(768, 1024)
(768, 1024)
37
(768