In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class VGGPerceptualLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        model = torchvision.models.vgg16(pretrained=True)
        features = model.features
        self.relu2_2 = nn.Sequential()
        for i in range(9):
            self.relu2_2.add_module(name="relu2_2_"+str(i+1), module=features[i])    
        # Setting requires_grad=False to fix the perceptual loss model parameters 
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        out_relu2_2 = self.relu2_2(x)
        return out_relu2_2

In [None]:
VGGLoss = VGGPerceptualLoss()

In [None]:
def PerceptualLoss(x, y):
    
    x_features = VGGLoss(x)
    y_features = VGGLoss(y)
    
    # Calculating feature loss
    C = y_features.shape[1]
    H = y_features.shape[2]
    W = y_features.shape[3]
    feature_loss = F.mse_loss(y_features, x_features) / (C*H*W) # Here assuming square of Euclidean Norm = MSE Loss
    return feature_loss

In [None]:
x = torch.randn(4,3,256,256)  # Batch Size = 4
y = torch.randn(4,3,256,256)

In [None]:
loss = PerceptualLoss(x, y)

In [None]:
loss