In [None]:
import os
import sys

import matplotlib.pyplot as plt
import gc
import tensorflow as tf
keras = tf.keras

from keras.layers import Conv2D, MaxPool2D, BatchNormalization, LeakyReLU, Concatenate, Activation, Input
from keras.layers import Conv2DTranspose as Deconv2D
from keras.models import Model

import cv2

################################################################# FUNCTION DECLARATION

s = tf.compat.v1.InteractiveSession()

def input_layer(n, input):
    layer = Conv2D(n, (3, 3), padding='same')(input)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(alpha=0.2)(layer)
    layer = Conv2D(n, (3, 3), strides=1, padding='same')(layer)
    layer = BatchNormalization()(layer)
    layer = LeakyReLU(alpha=0.2)(layer)
    
    return layer
    
def maxPool(n, input):
    layer = MaxPool2D((2, 2), strides=2)(input)
    
    for i in range(2):
        layer = Conv2D(n, (3, 3), padding='same')(layer)
        layer = BatchNormalization()(layer)
        layer = LeakyReLU(alpha=0.2)(layer)
    
    return layer
    
def upConv(n, input_1, input_2):
    layer = Deconv2D(n, (2, 2), strides=2)(input_2)
    layer = Concatenate()([input_1, layer])
    
    for i in range(2):
        layer = Conv2D(n, (3, 3), padding='same')(layer)
        layer = BatchNormalization()(layer)
        layer = Activation('relu')(layer)
    
    return layer

#define the model
def UNet(x_shape):
    input = Input(x_shape)
    
    #convolution
    X0 = input_layer(64, input)
    X1 = maxPool(128, X0)
    X2 = maxPool(256, X1)
    X3 = maxPool(512, X2)
    X4 = maxPool(1024, X3)
    
    #up convolution
    X4 = upConv(512, X3, X4)
    X3 = upConv(256, X2, X4)
    X2 = upConv(128, X1, X3)
    X1 = upConv(64, X0, X2)
    
    #to 3 channel
    X0 = Conv2D(3, (1, 1), strides=1)(X1)
    
    #define the output model
    model = Model(inputs=input, outputs=X0)
    
    return model

#load the previous training history
def load_previous_epoch(path_history):
    file = open(path_history, "r")
    previous_epoch = 0
  
    CoList = file.read().split("\n") 
    
    file.close()
    for i in CoList: 
        if i: 
            previous_epoch += 1
    return previous_epoch
  
#testing and show the images
def test_sample(color, gray, w, e):
    output = model.predict(gray)
    
    #convert from BGR to RGB
    output = output[...,::-1]
    color = color[...,::-1]
    gray = gray[...,::-1]
    
    n=1
    for i in range(3):
        plt.subplot(3, 3, n)
        plt.imshow(color[i].reshape((w, w, 3)), cmap="gray", interpolation='none')
        plt.subplot(3, 3, n+1)
        plt.imshow(gray[i].reshape((w, w)), cmap="gray", interpolation='none')
        plt.subplot(3, 3, n+2)
        plt.imshow(output[i].reshape((w, w, 3)), cmap="gray", interpolation='none')
        n += 3
    
    #save_test_path = '/content/gdrive/My Drive/colorization/256/img/Epoch_'+str(e)+'.jpg'
    #plt.savefig(save_test_path)
    plt.show()

#read images and convert to array form
def batch_generator():
    images = []
    
    train_list = random.sample(range(TOTAL_TRAIN), TRAIN_IMAGES_PER_EPOCH)
    test_list = random.sample(range(TOTAL_TEST), TEST_IMAGES_PER_EPOCH)
    t_load = TRAIN_IMAGES_PER_EPOCH + TEST_IMAGES_PER_EPOCH
    path_list = train_list + test_list
    
    i = 0
    for id in train_list:
        filename = img_path_train[id]
        img = cv2.imread(filename)
                                  
        img=cv2.resize(img,(w,w))
        images.append(img)
      
        sys.stdout.write("\rLoading training img {}/{}".format(i+1, TRAIN_IMAGES_PER_EPOCH))
        sys.stdout.flush()
    
        i = i + 1
      
    i = 0
    for id in test_list:
        filename = img_path_train[id]
        img = cv2.imread(filename)
                                  
        img=cv2.resize(img,(w,w))
        images.append(img)
      
        sys.stdout.write("\rLoading testing img {}/{}".format(i+1, TEST_IMAGES_PER_EPOCH))
        sys.stdout.flush()
    
        i = i + 1

    images = np.array(images)
    images = images/255.

    gray_images = np.mean(images, axis=-1)
    gray_images = gray_images.reshape((*gray_images.shape, 1))

    return images[:TRAIN_IMAGES_PER_EPOCH], gray_images[:TRAIN_IMAGES_PER_EPOCH], images[-TEST_IMAGES_PER_EPOCH:], gray_images[-TEST_IMAGES_PER_EPOCH:]
   
############################################################################# DEFINE PARAMETERS
EPOCH = 2000
BATCH_SIZE = 16
TRAIN_IMAGES_PER_EPOCH = 2048
TEST_IMAGES_PER_EPOCH = 512
load_PreTrain = True

w = 256

############################################################################ LOADING PREVIOUS INFORMATION
x_shape = (w,w,1)
model = UNet(x_shape)

model.compile('adam', loss='mean_squared_error', metrics=['mae', 'acc'])
#model.summary()

if load_PreTrain:
    model.load_weights("./model.hdf5")
    previous_epoch = load_previous_epoch('./history.txt')
    
else:
    previous_epoch = 0

print('Start from Epoch : {}'.format(previous_epoch))

############################################################################# TRAINING
for epoch in range(previous_epoch, previous_epoch+EPOCH):

    print('EPOCH : {}'.format(epoch))

    Y_train, X_train, Y_test, X_test = batch_generator()

    hist = model.fit(X_train, Y_train, batch_size=BATCH_SIZE, validation_data=(X_test, Y_test))
    
    record = ("{} {} {} {} {} {}\n".format(hist.history['loss'][0],hist.history['mae'][0],hist.history['acc'][0],hist.history['val_loss'][0],hist.history['val_mae'][0],hist.history['val_acc'][0]))
    
    with open("./history.txt", "a") as file_object:
        file_object.write(record)

    model.save('./model.hdf5')
    
    test_sample(Y_test[:3], X_test[:3], w, epoch)