In [2]:
import os
import random
import PIL
from easydict import EasyDict
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

In [5]:
#Note: 
#    path2data: Add all the data to this directory
#    hr_path: Output HR images
#    lr_path: Output LR images

config = {
        "upscale": 4,
        "path2data": "/content/hr_image",
        "hr_path": "HR_dataset/",
        "lr_path": "LR_dataset/",
        }

args = EasyDict(config)

In [6]:
if not os.path.isdir(args.hr_path):
    os.mkdir(args.hr_path)

if not os.path.isdir(args.lr_path):
    os.mkdir(args.lr_path)

In [7]:
ALL_INTERPOLATIONS = [
        InterpolationMode.NEAREST,
        InterpolationMode.BILINEAR,
        InterpolationMode.BICUBIC,
        InterpolationMode.BOX,
        InterpolationMode.HAMMING,
        InterpolationMode.LANCZOS
                ]

IMG_EXTENSIONS = [".jpg", ".png", ".jpeg", ".JPG", ".PNG", ".JPEG"]

In [11]:
def check_image(filename):
    return any(filename.endswith(ext) for ext in IMG_EXTENSIONS)

def get_dimensions(img, upscale):
    '''
    input : PIL Image
    output: image with HR size multiple of the Upscale Factor needed
    '''

    w,h = img.size

    W = int(w - (w%upscale))
    H = int(h - (h%upscale))

    return W,H

In [12]:
all_images = sorted([os.path.join(args.path2data,name) for name in os.listdir(args.path2data) if check_image(name)])

for names in all_images:
    img = PIL.Image.open(names).convert("RGB")
    W,H = get_dimensions(img, args.upscale)                                           #Dimension now is a multiple of UPSCALE_FACTOR
    l = len(ALL_INTERPOLATIONS)
    img_name = os.path.splitext(names.split("/")[-1])[0]                #Extracting the name

    lr_transformation = transforms.Compose([
                            transforms.Resize((H//(2*args.upscale), W//((2*args.upscale))), interpolation= ALL_INTERPOLATIONS[random.randint(0, (l-1))]),
                            transforms.Resize((H//args.upscale, W//args.upscale), interpolation= ALL_INTERPOLATIONS[random.randint(0, (l-1))]),   #Torch follows (H,W) format
                                        ])

    hr_transformation = transforms.Compose([
                            transforms.CenterCrop((H,W))   
                                            ])
    
    hr_img = hr_transformation(img)
    lr_img = lr_transformation(hr_img)

    #Saving the images
    hr_img.save(args.hr_path + img_name + ".png", format= 'png')
    lr_img.save(args.lr_path + img_name + ".png", format= 'png')

print("Task Completed Succesfully!")

Task Completed Succesfully!
