In [1]:
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
from models import SRResNet
import random

In [2]:

device = torch.device("cuda" if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [27]:
def super_resolve(model, img, crop_size):
    """
    Apply super-resolution using a given model to a specified image and compare it
    with the bicubic-upsampled version. The function processes an image to produce
    low-resolution, bicubic-upsampled, and super-resolved versions for evaluation.

    :param model: The super-resolution model to be used for image processing.
    :param img: Filepath of the high-resolution (HR) image to be processed.
    :return: A tuple containing the low-resolution image, the original high-resolution image,
             the super-resolved image, and the bicubic-upsampled image.
    """

    # Load the high-resolution (HR) image from the given file path
    hr_img = Image.open(img)
    # Ensure the image is in RGB format
    hr_img = hr_img.convert('RGB')

    if crop_size:
        width, height = hr_img.size   # Get dimensions
        new_width, new_height = crop_size, crop_size
        left = (width - new_width)/2
        top = (height - new_height)/2
        right = (width + new_width)/2
        bottom = (height + new_height)/2

        # Crop the center of the image
        hr_img = hr_img.crop((left, top, right, bottom))

    # Create a low-resolution (LR) version of the image by downscaling using bicubic interpolation
    lr_img = hr_img.resize(
        (int(hr_img.width / 4), int(hr_img.height / 4)), Image.BICUBIC)

    # Upsample the low-resolution image back to high-resolution using bicubic interpolation
    # This serves as a baseline for comparison with the super-resolved image
    bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

    # Normalize the low-resolution image for model input according to ImageNet standards
    # `convert_image` is assumed to be a function that adjusts image data to model-specific input requirements
    imagenet_normed = convert_image(
        lr_img, source='pil', target='imagenet-norm')

    # Process the normalized low-resolution image using the super-resolution model
    # Assuming the model and device (GPU/CPU) are correctly configured
    with torch.no_grad():
        sr_img = model(imagenet_normed.unsqueeze(0).to(device))

    # Remove the batch dimension and move the image data back to CPU, if necessary
    sr_img = sr_img.squeeze(0).cpu().detach()

    # Convert the model output back to a PIL image from its normalized form
    sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')

    # Return the processed images
    return lr_img, hr_img, sr_img, bicubic_img

In [24]:
def visualize_images(*imgs, labels=None, image_save_path=None):
    print(labels)
    if labels is None:
        labels = [''] * len(imgs)
    # Create grid
    margin = 40
    num_images = len(imgs)
    # Calculating number of rows to include HR and bicubic at top
    rows = (num_images + 1) // 2

    # Define the size of the grid image
    max_width = max([img.width for img in imgs])
    max_height = max([img.height for img in imgs])
    grid_img = Image.new('RGB',
                         (2 * (max_width + margin) + margin,
                          rows * (max_height + margin) + margin),
                         (255, 255, 255))

    # Initialize drawing context
    draw = ImageDraw.Draw(grid_img)

    font = ImageFont.load_default(size=23)

    # Loop through SR images and their corresponding model names starting from the second row
    row, column = 0, 0  # Start from the second row, first column

    for img, label in zip(imgs, labels):
        x = margin + column * (img.width + margin)
        y = margin + row * (img.height + margin)

        # Place image
        grid_img.paste(img, (x, y))
        text_size = font.getbbox(label)
        draw.text((x + img.width / 2 -
                   text_size[2] / 2, y - text_size[3] - 10), label, font=font, fill='black')

        # Update column and check if we need to move to the next row
        column += 1
        if column >= 2:
            column = 0
            row += 1

    # Display the grid image
    grid_img.show()
    if image_save_path:
        grid_img.save(image_save_path)

In [20]:
model_weights = {
    'SRResNet': 'SRResNet.pth',
    # 'SRResNet_attention': 'SRResNet_attention.pth',
    'SRGAN': 'SRResNet_Discriminator.pth',
    'SRGAN_EfficientNet': "SRResNet_EfficientNet.pth",
    'SRGAN_WGAN': 'SRResNet_Discriminator_WGAN.pth'
}

models = {}
for key, weight in model_weights.items():
    model = SRResNet(4, 8 if 'attention' in weight else 0, weight).to(device)
    model.eval()
    models[key] = model

Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!


In [30]:
# Define the path to the validation dataset directory
test_folder = '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X'

# Randomly select 3 folders from the test folder
test_images_folders = random.sample(os.listdir(test_folder), 3)

# For each selected folder, randomly select one image and construct its full path
test_images = [
    os.path.join(test_folder, folder, random.sample(os.listdir(
        os.path.join(test_folder, folder)), 1)[0])  # Randomly pick one file
    for folder in test_images_folders  # Iterate over the selected folders
]

# Print the paths of the randomly selected images
print(test_images)

# Loop through each selected test image
for test_image in test_images:
    # Initialize an empty list to store super-resolved images from each model
    sr_imgs = []

    # Loop over each model defined in 'models'
    for model_name, model in models.items():
        # Apply the super-resolution model to the test image
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(
            model, test_image, 200 if 'SRResNet_attention' in models else None)

        # Append the super-resolved image to the list of super-resolved images
        sr_imgs.append(sr_img)

    # Visualize all super-resolved images along with the high-resolution and bicubic images

    visualize_images(hr_img, bicubic_img, *sr_imgs,
                     labels=['Original', 'Bicubic', *models.keys()], image_save_path=os.path.join('output', os.path.basename(test_image)))

['/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01985128/ILSVRC2012_val_00022222.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01582220/ILSVRC2012_val_00000963.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01773797/ILSVRC2012_val_00003182.JPEG']
['Original', 'Bicubic', 'SRResNet', 'SRGAN', 'SRGAN_EfficientNet', 'SRGAN_WGAN']


python(80838) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(80841) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


['Original', 'Bicubic', 'SRResNet', 'SRGAN', 'SRGAN_EfficientNet', 'SRGAN_WGAN']


python(80863) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(80864) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


['Original', 'Bicubic', 'SRResNet', 'SRGAN', 'SRGAN_EfficientNet', 'SRGAN_WGAN']


python(80865) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(80866) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [95]:
test_images = [
    "/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01592084/ILSVRC2012_val_00009374.JPEG",
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00046767.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00047983.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00048240.JPEG'
]
for test_image in test_images:
    sr_imgs = []
    for model in models:
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(
            model, test_image)
        sr_imgs.append(sr_img)
    visualize_images(hr_img, bicubic_img, *sr_imgs,
                     labels=['Original', 'Bicubic', *model_weights.keys()])

['Original', 'Bicubic', 'SRResNet']


python(28861) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(28863) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


['Original', 'Bicubic', 'SRResNet']


python(28898) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(28900) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


['Original', 'Bicubic', 'SRResNet']


python(28904) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(28905) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


['Original', 'Bicubic', 'SRResNet']


python(28906) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(28907) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
