In [15]:
import cv2
import numpy

import os
import time
import math

from utility.h5data import h5DataWrite

In [16]:
DATA_FORMAT = ".png"
DATAPATH_TRAIN = "./imageTrain/"
DATAPATH_VALIDATE = "./imageValidate/"

RANDOM_CROP = 30
SIZE_PATCH = 32

SCALE = 4
INTERPOLATION = cv2.INTER_CUBIC

FILENAME_TRAIN = "yayoi_waifu2x_dataTrain_" + str(SCALE) + "_"+ str(SIZE_PATCH) + ".h5"
FILENAME_VALIDATE = "yayoi_waifu2x_dataValidate_" + str(SCALE) + "_"+ str(SIZE_PATCH) + ".h5"

### OpenCV interpolation methods
INTER_NEAREST - a nearest-neighbor interpolation<br>
INTER_LINEAR - a bilinear interpolation (used by default)<br>
INTER_AREA - resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire’-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.<br>
INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood<br>
INTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhood<br>

In [17]:
def parseData(_path):
    names = os.listdir(_path)
    names = sorted(names)
    nums = names.__len__()

    data = numpy.zeros((nums * RANDOM_CROP, 1, SIZE_PATCH, SIZE_PATCH, 3), dtype=numpy.double)
    label = numpy.zeros((nums * RANDOM_CROP, 1, SIZE_PATCH, SIZE_PATCH, 3), dtype=numpy.double)

    for i in range(nums):
        if DATA_FORMAT in names[i]:
            name = _path + names[i]

            hr_img = cv2.imread(name, cv2.IMREAD_COLOR)
            shape = hr_img.shape

            lr_img = cv2.resize(hr_img, (int(shape[1] / SCALE), int(shape[0] / SCALE)), INTERPOLATION)
            lr_img = cv2.resize(lr_img, (shape[1], shape[0]), INTERPOLATION)
            
            Points_x = numpy.random.randint(0, min(shape[0], shape[1]) - SIZE_PATCH, RANDOM_CROP)
            Points_y = numpy.random.randint(0, min(shape[0], shape[1]) - SIZE_PATCH, RANDOM_CROP)

            for j in range(RANDOM_CROP):
                lr_patch = lr_img[Points_x[j]: Points_x[j] + SIZE_PATCH, Points_y[j]: Points_y[j] + SIZE_PATCH]
                hr_patch = hr_img[Points_x[j]: Points_x[j] + SIZE_PATCH, Points_y[j]: Points_y[j] + SIZE_PATCH]

                lr_patch = lr_patch.astype(float) / 255.
                hr_patch = hr_patch.astype(float) / 255.

                data[i * RANDOM_CROP + j, 0, :, :, :] = lr_patch
                label[i * RANDOM_CROP + j, 0, :, :, :] = hr_patch
                
    return data, label

In [18]:
BLOCK_STEP = 16
BLOCK_SIZE = 32

In [19]:
def parseCropData(_path):
    names = os.listdir(_path)
    names = sorted(names)
    nums = names.__len__()

    data = []
    label = []

    for i in range(nums):
        if DATA_FORMAT in names[i]:
            name = _path + names[i]
            hr_img = cv2.imread(name, cv2.IMREAD_COLOR)
            shape = hr_img.shape

            lr_img = cv2.resize(hr_img, (int(shape[1] / SCALE), int(shape[0] / SCALE)))
            lr_img = cv2.resize(lr_img, (shape[1], shape[0]))

            width_num = (shape[0] - (BLOCK_SIZE - BLOCK_STEP) * 2) / BLOCK_STEP
            height_num = (shape[1] - (BLOCK_SIZE - BLOCK_STEP) * 2) / BLOCK_STEP
            for k in range(int(width_num)):
                for j in range(int(height_num)):
                    x = k * BLOCK_STEP
                    y = j * BLOCK_STEP
                    hr_patch = hr_img[x: x + BLOCK_SIZE, y: y + BLOCK_SIZE]
                    lr_patch = lr_img[x: x + BLOCK_SIZE, y: y + BLOCK_SIZE]

                    lr_patch = lr_patch.astype(float) / 255.
                    hr_patch = hr_patch.astype(float) / 255.

                    data.append(lr_patch)
                    label.append(hr_patch)

    data = numpy.array(data, dtype=float)
    label = numpy.array(label, dtype=float)
    return data, label

In [20]:
tick1 = cv2.getTickCount()

In [21]:
data, label = parseCropData(DATAPATH_TRAIN)
h5DataWrite(data, label, FILENAME_TRAIN)
print(FILENAME_TRAIN + " generated")
data, label = parseData(DATAPATH_VALIDATE)
h5DataWrite(data, label, FILENAME_VALIDATE)
print(FILENAME_VALIDATE + " generated")

yayoi_waifu2x_dataTrain_4_32.h5 generated
yayoi_waifu2x_dataValidate_4_32.h5 generated


In [22]:
tick2 = cv2.getTickCount()
tick = math.floor( ((tick2 - tick1) * 1000) / cv2.getTickFrequency())

if tick >= 60000:
    mins = math.floor(tick / 60000)
    secs = math.floor((tick - mins * 60000) / 1000)
    msec = tick - mins * 60000 - secs * 1000
    print("processed time: " + str(mins) + " mins " + str(secs) + " secs " + str(msec) + " ms")
elif tick >= 1000:
    secs = math.floor(tick / 1000)
    msec = tick - secs * 1000
    print("processed time: " + str(secs) + " secs " + str(msec) + " ms")
else:
    print("processed time: " + str(tick) + " ms")

processed time: 711 ms
