In [1]:
import torch
import torch.nn as nn
from torchvision.models import vgg19
from torchvision.models.feature_extraction import create_feature_extractor

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()

        # VGG19 pretrained on ImageNet
        vgg = vgg19(pretrained=True).features.eval()
        for param in vgg.parameters():
            param.requires_grad = False

        # Extract relu5_4 (features.35)
        self.feature_extractor = create_feature_extractor(
            vgg,
            return_nodes={'35': 'feat'}  # features[35] is relu5_4
        )

        # ImageNet normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        # L1 loss is more robust to outliers than MSE
        self.criterion = nn.L1Loss()

    def forward(self, sr, hr):
        """
        Args:
            sr: Super-resolved image (B, 3, H, W), values in [0,1]
            hr: High-resolution GT image (B, 3, H, W), values in [0,1]
        Returns:
            Perceptual loss between VGG features
        """
        # Normalize to ImageNet stats
        sr_norm = (sr - self.mean) / self.std
        hr_norm = (hr - self.mean) / self.std

        # Extract VGG features
        sr_feat = self.feature_extractor(sr_norm)['feat']
        hr_feat = self.feature_extractor(hr_norm)['feat']

        # Compute perceptual loss
        loss = self.criterion(sr_feat, hr_feat)
        return loss


In [2]:
sr_input = torch.rand(1, 3, 224, 224)  # Example super-resolved input
hr_input = torch.rand(1, 3, 224, 224)  # Example high-resolution ground truth
vgg_loss = VGGLoss()
loss_value = vgg_loss(sr_input, hr_input)
print(f"Perceptual loss: {loss_value.item()}")



Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Network/Servers/fs.local/Volumes/home/kawai/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:06<00:00, 88.1MB/s] 


Perceptual loss: 0.027264084666967392
