In [41]:
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
from models import SRResNet
import random

In [6]:

device = torch.device("cuda" if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [40]:
model_weights = {'srresnet': 'SRResNet.pth',
                 #  'srresnet_attention ': 'SRResNet_attention.pth',
                 'srgan': 'SRResNet_Discriminator.pth',
                 'srgan_efficientnet': "SRResNet_EfficientNet.pth",
                 'srgan_wgan': 'SRResNet_Discriminator_WGAN.pth'}

models = []
for key, value in model_weights.items():
    model = SRResNet(4, 8 if 'attention' in value else 0, value).to(device)
    model.eval()
    models.append(model)

Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!
Model's pretrained weights loaded!


In [36]:

def super_resolve(model, img):
    """
    Visualizes the super-resolved images from the SRResNet and SRGAN for comparison with the bicubic-upsampled image
    and the original high-resolution (HR) image, as done in the paper.

    :param img: filepath of the HR iamge
    :param halve: halve each dimension of the HR image to make sure it's not greater than the dimensions of your screen?
                  For instance, for a 2160p HR image, the LR image will be of 540p (1080p/4) resolution. On a 1080p screen,
                  you will therefore be looking at a comparison between a 540p LR image and a 1080p SR/HR image because
                  your 1080p screen can only display the 2160p SR/HR image at a downsampled 1080p. This is only an
                  APPARENT rescaling of 2x.
                  If you want to reduce HR resolution by a different extent, modify accordingly.
    """
    # Load image, downsample to obtain low-res version
    hr_img = Image.open(img, mode="r")
    hr_img = hr_img.convert('RGB')
    lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)),
                           Image.BICUBIC)

    # Bicubic Upsampling
    bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

    # Super-resolution (SR) with SRResNet
    imagenet_normed = convert_image(
        lr_img, source='pil', target='imagenet-norm')

    sr_img = model(imagenet_normed.unsqueeze(0).to(device))
    sr_img = sr_img.squeeze(0).cpu().detach()
    sr_img = convert_image(
        sr_img, source='[-1, 1]', target='pil')

    return lr_img, hr_img, sr_img, bicubic_img

In [37]:
def visualize_images(sr_imgs, hr_img, bicubic_img, model_names):
    # Create grid
    margin = 40
    num_models = len(model_weights)
    # Calculating number of rows to include HR and bicubic at top
    rows = (num_models + 2) // 2

    # Define the size of the grid image
    grid_img = Image.new('RGB',
                         (2 * (bicubic_img.width + margin) + margin,
                          rows * (hr_img.height + margin) + margin),
                         (255, 255, 255))

    # Initialize drawing context
    draw = ImageDraw.Draw(grid_img)

    font = ImageFont.load_default(size=23)

    # Place the original HR image in the top left
    x_hr = margin
    y_hr = margin
    grid_img.paste(hr_img, (x_hr, y_hr))
    text_size = font.getbbox("Original HR")
    draw.text((x_hr + hr_img.width / 2 -
               text_size[2] / 2, y_hr - text_size[3] - 10), "Original HR", font=font, fill='black')

    # Place the bicubic-upsampled image in the top right
    x_bicubic = margin + hr_img.width + margin
    y_bicubic = margin
    grid_img.paste(bicubic_img, (x_bicubic, y_bicubic))
    text_size = font.getbbox("Bicubic")
    draw.text((x_bicubic + bicubic_img.width / 2 -
               text_size[2] / 2, y_bicubic - text_size[3] - 10), "Bicubic", font=font, fill='black')

    # Loop through SR images and their corresponding model names starting from the second row
    row, column = 1, 0  # Start from the second row, first column

    for sr_img, model_name in zip(sr_imgs, model_names):
        x = margin + column * (sr_img.width + margin)
        y = margin + row * (sr_img.height + margin)

        # Place image
        grid_img.paste(sr_img, (x, y))
        text_size = font.getbbox(model_name)
        draw.text((x + sr_img.width / 2 -
                   text_size[2] / 2, y - text_size[3] - 10), model_name, font=font, fill='black')

        # Update column and check if we need to move to the next row
        column += 1
        if column >= 2:
            column = 0
            row += 1

    # Display the grid image
    grid_img.show()

In [53]:
test_images = [
    "/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01592084/ILSVRC2012_val_00009374.JPEG",
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00046767.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00047983.JPEG',
    '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n02077923/ILSVRC2012_val_00048240.JPEG'
]
for test_image in test_images:
    sr_imgs = []
    for model in models:
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(
            model, test_image)
        sr_imgs.append(sr_img)
    visualize_images(sr_imgs, hr_img, bicubic_img, model_weights.keys())

python(43989) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(43991) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(43999) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44000) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44001) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44002) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44034) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(44035) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [52]:
test_folder = '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X'
test_images_folders = random.sample(os.listdir(test_folder), 3)
test_images = [os.path.join(test_folder, folder, random.sample(os.listdir(
    os.path.join(test_folder, folder)), 1)[0]) for folder in test_images_folders]
print(test_images)
for test_image in test_images:
    sr_imgs = []
    for model in models:
        lr_img, hr_img, sr_img, bicubic_img = super_resolve(
            model, test_image)
        sr_imgs.append(sr_img)
    visualize_images(sr_imgs, hr_img, bicubic_img, model_weights.keys())

['/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01986214/ILSVRC2012_val_00021819.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01924916/ILSVRC2012_val_00049105.JPEG', '/Users/youssefshaarawy/Documents/Datasets/INM705/data/val.X/n01695060/ILSVRC2012_val_00008095.JPEG']


python(41585) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(41618) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(41626) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(41627) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(41630) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(41631) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
