# Upscale Images

Imports

In [19]:
import torch
from PIL import Image
import torchvision.transforms.functional as F
from src.esrgan import RRDBNet
import matplotlib.pyplot as plt
import random

Function to Load Model saved in ./models

In [20]:
def load_model(model_path, device):
    """
    Load an ESRGAN RRDBNet generator from a checkpoint.

    Args:
        model_path (str): Path to the .pth checkpoint file.
        device (str or torch.device): Device to load the model onto (e.g. "cpu" or "cuda").

    Returns:
        RRDBNet: The generator model moved to `device` and set to eval() mode.
    """
    print(f"Loading model from: {model_path}")
    generator = RRDBNet().to(device)
    generator.load_state_dict(torch.load(model_path, map_location=device))
    generator.eval()
    print("Model loaded successfully!")
    return generator

Function to Upscale Images

In [21]:
def upscale_image(generator, lr_image_path, output_path, device):
    print(f"Loading image: {lr_image_path}")
    lr = Image.open(lr_image_path).convert("RGB")
    original_size = lr.size
    print(f"Original size: {original_size[0]}x{original_size[1]}")
    
    lr_t = F.to_tensor(lr).unsqueeze(0).to(device)
    
    print("Upscaling...")
    with torch.no_grad():
        sr = generator(lr_t)
    
    sr = torch.clamp(sr, 0, 1)
    sr_img = F.to_pil_image(sr.squeeze(0).cpu())
    
    sr_img.save(output_path)
    upscaled_size = sr_img.size
    print(f"Upscaled size: {upscaled_size[0]}x{upscaled_size[1]}")
    print(f"Saved to: {output_path}")

Function to Compare Original and Upscaled Images

In [22]:
def comparison(lr, sr):
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))

    axes[0, 0].imshow(lr)
    axes[0, 0].set_title("Original Image (Full)")
    axes[0, 0].axis("off")

    axes[0, 1].imshow(sr)
    axes[0, 1].set_title("Upscaled Image (Full)")
    axes[0, 1].axis("off")

    zoom_size = 200
    lr_width, lr_height = lr.size
    sr_width, sr_height = sr.size

    max_x = max(0, lr_width - zoom_size)
    max_y = max(0, lr_height - zoom_size)

    if max_x > 0 and max_y > 0:
        x = random.randint(0, max_x)
        y = random.randint(0, max_y)

        zoom_box_lr = (x, y, x + zoom_size, y + zoom_size)
        zoom_box_sr = (x * 4, y * 4, (x + zoom_size) * 4, (y + zoom_size) * 4)
        
        lr_zoom = lr.crop(zoom_box_lr)
        sr_zoom = sr.crop(zoom_box_sr)
        
        axes[1, 0].imshow(lr_zoom)
        axes[1, 0].set_title(f"Original - Zoomed")
        axes[1, 0].axis("off")
        
        axes[1, 1].imshow(sr_zoom)
        axes[1, 1].set_title(f"Upscaled - Zoomed")
        axes[1, 1].axis("off")
    else:
        axes[1, 0].text(0.5, 0.5, "Image too small for zoom", ha='center', va='center')
        axes[1, 1].text(0.5, 0.5, "Image too small for zoom", ha='center', va='center')

    plt.tight_layout()
    plt.show()

### Important Parameters
- device: GPU is highly preferred, CPU would work but would be painfully slow.
- model_path: ./models has ESRGAN.pth and ESRGAN_PSNR.pth. Only use the ESRGAN.pth, the other one is Phase 1 (PSNR) Trained only.
- input_image: You can place the image anywhere, preferably in ./images.
- output_path: You can place it anywhere, preferably in ./images.

In [39]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}\n")
model_path = "models/ESRGAN.pth"
input_image = "images/test.jpg"
output_path = "images/example_output_4.png"

Using device: cuda



4x Upscale your Image

In [None]:
generator = load_model(model_path, device)
upscale_image(generator, input_image, output_path, device)

lr = Image.open(input_image).convert("RGB")
sr = Image.open(output_path).convert("RGB")

comparison(lr, sr)