# Preprocess Dataset/Generate Dataset

In [1]:
import os
import torch
import pathlib
import shutil
import numpy as np
from tqdm import tqdm
import PIL
from fastai.vision import *
from torchvision import transforms

In [2]:
# TODO: constants to be changed
num_extra = 2
# final image size will be target_size
target_size = 300
# procedure: resize to the following sizes and resize back to target_size
sizes = [300, 200, 150, 100, 50]

In [3]:
DIV2K_path = pathlib.Path('.').parent.absolute()
train_HR = DIV2K_path/'DIV2K_train_HR'
valid_HR = DIV2K_path/'DIV2K_valid_HR'
assert DIV2K_path.exists() and train_HR.exists() and valid_HR.exists()

In [4]:
output_path = DIV2K_path/'same_300'
if output_path.exists():
    shutil.rmtree(output_path)
output_path.mkdir(parents=True, exist_ok=True)

## Generate Center Cropped & Transformed Images

In [5]:
hr_valid_image_list = ImageList.from_folder(valid_HR)
hr_train_image_list = ImageList.from_folder(train_HR)
hr_train_image_name_list = [img_path.relative_to(train_HR) for img_path in hr_train_image_list.items]
shapes = [PIL.Image.open(img_path).size for img_path in hr_train_image_list.items]
print(f"min dimension of all training images={torch.min(torch.tensor(shapes))}")
hr_valid_image_name_list = [img_path.relative_to(valid_HR) for img_path in hr_valid_image_list.items]
shapes = [PIL.Image.open(img_path).size for img_path in hr_valid_image_list.items]
print(f"min dimension of all validation images={torch.min(torch.tensor(shapes))}")

min dimension of all training images=648
min dimension of all validation images=816


In [6]:
class ImageAugmentor(object):
    def __init__(self, dest_dir, num_extra:int=2):
        assert num_extra <= 5
        self.num_extra = num_extra
        self.dest_dir = dest_dir
        self.base_image, self.image_name, self.filename_no_ext, self.ext = None, None, None, None
        self.generated_images, self.generated_image_names = [], []

    def generate(self):
        random_choices = np.random.choice(a=[1, 2, 3, 4, 5], replace=False, size=self.num_extra)
        for choice in random_choices:
            if choice == 1:
                self.generated_images.append(transforms.functional.hflip(self.base_image))
                self.generated_image_names.append(self.filename_no_ext + '-hf' + self.ext)
            elif choice == 2:
                self.generated_images.append(transforms.functional.vflip(self.base_image))
                self.generated_image_names.append(self.filename_no_ext + '-vf' + self.ext)
            elif choice == 3:
                self.generated_images.append(transforms.functional.rotate(self.base_image, 90))
                self.generated_image_names.append(self.filename_no_ext + '-r90' + self.ext)
            elif choice == 4:
                self.generated_images.append(transforms.functional.rotate(self.base_image, 180))
                self.generated_image_names.append(self.filename_no_ext + '-r180' + self.ext)
            elif choice == 5:
                self.generated_images.append(transforms.functional.rotate(self.base_image, 270))
                self.generated_image_names.append(self.filename_no_ext + '-r270' + self.ext)

    def save(self):
        for i in range(len(self.generated_image_names)):
            output_path = self.dest_dir/self.generated_image_names[i]
            self.generated_images[i].save(output_path)

    def __call__(self, image_path, i):
        self.image_name = os.path.basename(image_path)
        self.filename_no_ext, self.ext = os.path.splitext(self.image_name)
        image = PIL.Image.open(image_path)
        self.base_image = transforms.Compose([transforms.CenterCrop(min(image.size))])(image)
        image.close()
        self.generated_images = [self.base_image]
        self.generated_image_names = [self.image_name]
        self.generate()
        self.save()


In [7]:
# create center cropped & transformed image directory
# cct stands for center-crop-transform
cct_train = output_path/'train_cct'
cct_valid = output_path/'valid_cct'
for path in [cct_train, cct_valid]:
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)

In [8]:
%%time
# transform and save images
parallel(ImageAugmentor(cct_valid, num_extra=0), hr_valid_image_list.items)
parallel(ImageAugmentor(cct_train, num_extra=num_extra), hr_train_image_list.items)

## Generate All Other Size Images

In [9]:
class Resizer(object):
    def __init__(self, src_path, dest_path, size:int, target_size:int):
        self.src_path = src_path
        self.dest_path = dest_path
        self.size = size
        self.target_size = target_size

    def __call__(self, image_name, i):
        src_image_path = self.src_path/image_name
        target_image_path = self.dest_path/image_name
        src_img = PIL.Image.open(src_image_path)
        resized_image = src_img.resize((self.size, self.size), resample=PIL.Image.BICUBIC).convert('RGB').resize((self.target_size, self.target_size), resample=PIL.Image.BICUBIC).convert('RGB')
        src_img.close()
        resized_image.save(target_image_path)

In [10]:
cct_valid_list = ImageList.from_folder(cct_valid)
cct_valid_names = [img_path.relative_to(cct_valid) for img_path in cct_valid_list.items]
cct_train_list = ImageList.from_folder(cct_train)
cct_train_names = [img_path.relative_to(cct_train) for img_path in cct_train_list.items]

In [11]:
%%time
for size in tqdm(sizes):
    print(f"\n\nsize={size}")
    sub_out_train = output_path/f'train_{size}'
    sub_out_valid = output_path/f'valid_{size}'
    if not sub_out_train.exists():
        sub_out_train.mkdir(parents=True, exist_ok=True)
    if not sub_out_valid.exists():
        sub_out_valid.mkdir(parents=True, exist_ok=True)
    parallel(Resizer(cct_train, sub_out_train, size, target_size), cct_train_names)
    parallel(Resizer(cct_valid, sub_out_valid, size, target_size), cct_valid_names)
print("\n")

  0%|          | 0/5 [00:00<?, ?it/s]

size=300


 20%|██        | 1/5 [00:32<02:10, 32.55s/it]

size=200


 40%|████      | 2/5 [01:04<01:36, 32.33s/it]

size=150


 60%|██████    | 3/5 [01:34<01:03, 31.58s/it]

size=100


 80%|████████  | 4/5 [02:03<00:30, 30.84s/it]

size=50


In [12]:
# clean the super-large but unecessary cct_valid and cct_train images (they are for generating other images, not useful anymore)
shutil.rmtree(cct_valid) 
shutil.rmtree(cct_train) 