In [30]:
import os
import random
import time
import shutil
import numpy as np
from PIL import Image
import torch
from gradio_client import Client
from RealESRGAN import RealESRGAN

random.seed(420)

In [None]:
def dewarp(input_image_folder, output_image_folder):
    
    """
        Performs image restoration on all images in the input folder and saves the restored image 
        in the output folder
    """
    
    os.makedirs(output_image_folder, exist_ok=True)
    image_list = os.listdir(input_image_folder)
    
    for image_file in image_list:
        print(f"Dewarping {image_file}")
        input_image_path = os.path.join(input_image_folder, image_file)
        result_image_path = os.path.join(output_image_folder, image_file)
        start_time = time.time()
        dewarper = Client("https://doctrp.docscanner.top/")
        dewarped_image_path = dewarper.predict(input_image_path, api_name="/predict")
        end_time = time.time()
        shutil.move(dewarped_image_path, result_image_path)
        print(f"Dewarped {image_file}. {end_time-start_time:.3f}s taken!")

In [None]:
def get_model(device, scale):
    
    """ Get Real-ESRGAN model for super-resolution of given scale """
    
    model = RealESRGAN(device, scale=int(model_scale))
    model.load_weights(f'weights/RealESRGAN_x{model_scale}.pth')

    return model

    
def super_resolution(model, input_image_folder, output_image_folder):

    """
        Performs super-resolution on all images in the input_image_folder and saves the upscaled image
        in the output folder
    """
    
    os.makedirs(output_image_folder, exist_ok=True)
    images_list = os.listdir(input_image_folder)

    for image_file in images_list:
        input_image_path = os.path.join(input_image_folder, image_file)
        result_image_path = os.path.join(output_image_folder, image_file)
        image = Image.open(input_image_path).convert('RGB')
        sr_image = model.predict(np.array(image))
        sr_image.save(result_image_path)
        print(f'Finished! Image saved to {result_image_path}')

In [None]:
current_dir = os.getcwd()
input_image_folder = os.path.join(current_dir, "archive/scan_doc_rotation/images")
dewarp_output_image_folder = os.path.join(current_dir, "dewarped_images")

dewarp(input_image_folder, dewarp_output_image_folder)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_scale = "2" # Can be 2, 4, 8
sr_model = get_model(device, model_scale)

sr_output_image_folder = os.path.join(current_dir, "sr_images")

super_resolution(sr_model, dewarp_output_image_folder, sr_output_image_folder)