In [41]:
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
from models import SRResNet
import random

In [6]:

device = torch.device("cuda" if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [40]:
model_weights = {'srresnet': 'SRResNet.pth',
                 #  'srresnet_attention ': 'SRResNet_attention.pth',
                 'srgan': 'SRResNet_Discriminator.pth',
                 'srgan_efficientnet': "SRResNet_EfficientNet.pth",
                 'srgan_wgan': 'SRResNet_Discriminator_WGAN.pth'}

models = []
for key, value in model_weights.items():
    model = SRResNet(4, 8 if 'attention' in value else 0, value).to(device)
    model.eval()
    models.append(model)

Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!


In [36]:
def super_resolve(model, img):
    """
    Apply super-resolution using a given model to a specified image and compare it
    with the bicubic-upsampled version. The function processes an image to produce
    low-resolution, bicubic-upsampled, and super-resolved versions for evaluation.

    :param model: The super-resolution model to be used for image processing.
    :param img: Filepath of the high-resolution (HR) image to be processed.
    :return: A tuple containing the low-resolution image, the original high-resolution image,
             the super-resolved image, and the bicubic-upsampled image.
    """

    # Load the high-resolution (HR) image from the given file path
    hr_img = Image.open(img)
    # Ensure the image is in RGB format
    hr_img = hr_img.convert('RGB')

    # Create a low-resolution (LR) version of the image by downscaling using bicubic interpolation
    lr_img = hr_img.resize(
        (int(hr_img.width / 4), int(hr_img.height / 4)), Image.BICUBIC)

    # Upsample the low-resolution image back to high-resolution using bicubic interpolation
    # This serves as a baseline for comparison with the super-resolved image
    bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

    # Normalize the low-resolution image for model input according to ImageNet standards
    # `convert_image` is assumed to be a function that adjusts image data to model-specific input requirements
    imagenet_normed = convert_image(
        lr_img, source='pil', target='imagenet-norm')

    # Process the normalized low-resolution image using the super-resolution model
    # Assuming the model and device (GPU/CPU) are correctly configured
    with torch.no_grad():
        sr_img = model(imagenet_normed.unsqueeze(0).to(device))

    # Remove the batch dimension and move the image data back to CPU, if necessary
    sr_img = sr_img.squeeze(0).cpu().detach()

    # Convert the model output back to a PIL image from its normalized form
    sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')

    # Return the processed images
    return lr_img, hr_img, sr_img, bicubic_img

In [37]:
def visualize_images(sr_imgs, hr_img, bicubic_img, model_names):
    # Create grid
    margin = 40
    num_models = len(model_weights)
    # Calculating number of rows to include HR and bicubic at top
    rows = (num_models + 2) // 2

    # Define the size of the grid image
    grid_img = Image.new('RGB',
                         (2 * (bicubic_img.width + margin) + margin,
                          rows * (hr_img.height + margin) + margin),
                         (255, 255, 255))

    # Initialize drawing context
    draw = ImageDraw.Draw(grid_img)

    font = ImageFont.load_default(size=23)

    # Place the original HR image in the top left
    x_hr = margin
    y_hr = margin
    grid_img.paste(hr_img, (x_hr, y_hr))
    text_size = font.getbbox("Original HR")
    draw.text((x_hr + hr_img.width / 2 -
               text_size[2] / 2, y_hr - text_size[3] - 10), "Original HR", font=font, fill='black')

    # Place the bicubic-upsampled image in the top right
    x_bicubic = margin + hr_img.width + margin
    y_bicubic = margin
    grid_img.paste(bicubic_img, (x_bicubic, y_bicubic))
    text_size = font.getbbox("Bicubic")
    draw.text((x_bicubic + bicubic_img.width / 2 -
               text_size[2] / 2, y_bicubic - text_size[3] - 10), "Bicubic", font=font, fill='black')

    # Loop through SR images and their corresponding model names starting from the second row
    row, column = 1, 0  # Start from the second row, first column

    for sr_img, model_name in zip(sr_imgs, model_names):
        x = margin + column * (sr_img.width + margin)
        y = margin + row * (sr_img.height + margin)

        # Place image
        grid_img.paste(sr_img, (x, y))
        text_size = font.getbbox(model_name)
        draw.text((x + sr_img.width / 2 -
                   text_size[2] / 2, y - text_size[3] - 10), model_name, font=font, fill='black')

        # Update column and check if we need to move to the next row
        column += 1
        if column >= 2:
            column = 0
            row += 1

    # Display the grid image
    grid_img.show()

In [54]:
# Define the path to the validation dataset directory
test_folder = '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X'

# Randomly select 3 folders from the test folder
test_images_folders = random.sample(os.listdir(test_folder), 3)

# For each selected folder, randomly select one image and construct its full path
test_images = [
    os.path.join(test_folder, folder, random.sample(os.listdir(
        os.path.join(test_folder, folder)), 1)[0])  # Randomly pick one file
    for folder in test_images_folders  # Iterate over the selected folders
]

# Print the paths of the randomly selected images
print(test_images)

# Loop through each selected test image
for test_image in test_images:
    # Initialize an empty list to store super-resolved images from each model
    sr_imgs = []

    # Loop over each model defined in 'models'
    for model in models:
        # Apply the super-resolution model to the test image
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(model, test_image)

        # Append the super-resolved image to the list of super-resolved images
        sr_imgs.append(sr_img)

    # Visualize all super-resolved images along with the high-resolution and bicubic images
    visualize_images(sr_imgs, hr_img, bicubic_img, model_weights.keys())

['/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01592084/ILSVRC2012_val_00012836.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01582220/ILSVRC2012_val_00003665.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01632777/ILSVRC2012_val_00013940.JPEG']


python(50517) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(50524) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(50532) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(50533) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(50534) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(50535) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [53]:
test_images = [
    "/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01592084/ILSVRC2012_val_00009374.JPEG",
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00046767.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00047983.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00048240.JPEG'
]
for test_image in test_images:
    sr_imgs = []
    for model in models:
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(
            model, test_image)
        sr_imgs.append(sr_img)
    visualize_images(sr_imgs, hr_img, bicubic_img, model_weights.keys())

python(43989) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(43991) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(43999) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44000) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44001) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44002) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44034) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44035) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
