In [None]:
import os
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import numpy as np

In [None]:
model_path = '/kaggle/input/glimpse_saliency_v1/pytorch/v1/1/saliency_model_v11.pt'

In [None]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        
        self.resblock = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels, affine=True)
        )
        
    def forward(self, x):
        out = self.resblock(x)
        return out + x

In [None]:
class Upsample2d(nn.Module):
    def __init__(self, scale_factor):
        super(Upsample2d, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        
    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
        return x

In [None]:
class MicroResNet(nn.Module):
    def __init__(self):
        super(MicroResNet, self).__init__()
        
        self.downsampler = nn.Sequential(
            nn.ReflectionPad2d(4),
            nn.Conv2d(3,8,kernel_size=9, stride=4),
            nn.InstanceNorm2d(8, affine=True),
            nn.ReLU(),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(8,16, kernel_size=3, stride=2),
            nn.InstanceNorm2d(16, affine=True),
            nn.ReLU(),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(16,32, kernel_size=3, stride=2),
            nn.InstanceNorm2d(32,affine=True),
            nn.ReLU(),
        )
        
        self.residual = nn.Sequential(
            ResBlock(32),
            nn.Conv2d(32,64, kernel_size=1, bias=False, groups=32),
            ResBlock(64)
        )
        
        self.segmentator = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64,16,kernel_size=3),
            nn.InstanceNorm2d(16,affine=True),
            nn.ReLU(),
            
            Upsample2d(scale_factor=2),
            nn.ReflectionPad2d(4),
            nn.Conv2d(16,1,kernel_size=9),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        out = self.downsampler(x)
        out = self.residual(out)
        out = self.segmentator(out)
        return out

In [None]:
model = MicroResNet()
model.load_state_dict(torch.load(model_path))
model.eval()

In [None]:
with torch.no_grad():
    path = '/kaggle/input/test-images/la-so-vk4vjTNVrTg-unsplash.jpg'
#     path = '/kaggle/input/test-images/alex-azabache-x5KZMQ_RPXc-unsplash.jpg'
#     path = '/kaggle/input/test-images/mieke-campbell-2AVb8vBKAPA-unsplash.jpg'
#     path = '/kaggle/input/test-images/shifaaz-shamoon-qtbV_8P_Ksk-unsplash.jpg'
#     path = '/kaggle/input/test-images/woody-yan-TkaOhOFPKdM-unsplash.jpg'
    image = Image.open(path)
    image = transforms.ToTensor()(transforms.Resize(240)(image)).unsqueeze(0)
    print(image.shape)
    plt.imshow(image[0].permute(1, 2, 0))
    plt.show()
    #prediction
    preds = model(image)[0]
    print(preds.shape)

    pred_h, pred_w = preds.size(1), preds.size(2)
    temperature = 0.25
    tempered_pred = torch.log(torch.softmax(preds[0].view(-1), dim=0)) / temperature
    tempered_pred = torch.exp(tempered_pred) / torch.sum(torch.exp(tempered_pred))
    pred = tempered_pred.view(pred_h,pred_w)
    plt.imshow(pred,cmap='Greys_r')
    plt.show()
     
#     pred_bin = torch.zeros_like(pred)
#     pred_bin[pred >= pred.max()*0.75] = 1
#     plt.imshow(pred_bin, cmap='Greys_r')
#     plt.show()
    
    center_x, center_y = 0, 0
    for i in range(pred.size(0)):
        for j in range(pred.size(1)):
            center_x += j * pred[i, j]
            center_y += i * pred[i, j]
    
    print("CENTER", center_x, center_y)
    
    image = Image.open(path)
    image = transforms.Resize(240)(image)
    center = (int(center_x*8),int(center_y*6))
    
    circle = cv2.circle(np.array(image),center,50,(0,255,0),3)
    plt.imshow(circle)
    plt.show()
    
    height = 100
    width = 100
    top_left = (int((center_x*8 - height/2)), int((center_x*8 - width/2)))
    bottom_right = (int((center_x*8 + height/2)), int((center_x*8 + width/2)))
    print(top_left, bottom_right)
    rectangle = cv2.rectangle(np.array(image),top_left, bottom_right,(0,255,0),3)
    plt.imshow(rectangle)
    plt.show()
    
    crop_img = np.array(image)[top_left[1]:top_left[1]+height, top_left[0]:top_left[0]+width]
#     crop_img_cv2 = cv2.cvtColor(crop_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
    plt.imshow(crop_img)
    plt.show()
    cv2.imwrite('/kaggle/working/crop.jpg',crop_img.astype(np.uint8))