# Receive output from SR generator for a low resoultion image

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage, Normalize

import lpips

from skimage.metrics import structural_similarity as SSIM

# SR Generator class

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.prelu(self.conv1(x)))

class GSR(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residuals=5):
        super(GSR, self).__init__()
        self.encoder = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.resnet = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residuals)])
        self.decoder = nn.Conv2d(64, out_channels, kernel_size=5, padding=2)
        self.upsample = nn.Upsample(scale_factor=4, mode='bicubic')

    def forward(self, x):
        x = self.upsample(x)
        x = self.encoder(x)
        x = self.resnet(x)
        x = self.decoder(x)
        return torch.clamp(x, 0, 1)

## Load the weights of the pre-trained model

In [None]:
# Model, Optimizer, and Losses
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assume GSR is the class definition for the SR generator
generator_sr = GSR().to(device)  # Make sure the device is set (e.g., 'cuda' or 'cpu')

# Load the pre-trained weights
generator_sr.load_state_dict(torch.load('generator_sr.pth'))

# Set the model to evaluation mode
generator_sr.eval()

GSR(
  (encoder): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (resnet): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride

## Prepare the low-resolution image

In [None]:
# Load the low-resolution image
lr_image = Image.open('input/LR/8e5ce3b37f7e73.jpg').convert('RGB')

# Preprocess the image: Convert to tensor and add a batch dimension
transform = ToTensor()
lr_image_tensor = transform(lr_image).unsqueeze(0).to(device)  # Add batch dimension and move to device

# Normalize the input image before passing it through the generator
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
lr_image_tensor = normalize(lr_image_tensor)

## Generate the high-resolution image

In [None]:
# Generate the high-resolution image
with torch.no_grad():  # Disable gradient calculation for inference
    output = generator_sr(lr_image_tensor)

# Denormalize the output before converting to an image
output = output * torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + \
         torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)

# Remove the batch dimension and convert back to a PIL image
sr_image_tensor = output.squeeze(0)  # Remove batch dimension
to_pil = ToPILImage()
sr_image = to_pil(sr_image_tensor.cpu())  # Move to CPU and convert to PIL image

## Save high resolution image output

In [None]:
# Save the high-resolution image
sr_image.save('output/HR/super_resolved_image.jpg')

## Metrics

In [None]:
def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

In [None]:
def calculate_ssim(img1, img2, win_size=7):
    img1_np = img1.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert to HWC format and NumPy array
    img2_np = img2.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert to HWC format and NumPy array
    ssim_value = SSIM(img1_np, img2_np, multichannel=True, data_range=img1_np.max() - img1_np.min(), win_size=win_size)
    return ssim_value

In [None]:
def calculate_lpips(img1, img2):
    # Initialize LPIPS model
    loss_fn = lpips.LPIPS(net='alex')

    # Ensure img1 and img2 are normalized and have the same dimensions
    lpips_value = loss_fn(img1, img2)
    return lpips_value.item()

## Load images and run calaculation of metrics for the given images

In [None]:
sr_image = Image.open('output/HR/super_resolved_image.png').convert('RGB')
hr_image = Image.open('input/HR/8e5ce3b37f7e73.jpg').convert('RGB')

# Convert to tensors (this will normalize the images to the range [0, 1])
sr_image_tensor = ToTensor()(sr_image)  # Shape will be [3, H, W]
hr_image_tensor = ToTensor()(hr_image)  # Shape will be [3, H, W]

# Add a batch dimension (shape will be [1, 3, H, W])
sr_image_tensor = sr_image_tensor.unsqueeze(0)
hr_image_tensor = hr_image_tensor.unsqueeze(0)

# Resize sr_image_tensor to match the dimensions of hr_image_tensor
sr_image_tensor_resized = F.interpolate(sr_image_tensor, size=hr_image_tensor.shape[2:], mode='bilinear', align_corners=False)

# Normalize to [-1, 1] for LPIPS
sr_image_tensor_lpips = 2 * sr_image_tensor - 1
hr_image_tensor_lpips = 2 * hr_image_tensor - 1

sr_image_tensor_lpips = F.interpolate(sr_image_tensor_lpips, size=hr_image_tensor_lpips.shape[2:], mode='bilinear', align_corners=False)

In [None]:
# PSNR
psnr_value = calculate_psnr(sr_image_tensor_resized, hr_image_tensor)
print(f"PSNR: {psnr_value} dB")

# SSIM
ssim_value = calculate_ssim(sr_image_tensor_resized, hr_image_tensor, win_size=3)
print(f"SSIM: {ssim_value}")

# LPIPS
lpips_value = calculate_lpips(sr_image_tensor_lpips, hr_image_tensor_lpips)
print(f"LPIPS: {lpips_value}")

PSNR: 9.65070629119873 dB
SSIM: 0.04575595462125225
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth
LPIPS: 0.804252028465271
