In [1]:
## Setting import from parent directory

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models


from models.Generator import Generator
from models.Discriminator import Discriminator

In [3]:
class Content_loss(nn.Module):
    def __init__(self):
        super(Content_loss,self).__init__()
        
        vgg = models.vgg19(pretrained=True)
        #vgg.eval()
        self.feature = nn.Sequential(*list(vgg.features.children())[:-1])
        self.feature.eval()
        self.MSE = nn.MSELoss()
        
    def forward(self, HR, SR):
        HR_feature = self.feature(HR)
        SR_feature = self.feature(SR)
        loss  = self.MSE(HR_feature, SR_feature)
        return loss.sum()
    
class Adversarial_loss(nn.Module):
    def __init__(self):
        super(Adversarial_loss,self).__init__()
        #self.D = Discriminator()
        
    def forward(self, D_SR):
        loss=-torch.log10(D_SR)
        
        return loss.sum()

    
class Perceptual_loss(nn.Module):
    def __init__(self):
        super(Perceptual_loss,self).__init__()
        self.content_loss = Content_loss()
        self.adversarial_loss = Adversarial_loss()
    def forward(self, HR, SR, D_SR):
        loss = self.content_loss(HR, SR) + 10**-3 * self.adversarial_loss(D_SR)
        return loss
        

## Input

In [4]:
G = Generator()
D = Discriminator()

In [5]:
HR = torch.rand(5,3,256,256)
print(HR.shape)
SR = G(torch.rand(5,32,32,3))
print(SR.shape)

torch.Size([5, 3, 256, 256])
torch.Size([5, 3, 256, 256])


## loss
### content_loss

In [6]:
c_loss = Content_loss()

In [7]:
loss1 = c_loss(HR, SR)
loss1

tensor(0.0160, grad_fn=<SumBackward0>)

### adversarial_loss

In [8]:
a_loss = Adversarial_loss()

In [9]:
loss2 = a_loss(D(SR))
loss2

tensor(1.6655, grad_fn=<SumBackward0>)

### Perceptual_loss

In [10]:
p = Perceptual_loss()

In [11]:
p(HR, SR, D(SR))

tensor(0.0177, grad_fn=<AddBackward0>)