In [None]:
import numpy as np
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class L2pooling(nn.Module):
    def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
        super(L2pooling, self).__init__()
        self.padding = (filter_size - 2 )//2
        self.stride = stride
        self.channels = channels
        a = np.hanning(filter_size)[1:-1]
        g = torch.Tensor(a[:,None]*a[None,:])
        g = g/torch.sum(g)
        self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
    def forward(self, input):
        input = input**2
        out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
        return (out+1e-12).sqrt()


In [None]:
class DISTS(torch.nn.Module):
    def __init__(self, load_weights=True):
        super(DISTS, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.stage1 = torch.nn.Sequential()
        self.stage2 = torch.nn.Sequential()
        self.stage3 = torch.nn.Sequential()
        self.stage4 = torch.nn.Sequential()
        self.stage5 = torch.nn.Sequential()
        for x in range(0,4):
            self.stage1.add_module(str(x), vgg_pretrained_features[x])
        self.stage2.add_module(str(4), L2pooling(channels=64))
        for x in range(5, 9):
            self.stage2.add_module(str(x), vgg_pretrained_features[x])
        self.stage3.add_module(str(9), L2pooling(channels=128))
        for x in range(10, 16):
            self.stage3.add_module(str(x), vgg_pretrained_features[x])
        self.stage4.add_module(str(16), L2pooling(channels=256))
        for x in range(17, 23):
            self.stage4.add_module(str(x), vgg_pretrained_features[x])
        self.stage5.add_module(str(23), L2pooling(channels=512))
        for x in range(24, 30):
            self.stage5.add_module(str(x), vgg_pretrained_features[x])

        self.stages=[self.stage1, self.stage2, self.stage3, self.stage4, self.stage5]
        for param in self.parameters():
            param.requires_grad = False
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1))

        self.chns =  [3,3,64,128,256,512,512]
        self.target=None



    def forward_once(self, x):
      ''' Realiza una pasada hacia adelante (forward pass) en la red neuronal para una entrada x
      y retorna las salidas intermedias de cada capa de la red hasta la capa indicada
      '''
      Ny = x.shape[2]
      Nx = x.shape[3]

      xfcov = torch.zeros([1, 6,Ny,Nx], dtype=x.dtype, device=x.device)
      h = (x-self.mean)/self.std #mean y std son valores típicos
      xfcov[0,0,:,:]= h[0,0,:,:]*h[0,0,:,:] #energia R
      xfcov[0,1,:,:]= h[0,1,:,:]*h[0,1,:,:] #energia G
      xfcov[0,2,:,:]= h[0,2,:,:]*h[0,2,:,:] #energia B
      xfcov[0,3,:,:]= h[0,0,:,:]*h[0,1,:,:] #factor de correlación RG
      xfcov[0,4,:,:]= h[0,0,:,:]*h[0,2,:,:] #factor de correlación RB
      xfcov[0,5,:,:]= h[0,1,:,:]*h[0,2,:,:] #factor de correlación GB

      h = self.stage1(h)
      h_relu1_2 = h
      h = self.stage2(h)
      h_relu2_2 = h
      h = self.stage3(h)
      h_relu3_3 = h
      h = self.stage4(h)
      h_relu4_3 = h
      h = self.stage5(h)
      h_relu5_3 = h

      return [x,xfcov,h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]



    def forward(self, x):
      '''realiza una pasada hacia adelante en la red neuronal para una entrada x
      y retorna la distancia DISTS entre las características de la entrada y un objetivo preestablecido
      '''
      layers = len(self.chns)
      feats0 = self.forward_once(x)
      dist1 = 0
      #c1 = 1e-6
      c1 = 1e-6
      feats1Mean=self.r_means

      for k in range(layers):
          x_mean = feats0[k].mean([2,3], keepdim=True)
          y_mean = feats1Mean[k]
          S1 = (x_mean-y_mean)**2 / (c1 + y_mean**2)
          dist1 = dist1 + S1.sum(1,keepdim=True)/self.chns[k]


      dist1 = dist1/len(self.chns) # numero de capas

      score = dist1.squeeze()


      return score

    def set_target(self,target):
      '''establece una imagen objetivo con la cual se va a calcular la distancia'''
      feats1 = self.forward_once(target)
      feats1Mean=[]
      for k in range(len(self.chns)):
        feats1Mean.append(feats1[k].mean([2,3], keepdim=True))

      self.r_means=feats1Mean
      self.target=target
