In [None]:
################################################################################
# Download dataset and models from google drive
################################################################################

!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

dataset_training = drive.CreateFile({'id': '1T3ekZdrQnPurftu-3jnoL-lu68ni4Fxh'})
dataset_training.GetContentFile("dataset_training_0.zip")

dataset_validation = drive.CreateFile({'id': '1Hu4bm92fmnNNPlvweKUgdaPRkL_uN74Z'})
dataset_validation.GetContentFile("dataset_validation_0.zip")

model = drive.CreateFile({'id': '1Ul5PTZ9S8CsuXEPEaLdngr_snClYvf9f'})
model.GetContentFile("colorization_model.hdf5")

pspnet = drive.CreateFile({'id': '1oFuPmSbFqy_b8UPxNHNCMmYC99mogCEm'})
pspnet.GetContentFile("pspnet.h5")

In [None]:
################################################################################
# Import libraries
################################################################################

import zipfile
from PIL import Image, ImageFile
from skimage.color import rgb2lab
import numpy as np
from skimage.transform import resize
from keras.models import Model, load_model
from keras.layers import Conv2D, concatenate, Input, UpSampling2D
from keras.callbacks import CSVLogger, Callback
from keras.optimizers import Adam
from keras import layers
from keras.backend import tf as ktf
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
################################################################################
# Generator and its functions
################################################################################

# Returns normalized numpy arrays l, a and b (w x h x 1)
def lab_img(img):
    if img.format != ("JPEG" or "JPG"):
        img = img.convert("RGB")
    
    # convert to LAB
    img = rgb2lab(img)

    # split color chanels
    l = img[:, :, 0]
    a = img[:, :, 1]
    b = img[:, :, 2]

    # normalize
    l = (np.array(l) / 100)
    a = ((np.array(a) + 127) / 255) * 2 - 1
    b = ((np.array(b) + 128) / 255) * 2 - 1

    l = np.expand_dims(l, axis=2)
    a = np.expand_dims(a, axis=2)
    b = np.expand_dims(b, axis=2)

    return l, a, b


# Returns batches of x, y, and segmentation for training
def batch_images(index, batch_size, images_size, image_paths, imgs, pspnet):
    
    # create empty batches
    x_batch = np.zeros((batch_size, images_size[1], images_size[0], 1))
    s_batch = np.zeros((batch_size, int(images_size[1] / 8), int(images_size[0] / 8), 150))
    y_batch = np.zeros((batch_size, images_size[1], images_size[0], 2))
    
    # fill batches with images
    for i in range(index, index + batch_size):
        with imgs.open(image_paths[i]) as img:
            img = Image.open(img)
        
        # randomly flip image
        if random() < .5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)

        # create segmentation
        s = predict_segmentation(img.convert("L").convert("RGB"), pspnet)
        
        # create x and y
        l, a, b = lab_img(img, validation)
        y = np.concatenate((a, b), axis=2)

        x_batch[i - index] = l
        s_batch[i - index] = s
        y_batch[i - index] = y

    return x_batch, s_batch, y_batch


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, pspnet):
    data_mean = np.array([[[128, 128, 128]]])
    input_size = (473, 473)
    output_size = (img.size[0] / 8, img.size[1] / 8)

    # resize to 473 x 473px
    if img.size != input_size:
        img = img.resize(input_size)

    # normalize image
    img = np.array(img)
    img = img - data_mean
    img = np.expand_dims(pixel_img, axis=0)
    
    # predict segmentation
    segmentation = pspnet.predict(img)[0]
    
    # resize segmentation
    if output_size != input_size:
        segmentation = resize(segmentation,
                             (output_size[1], output_size[0], 150),
                              mode="constant",
                              preserve_range=True)
    return segmentation


# returns training batches
def generator_fn(batch_size, images_path, images_size, pspnet):
    with zipfile.ZipFile(images_path) as imgs:
        image_paths = imgs.infolist()
        n_images = len(image_paths)
        i = 0
        while True:
            if i + batch_size > n_images:
                i = 0
            x, s, y = batch_images(i, batch_size, images_size, image_paths, imgs, pspnet, validation)
            i += batch_size
            yield [x, s], y


In [None]:
################################################################################
# Loading, training and saving model
################################################################################

# custom pspnet layer
class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns pspnet
def load_pspnet(path):
    pspnet = load_model(path, custom_objects={'Interp': Interp})
    pspnet._make_predict_function()
    return pspnet


# Returns main model
def model_definition():
    grayscale_input = Input(shape=(None, None, 1))
    grayscale1 = Conv2D(64, (5, 5), padding="same", activation="relu", strides=2)(grayscale_input)
    grayscale2 = Conv2D(64, (5, 5), padding="same", activation="relu")(grayscale1)
    grayscale3 = Conv2D(64, (5, 5), padding="same", activation="relu")(grayscale2)
    r1 = layers.add([grayscale3, grayscale1])
    grayscale4 = Conv2D(128, (5, 5), padding="same", activation="relu", strides=2)(r1)
    grayscale5 = Conv2D(128, (5, 5), padding="same", activation="relu")(grayscale4)
    grayscale6 = Conv2D(128, (5, 5), padding="same", activation="relu")(grayscale5)
    r2 = layers.add([grayscale6, grayscale4])
    grayscale7 = Conv2D(256, (5, 5), padding="same", activation="relu", strides=2)(r2)
    grayscale8 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale7)
    grayscale9 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale8)
    r3 = layers.add([grayscale9, grayscale7])
    grayscale10 = Conv2D(256, (5, 5), padding="same", activation="relu")(r3)
    grayscale11 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale10)
    grayscale12 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale11)
    r4 = layers.add([grayscale12, grayscale10])

    segmentation_input = Input(shape=(None, None, 150))

    merged = concatenate([r4, segmentation_input], axis=3)
    colorized1 = Conv2D(256, (5, 5), padding="same", activation="relu")(merged)
    colorized2 = Conv2D(256, (5, 5), padding="same", activation="relu")(colorized1)
    colorized3 = Conv2D(256, (5, 5), padding="same", activation="relu")(colorized2)
    r4 = layers.add([colorized3, colorized1])
    upsampling1 = UpSampling2D()(r4)
    colorized4 = Conv2D(128, (3, 3), padding="same", activation="relu")(upsampling1)
    colorized5 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized4)
    colorized6 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized5)
    r5 = layers.add([colorized6, colorized4])
    upsampling2 = UpSampling2D()(r5)
    colorized7 = Conv2D(64, (3, 3), padding="same", activation="relu")(upsampling2)
    colorized8 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized7)
    colorized9 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized8)
    r6 = layers.add([colorized9, colorized7])
    upsampling3 = UpSampling2D()(r6)
    colorized10 = Conv2D(32, (3, 3), padding="same", activation="relu")(upsampling3)
    colorized11 = Conv2D(16, (3, 3), padding="same", activation="relu")(colorized10)
    colorized12 = Conv2D(2, (3, 3), padding="same", activation="tanh")(colorized11)

    model = Model(inputs=[grayscale_input, segmentation_input], outputs=colorized12)
    model.compile(loss="mse", optimizer=Adam(lr=0.00015, decay=0.035))
    return model

# Class for saving and uploading model & csv logger
class save_and_upload(Callback):
    def __init__(self, model_name, n_epochs):
        self.model_name = model_name
        self.n_epochs = n_epochs
        self.model_checkpoint = drive.CreateFile({"title": self.model_name + ".hdf5",
                                              "parents": [{"kind": "drive#childList",
                                                           "id": "15svx5A7mYSLMrDcRDmgY5cB3Cj66m6j8"}]})
        self.model_log = drive.CreateFile({"title": self.model_name + ".csv",
                                        "parents": [{"kind": "drive#childList",
                                                     "id": "15svx5A7mYSLMrDcRDmgY5cB3Cj66m6j8"}]})
  
    def on_epoch_end(self, epoch, logs=None):
        try:
            if gauth.access_token_expired:
                gauth.Refresh()
        except:
            print("refresh failed")
        if epoch / self.n_epochs > 0.4:
            try:
                self.model.save(self.model_name + ".hdf5", overwrite=True)
                self.model_checkpoint.SetContentFile(self.model_name + ".hdf5")
                self.model_log.SetContentFile(self.model_name + ".csv")
            except:
                print("save failed")

            try:
                self.model_checkpoint.Upload()
                self.model_log.Upload()
            except:
                print("upload failed")


# Returns list of callbacks
def callbacks(model_name, n_epochs):
    cb = list()
    cb.append(CSVLogger(model_name + ".csv"))
    cb.append(save_and_upload(model_name, n_epochs))
    return cb


# Trains the model
def train_model(model, training_data_generator, validation_data_generator, n_epochs, steps_per_epoch, validation_steps, model_name):
    model.fit_generator(training_data_generator,
                        epochs=n_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks(model_name, n_epochs),
                        validation_data=validation_data_generator,
                        validation_steps=validation_steps,
                        verbose=2,
                        max_queue_size=20)

In [None]:
################################################################################
# Main
################################################################################

model_name = "Model_0"
batch_size = 64
batches_per_epoch = 73
n_epochs = 10
batches_per_validation_epoch = 4
images_size = (256, 256)

pspnet = load_pspnet("pspnet.h5")
# model = model_definition()
model = load_model("colorization_model.hdf5")

training_data_generator = generator_fn(batch_size, "dataset2_training_0.zip", images_size, pspnet)
validation_data_generator = generator_fn(batch_size, "dataset2_validation_0.zip", images_size, pspnet)

train_model(model, 
            training_data_generator, 
            validation_data_generator, 
            n_epochs, 
            batches_per_epoch, 
            batches_per_validation_epoch, 
            model_name)