In [1]:
import os
import torch
import cv2
import numpy as np
from torchvision import transforms, utils, models
import torch.nn as nn
from tqdm import tqdm
from utils.data_process import preprocess_img, postprocess_img
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

flag = 1 # 0 for TranSalNet_Dense, 1 for TranSalNet_Res

  from .autonotebook import tqdm as notebook_tqdm


↑↑↑  Set **flag=1** to load *TranSalNet_Dense*,set **flag=0** to load *TranSalNet_Res*. <br>
<br>
↓↓↓  Load the model and pre-trained parameters.<br>

In [2]:
if flag:
    from TranSalNet_Res import TranSalNet
    model = TranSalNet()
    model.load_state_dict(torch.load(r'Ar_sal_model_most_new.pth'))
else:
    from TranSalNet_Dense import TranSalNet
    model = TranSalNet()
    model.load_state_dict(torch.load(r'pretrained_models\TranSalNet_Dense.pth'))

model = model.to(device) 
model.eval()

TranSalNet(
  (encoder): _Encoder(
    (encoder): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
      

In [3]:
from glob import glob
data = glob("val/val_stimuli/*.jpg")
data_map = glob("val/val_text_map/*.jpg")

In [4]:
for i in range(len(data)):
    name = data[i].split("/")[-1]
    test_img = data[i]
    img = preprocess_img(test_img) # padding and resizing input image into 250*250
    tmap = preprocess_img(data_map[i])
    
    img = np.array(img)/255.
    tmap = np.array(tmap)/255.
    
    img = np.expand_dims(np.transpose(img,(2,0,1)),axis=0)
    tmap = np.expand_dims(np.transpose(tmap,(2,0,1)),axis=0)
    
    img = torch.from_numpy(img)
    tmap = torch.from_numpy(tmap)
    
    img = img.type(torch.cuda.FloatTensor).to(device)
    tmap = tmap.type(torch.cuda.FloatTensor).to(device)
    
    pred_saliency = model(img ,tmap )
    toPIL = transforms.ToPILImage()
    pic = toPIL(pred_saliency.squeeze())
    pred_saliency = postprocess_img(pic, test_img) # restore the image to its original size as the result
    cv2.imwrite(r'result_val/'+ name , pred_saliency, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) # save the result

↓↓↓ Get the test image, feed it into the model, and get a result.

In [None]:

test_img = r'dataset_ECdata/val/val_stimuli/103.jpg' 

img = preprocess_img(test_img) # padding and resizing input image into 384x288
print(img.shape)
img = np.array(img)/255.
img = np.expand_dims(np.transpose(img,(2,0,1)),axis=0)
img = torch.from_numpy(img)
img = img.type(torch.cuda.FloatTensor).to(device)
pred_saliency = model(img)
toPIL = transforms.ToPILImage()
pic = toPIL(pred_saliency.squeeze())

pred_saliency = postprocess_img(pic, test_img) # restore the image to its original size as the result

cv2.imwrite(r'result.png', pred_saliency, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) # save the result
print('Finished, check the result at: {}'.format(r'result.png'))