In [2]:
#Import Libraries
import numpy as np
import cv2
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import pearsonr


In [3]:
#Define the metrics
def normalize_map(s_map):
    return (s_map - np.mean(s_map)) / (np.std(s_map) + 1e-8)

def compute_nss(pred_map, fixation_map):
    pred_map = normalize_map(pred_map)
    return np.mean(pred_map[fixation_map > 0])

def compute_cc(pred_map, gt_map):
    return pearsonr(pred_map.flatten(), gt_map.flatten())[0]

In [4]:
#Specification of stimuli, ground truth, and predicted saliency maps paths
raw_dir = 'Stimuli/' 
pred_dir = 'DeepGazeIIE_saliencymaps/'
gt_dir = 'GTSaliencymaps/'
fix_dir = 'Fixation-maps/'
output_csv = 'metric_results.csv'

results = []
file_list = sorted(os.listdir(pred_dir))

In [None]:
import os
os.makedirs("visualizations", exist_ok=True)

for file_name in tqdm(file_list):
    raw = cv2.imread(os.path.join(raw_dir, file_name))
    pred = cv2.imread(os.path.join(pred_dir, file_name), 0).astype(np.float32)
    gt = cv2.imread(os.path.join(gt_dir, file_name), 0).astype(np.float32)
    fix = cv2.imread(os.path.join(fix_dir, file_name), 0).astype(np.float32)

    # Resizing fixation map for NSS computation
    fix_resized = fix
    if pred.shape != fix.shape:
        fix_resized = cv2.resize(fix, (pred.shape[1], pred.shape[0]), interpolation=cv2.INTER_NEAREST)

    nss = compute_nss(pred, fix_resized)

    results.append({'Image': file_name, 'NSS': nss})

100%|██████████| 9/9 [00:00<00:00, 36.91it/s]


In [None]:
#Read the file paths and and save visualisation into the output folder
plt.figure(figsize=(15, 5))
    
plt.subplot(1, 3, 1)
plt.imshow(cv2.cvtColor(raw, cv2.COLOR_BGR2RGB))
plt.title('Raw Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(gt, cmap='gray')
plt.title('Groundtruth')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='jet')
plt.title(f'Prediction NSS={nss:.2f}')
plt.axis('off')

plt.tight_layout()

# Save figure
save_path = os.path.join("visualizations", f"{os.path.splitext(file_name)[0]}.png")
plt.savefig(save_path, dpi=150)
plt.close()


In [8]:
#Print results into a dataframe and save figures into the output csv file
df = pd.DataFrame(results)
df.to_csv(output_csv, index=False)
df.head()

Unnamed: 0,Image,NSS
0,High(i).png,0.702362
1,High(ii).png,1.751143
2,High(iii).png,1.540396
3,Low(i).png,2.897657
4,Low(ii).png,3.727693
