In [None]:
import os

import PIL
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchbearer import Trial
from torchvision import transforms

In [None]:
# Image transform
gt_size = 26
path_image = 'images/image.png'
path_gt = 'gt/gt_gau.png'

transform_image = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),  # convert to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
        ])

transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.ToTensor(),  # convert to tensor
        ])

transform_display = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),  # convert to tensor
        ])

transform_gt = transforms.Compose([
            transforms.Resize((gt_size,gt_size)),
            transforms.ToTensor(),  # convert to tensor
        ])

In [None]:
# feed-forward convolutional encoder network
vgg = torchvision.models.vgg16(pretrained=True)
resnet = torchvision.models.resnet152(pretrained=True)

In [None]:
# Extractor

def fm_resize(feature):
    return F.interpolate(feature.unsqueeze(0), size=[3, 28, 28], mode="trilinear", align_corners=False).view(3,28,28)

def extractor_vgg(image):
    list_vgg16 = [5, 10, 17, 24, 31]
    img = torch.tensor([])
    for i in list_vgg16:
        feature_extractor_model = nn.Sequential(*list(vgg.children())[0][0:i])
        feature_extractor_model.eval()
        feature_extractor_model = feature_extractor_model
        feature = feature_extractor_model(image.unsqueeze(0))
        img = torch.cat([img, fm_resize(feature)], 0)
    return img

def extractor_resnet(image):
    list_resnet152 = [4, 5, 6, 7, 8]
    img = torch.tensor([])
    for i in list_resnet152:
        feature_extractor_model = nn.Sequential(*list(resnet.children())[0:i])
        feature_extractor_model.eval()
        feature_extractor_model = feature_extractor_model
        feature = feature_extractor_model(image.unsqueeze(0))
        img = torch.cat([img, fm_resize(feature)], 0)
    return img

# Position encoding module
class PosENet_eval(nn.Module):
    def __init__(self, input_dim):
        super(PosENet_eval, self).__init__()
        self.conv = nn.Conv2d(input_dim, 1, (3, 3), stride=1, padding=0)
            
    def forward(self, x):
        out = self.conv(x)
        return out

In [None]:
model = PosENet_eval(input_dim=3)
model_vgg = PosENet_eval(input_dim=15) # if input feature maps from 1 layer, input_dim need to modify to 3.
model_resnet = PosENet_eval(input_dim=15) # if input feature maps from 1 layer, input_dim need to modify to 3.

model.load_state_dict(torch.load('weights/model_3_0_gau.weights', map_location=torch.device('cpu')))
model_vgg.load_state_dict(torch.load('weights/model_vgg_3_0_gau.weights', map_location=torch.device('cpu')))
model_resnet.load_state_dict(torch.load('weights/model_resnet_3_0_gau.weights', map_location=torch.device('cpu')))

model.eval()
model_vgg.eval()
model_resnet.eval()

img_model_PosENet = transform(PIL.Image.open(path_image))*255
img_model_vgg_resnet = transform_image(PIL.Image.open(path_image))

# Prediction
pred = model(img_model_PosENet.permute(0,2,1).unsqueeze(0))
pred_vgg = model_vgg(extractor_vgg(img_model_vgg_resnet.permute(0,2,1)).unsqueeze(0))
pred_resnet = model_resnet(extractor_resnet(img_model_vgg_resnet.permute(0,2,1)).unsqueeze(0))

In [None]:
# set a display transform
img = transform_display(PIL.Image.open(path_image))
gt_img = transform_display(PIL.Image.open(path_gt))

pred_img = F.interpolate(pred, size=(224,224), mode="bilinear", align_corners=False)
pred_vgg_img = F.interpolate(pred_vgg, size=(224,224), mode="bilinear", align_corners=False)
pred_resnet_img = F.interpolate(pred_resnet, size=(224,224), mode="bilinear", align_corners=False)

# show the image and GT
plt.figure(figsize=(15,3))
plt.subplot(151);plt.imshow(img.permute(2, 1, 0), aspect='equal')
plt.axis('off')
plt.subplot(152);plt.imshow(gt_img[0], aspect='equal')
plt.axis('off')
plt.subplot(153);plt.imshow(pred_img[0,0,:,:].data, aspect='equal')
plt.axis('off')
plt.subplot(154);plt.imshow(pred_vgg_img[0,0,:,:].data, aspect='equal')
plt.axis('off')
plt.subplot(155);plt.imshow(pred_resnet_img[0,0,:,:].data, aspect='equal')
plt.axis('off')
plt.show()

In [None]:
# Spearman's rank correlation coefficient (SPC)
gt = transform_gt(PIL.Image.open(path_gt)).view(gt_size,gt_size).numpy()
data = pd.DataFrame({'PosENet':pred.view(gt.shape).detach().numpy().flatten(), 
                     'Vgg':pred_vgg.view(gt.shape).detach().numpy().flatten(), 
                     'Resnet':pred_resnet.view(gt.shape).detach().numpy().flatten(), 
                     'GT':gt.flatten() 
                    })

round(data.corr('spearman'), 3)['GT'][0:3]

In [None]:
# mean absolute error (MAE)
print('PosENet: ', round(sum(abs((pred.view(gt.shape).detach().numpy()-gt).flatten())) / (gt_size ** 2), 3))
print('Vgg      ', round(sum(abs((pred_vgg.view(gt.shape).detach().numpy()-gt).flatten())) / (gt_size ** 2), 3))
print('Resnet   ', round(sum(abs((pred_resnet.view(gt.shape).detach().numpy()-gt).flatten())) / (gt_size ** 2), 3))

In [None]:
gt = transform_gt(PIL.Image.open(path_gt))
print(nn.functional.kl_div(pred, gt, reduction='mean'))
print(nn.functional.kl_div(pred_vgg, gt, reduction='mean'))
print(nn.functional.kl_div(pred_resnet, gt, reduction='mean'))