In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from PIL.Image import open as imopen
from cv2 import resize

from model import *
from utils import compute_saliency
from dataloader import ImageDataSet

n_images = 100
distance_from_ends = 0
image_size = 64
test_result_df = pd.read_csv('Data/ImageList.csv')
images = test_result_df['ImageFile']
top_images = test_result_df.iloc[distance_from_ends:(n_images+distance_from_ends)]
top_images.reset_index(drop=True,inplace=True)
bottom_images = test_result_df.iloc[-n_images-distance_from_ends:len(images)-distance_from_ends]
bottom_images.reset_index(drop=True,inplace=True)
model = resnet18(pretrained=False)
model.load_state_dict(torch.load('RankPrediction-model.pkl',map_location=torch.device('cpu')))
dataset = ImageDataSet(top_images,image_size,torch.device('cpu'),do_augmentation=False)
dataloader = DataLoader(dataset,
                        batch_size=20,
                        num_workers=0,
                        shuffle=False,
                        pin_memory=False)
top_maps = compute_saliency(dataloader,model,image_size)
dataset = ImageDataSet(bottom_images,image_size,torch.device('cpu'),do_augmentation=False)
dataloader = DataLoader(dataset,
                        batch_size=20,
                        num_workers=0,
                        shuffle=False,
                        pin_memory=False)
bottom_maps = compute_saliency(dataloader,model,image_size)

fig = plt.figure(figsize = (25,int(8.8*n_images)))
for n in range(n_images):

    rawimg = imopen('Data/raw-img/'+top_images['ImageFile'][n])
    rawimg_size = np.shape(rawimg)
    ax = fig.add_subplot(n_images, 4, (4*n)+1)
    ax.imshow(plt.imread('Data/raw-img/'+top_images['ImageFile'][n]), interpolation='nearest')

    gradimg = resize(top_maps[n],(rawimg_size[1],rawimg_size[0]))
    ax = fig.add_subplot(n_images, 4, (4*n)+2)
    ax.imshow(gradimg, interpolation='nearest')

    rawimg = imopen('Data/raw-img/'+bottom_images['ImageFile'][n])
    rawimg_size = np.shape(rawimg)
    ax = fig.add_subplot(n_images, 4, (4*n)+3)
    ax.imshow(plt.imread('Data/raw-img/'+bottom_images['ImageFile'][n]), interpolation='nearest')

    gradimg = resize(bottom_maps[n],(rawimg_size[1],rawimg_size[0]))
    ax = fig.add_subplot(n_images, 4, (4*n)+4)
    ax.imshow(gradimg, interpolation='nearest')
fig.show()