In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, Input, MaxPooling2D, BatchNormalization, Add
from tensorflow.keras import Model
from tensorflow.keras.regularizers import L2
from time import time

from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [15, 10]

# Set the seeds for reproducibility
from numpy.random import seed
from tensorflow.random import set_seed
seed_value = 1234578790
seed(seed_value)
set_seed(seed_value)

In [None]:
env = 'local'

if (env == 'colab'):
    from google.colab import drive
    drive.mount('/content/drive')
    train_folder = '/content/drive/MyDrive/Colab Notebooks/FaceUpscale/train'
    test_folder = '/content/drive/MyDrive/Colab Notebooks/FaceUpscale/test'
    cache_folder = '/content/cache'
    train_batch_size = 512

if (env == 'local'):
    train_folder = '../data/train'
    test_folder = '../data/test'
    cache_folder = '../data/cache'
    train_batch_size = 32

x_img_size = 32

y_img_size = 128
y_img_channels = 3


if (cache_folder is not None and not os.path.exists(cache_folder)):
    os.makedirs(cache_folder)

In [None]:
file_cache_write = 0
file_cache_read = 0
file_source_read = 0

train_files_list = []
test_files_list = []

def reset_cache_counters():
    global file_cache_write 
    global file_cache_read 
    global file_source_read
    file_cache_write = 0
    file_cache_read = 0
    file_source_read = 0

def image_file_iterator(root):
    for subdir, dirs, files in os.walk(root):        
        for file in files:
            if (file.endswith('.png')):
                yield os.path.join(subdir, file)

        for dir in dirs:
            for file in image_file_iterator(dir):
                yield os.path.join(subdir, file)

def train_files_shuffled_iterator():
    global train_files_list
    if (len(train_files_list) == 0):
        train_files_list = list(image_file_iterator(train_folder))

    np.random.shuffle(train_files_list)

    for f in train_files_list:
        yield f

def test_files_iterator():
    global test_files_list
    if (len(test_files_list) == 0):
        test_files_list = list(image_file_iterator(test_folder))

    for f in test_files_list:
        yield f

def scale_and_normalize(img, size):
    if (img.shape[0] != size or img.shape[1] != size):
        img = cv2.resize(img, (size, size))
    return img / 255

def make_cache_file_name(id, img_size, prefix):
    return prefix + str(hash(id + str(img_size))) + '.bin'

def load_cached_array(id, img_size, prefix):
    file = os.path.join(cache_folder, make_cache_file_name(id, img_size, prefix))

    if (os.path.exists(file)):
        global file_cache_read
        file_cache_read += 1
        return np.fromfile(file)
    else:
        return None

def save_array_to_cache(arr, id, img_size, prefix):
    file = os.path.join(cache_folder, make_cache_file_name(id, img_size, prefix))

    if (not os.path.exists(file)):
        global file_cache_write
        file_cache_write += 1
        arr.tofile(file)

def load_cached_xy_train(id):
    x_cached = load_cached_array(id, x_img_size, '_x')

    if (x_cached is None):
        return (None, None)
    
    y_cached = load_cached_array(id, y_img_size, '_y')

    if (y_cached is None):
        return (None, None)
    
    return (x_cached.reshape(x_img_size, x_img_size), y_cached.reshape(y_img_size, y_img_size, 3))

def save_xy_train_to_cache(id, x_img, y_img):
    save_array_to_cache(x_img, id, x_img_size, "_x")
    save_array_to_cache(y_img, id, y_img_size, "_y")

def data_iterator(file_iterator, batch_size):
    files_pending = True
    while(files_pending):
        x_batch = []
        y_batch = []

        for ii in range(batch_size):
            fpath = next(file_iterator, None) 

            if (fpath is not None):
                # try to load cached data
                x_img, y_img = load_cached_xy_train(fpath)

                if (x_img is None or y_img is None):
                    # Load source image
                    global file_source_read
                    file_source_read += 1
                    src_img = cv2.imread(fpath)
                    src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
                    # Make X
                    x_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2GRAY)
                    x_img = scale_and_normalize(x_img, x_img_size)
                    # Make Y
                    y_img = scale_and_normalize(src_img, y_img_size)
                    # Save to cache
                    save_xy_train_to_cache(fpath, x_img, y_img)

                x_batch.append(x_img)
                y_batch.append(y_img)

            else:
                files_pending = False
                break
        
        if len(x_batch) > 0:
            yield np.array(x_batch), np.array(y_batch)

In [None]:
# Let's see the number of data
train_set_len = sum(1 for _ in image_file_iterator(train_folder))
test_set_len = sum(1 for _ in image_file_iterator(test_folder))

print(train_set_len)
print(test_set_len)

# Also do a smoke test for datagen 
# for 1000 test images we expecting 1000 batches (when batch_size = 1)
if (sum(1 for _ in data_iterator(test_files_iterator(), batch_size=1)) != test_set_len):
    print('datagen failure!')

# for 1000 test images we expecting 500 batches (when batch_size = 2)
if (sum(1 for _ in data_iterator(test_files_iterator(), 2)) != test_set_len / 2):
    print('datagen failure!')

print('Source reads:', file_source_read)
print('Cache reads:', file_cache_read)
print('Cache writes:', file_cache_write)

# Let's visualize cached image to ensure it cached and read OK
reset_cache_counters()
cached_x_batch, cached_y_batch = next(data_iterator(test_files_iterator(), batch_size=1))
# ensure images were load from cache
if (file_cache_read == 2):
    plt.subplot(121), plt.imshow(cached_x_batch[0], cmap='gray')  
    plt.subplot(122), plt.imshow(cached_y_batch[0]) 
    print('cached x shape:', cached_x_batch[0].shape)
    print('cached y shape:', cached_y_batch[0].shape)
else:
    print('cache failure!')



In [None]:
# test datagen
batch = next(data_iterator(train_files_shuffled_iterator(), 10))

x_train_batch = batch[0]
y_train_batch = batch[1]

# Show x_train
for ii in range(x_train_batch.shape[0]):
    plt.subplot(3,5,ii+1), plt.imshow(x_train_batch[ii], cmap = 'gray'), plt.title(ii)

In [None]:
# Show y_train
for ii in range(y_train_batch.shape[0]):
    plt.subplot(3,5,ii+1), plt.imshow(y_train_batch[ii]), plt.title(ii)

In [None]:
def rdb_block(inputs, layers_count):
    # get number of input channels
    channels = inputs.get_shape()[-1]
    # initialize outputs list
    outputs = [inputs]
    
    # common Conv2D args
    conv_args = {
        "activation": "relu",
        "kernel_initializer": "Orthogonal",
        "padding": "same",
    }

    # Make Residual Dense Block
    for _ in range(layers_count):
        concatenation = tf.concat(outputs, axis=-1)
        net = Conv2D(channels, 3, **conv_args)(concatenation)
        outputs.append(net)

    # Make final resulting net
    final_concatenation = tf.concat(outputs, axis=-1)
    final_net = Conv2D(channels, 1, **conv_args)(final_concatenation)

    # Add input net and final output net (RDB)
    final_net = Add()([final_net, inputs])

    return final_net

In [None]:
def psnr(orig, pred):
	# cast the target images to integer
	orig = orig * 255.0
	orig = tf.cast(orig, tf.uint8)
	orig = tf.clip_by_value(orig, 0, 255)
	# cast the predicted images to integer
	pred = pred * 255.0
	pred = tf.cast(pred, tf.uint8)
	pred = tf.clip_by_value(pred, 0, 255)
	# return the psnr
	return tf.image.psnr(orig, pred, max_val=255)

In [None]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'

conv_args = {
        "activation": "relu",
        "kernel_initializer": "Orthogonal",
        "padding": "same",
    }

scale_ratio = y_img_size / x_img_size
print('Scale ratio: ', scale_ratio)

inputs = Input(shape=(x_img_size, x_img_size, 1))
net = Conv2D(64, 5, **conv_args)(inputs)
net = Conv2D(64, 3, **conv_args)(net)
# Adding RDB Block
net = rdb_block(net, layers_count=7)
net = Conv2D(32, 3, **conv_args)(net)
# Another one RDB Block
net = rdb_block(net, layers_count=7)
# Pixel Shuffle magic here
net = Conv2D(y_img_channels * (scale_ratio ** 2), 3, **conv_args)(net)
outputs = tf.nn.depth_to_space(net, scale_ratio)

#net = Conv2D(128, 13, **conv_args)(inputs)
#net = MaxPooling2D(2)(net)
#net = Conv2D(64, 5, **conv_args)(inputs)
#net = Conv2D(64, 3, **conv_args)(net)
#net = Dropout(0.1)(net)
#net = Conv2D(32, 3, **conv_args)(net)
#net = Conv2D(3 * (scale_ratio ** 2), 3, **conv_args)(net)
#net = Dropout(0.2)(net)
#outputs = tf.nn.depth_to_space(net, scale_ratio)


model = Model(inputs, outputs)
model.summary()

In [None]:
def datagen(batch_size):
    while(True):
        iterator = data_iterator(train_files_shuffled_iterator(), batch_size)
        result = next(iterator, None)

        if (result is None):
            iterator = data_iterator(train_files_shuffled_iterator(), batch_size)
        else:
            yield result

In [None]:
# Train the network
epochs = 10
steps_per_epoch = train_set_len / train_batch_size + 1

print(steps_per_epoch)

x_test, y_test = next(data_iterator(test_files_iterator(), test_set_len))

print(len(x_test))

#validation_data = (x_test), np.array(y_test))

model.compile(loss="mse", optimizer="adam", metrics=psnr)
#history = model.fit(src, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=validation_data)
history = model.fit(datagen(train_batch_size), steps_per_epoch=steps_per_epoch, epochs=epochs, validation_data=(x_test, y_test))

In [None]:
test_range = 20

batch = next(datagen(test_folder, test_range))

x_test = batch[0]

y_test = model.predict(x_test)

#img = np.array(y_test[9])

#plt.imshow(img)

#for ii in range(0, test_range-1):
#    plt.subplot(10, 2, 1), plt.imshow(x_test[ii], cmap='gray')
#    plt.subplot(10, 2, 2), plt.imshow(np.array(y_test[ii]))

for ii in range(test_range):
    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(x_test[ii], cmap='gray')
    axarr[1].imshow(cv2.resize(x_test[ii], (y_img_size, y_img_size), interpolation=cv2.INTER_LINEAR), cmap='gray')
    axarr[2].imshow(np.array(y_test[ii]))
