In [1]:
# Visualize some of the results.
from PIL import Image
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from models import SRCNN
import torch
from torchvision import transforms
from utils import convert_ycbcr_to_rgb

# Load the trained model.
model = SRCNN()
model.load_state_dict(torch.load("Models/Srcnn.pth"))
model.eval()

# Load the image that we want to super-resolve.
img_path = "data/DIV2K_train_HR/DIV2K_train_HR/0003.png"
image = Image.open(img_path).convert("YCbCr")
y, cb, cr = image.split()

low_res_img = y.resize(
    (720, 720), Image.ANTIALIAS)
upscaled_img = low_res_img.resize(
    (y.width, y.height), Image.BICUBIC)

  (720, 720), Image.ANTIALIAS)


In [2]:
low_res_img.size, upscaled_img.size

((720, 720), (2040, 1356))

In [3]:
transforms1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((1296, 1296)),
    # Normalizing using ImageNet stats
])

In [4]:
original_image = transforms1(low_res_img)
upscaled_img = transforms1(upscaled_img)



In [5]:
original_image.shape, upscaled_img.shape

(torch.Size([1, 1296, 1296]), torch.Size([1, 1296, 1296]))

In [6]:
# Save the original, upscaled images.
original_img = TF.to_pil_image(original_image)
upscaled_img = TF.to_pil_image(upscaled_img)

original_img.save('Predictions/1_original_ycrbr.png')
upscaled_img.save('Predictions/1_bicubic_ycrbr.png')

In [7]:
# Convert the original and upscaled to RGB format.
original_cb = cb.resize(original_img.size, Image.BICUBIC)
original_cr = cr.resize(original_img.size, Image.BICUBIC)

upscaled_cb = cb.resize(upscaled_img.size, Image.BICUBIC)
upscaled_cr = cr.resize(upscaled_img.size, Image.BICUBIC)

original_rgb_image = convert_ycbcr_to_rgb(original_img, original_cb, original_cr)
upscaled_rgb_image = convert_ycbcr_to_rgb(upscaled_img, upscaled_cb, upscaled_cr)

In [8]:
original_rgb_image.save('Predictions/1_original_rgb.png')
upscaled_rgb_image.save('Predictions/1_bicubic_rgb.png')

In [9]:
transforms2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((1296, 1296)),
    # Normalizing using ImageNet stats
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [10]:
import torch

def denormalize(tensor, mean, std):
    """
    Denormalizes a tensor of images.
    
    Parameters:
        tensor (torch.Tensor): The normalized images tensor
        mean (list or tuple): The mean used for normalization
        std (list or tuple): The standard deviation used for normalization
    
    Returns:
        torch.Tensor: The denormalized images tensor
    """
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [11]:
# Super-resolute the bi-cubic upscaled image.
input_img = transforms2(upscaled_img).unsqueeze(0) # Add the batch dimension.

# Super-resolve the image.
with torch.no_grad():
    output_img = model(input_img)

# Convert the output tensor back to pil format.
output_img = output_img.squeeze(0) # Remove the batch dimension.
denormalized_output_img = denormalize(output_img, mean=[0.5], std=[0.5])
output_img = TF.to_pil_image(output_img)

super_resolved_cb = cb.resize(output_img.size, Image.BICUBIC)
super_resolved_cr = cr.resize(output_img.size, Image.BICUBIC)
super_resolved_rgb_image = convert_ycbcr_to_rgb(output_img, super_resolved_cb, super_resolved_cr)
# Save the output_image.
super_resolved_rgb_image.save('Predictions/SRCNN_upscaled.png')