In [1]:
# Visualize some of the results.
from PIL import Image
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from models import SRCNN
import torch
from torchvision import transforms

# Load the trained model.
model = SRCNN()
model.load_state_dict(torch.load("Models/Srcnn.pth"))
model.eval()

# Load the image that we want to super-resolve.
img_path = "data/DIV2K_valid_HR/DIV2K_valid_HR/0803.png"
image = Image.open(img_path)

# Downscale the image and then use bicubic interpolation to upscale.
low_res_img = image.resize(
            (720, 720), Image.ANTIALIAS)
upscaled_img = low_res_img.resize(
            (image.width, image.height), Image.BICUBIC)

  (720, 720), Image.ANTIALIAS)


In [2]:
transforms1 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((1296, 1296)),
    # Normalizing using ImageNet stats
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
original_image = transforms1(image)
upscaled_img = transforms1(upscaled_img)



In [4]:
# Save the original, upscaled images.
original_img = TF.to_pil_image(original_image)
upscaled_img = TF.to_pil_image(upscaled_img)

original_img.save('Predictions/1_original.png')
upscaled_img.save('Predictions/1_bicubic.png')

In [5]:
transforms2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((1296, 1296)),
    # Normalizing using ImageNet stats
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
import torch

def denormalize(tensor, mean, std):
    """
    Denormalizes a tensor of images.
    
    Parameters:
        tensor (torch.Tensor): The normalized images tensor
        mean (list or tuple): The mean used for normalization
        std (list or tuple): The standard deviation used for normalization
    
    Returns:
        torch.Tensor: The denormalized images tensor
    """
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [9]:
# Super-resolute the bi-cubic upscaled image.
input_img = transforms2(upscaled_img).unsqueeze(0) # Add the batch dimension.

# Super-resolve the image.
with torch.no_grad():
    output_img = model(input_img)

# Convert the output tensor back to pil format.
output_img = output_img.squeeze(0) # Remove the batch dimension.
denormalized_output_img = denormalize(output_img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
output_img = TF.to_pil_image(output_img)

# Save the output_image.
output_img.save('Predictions/SRCNN_upscaled.png')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns