In [1]:
###########################################################################################################
# Download dataset and model from google drive to local storage
###########################################################################################################

!ls
!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

dataset = drive.CreateFile({'id': '1g9FGP5mrqp1iki6PXwnUDD1Dq7EuNiW4'})
dataset.GetContentFile("dataset.zip")

dataset = drive.CreateFile({'id': '1YpCm6bho9fNCVgY9Bh8WBaXX5OQlFqI1'})
dataset.GetContentFile("pspnet.h5")


adc.json  dataset.zip  model_checkpoint.hdf5  pspnet.h5  sample_data


In [2]:
###########################################################################################################
# Import necessary libraries
###########################################################################################################

!pip install --upgrade keras
import zipfile
from PIL import Image
from skimage.color import rgb2lab
import numpy as np
from skimage.transform import resize
from keras.models import Model, load_model
from keras.layers import Conv2D, concatenate, Input
from keras.callbacks import ModelCheckpoint
from keras import layers
from keras.backend import tf as ktf
import os
import time
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.ERROR)


Using TensorFlow backend.


In [0]:
###########################################################################################################
# Generator and its functions
###########################################################################################################

# Returns numpy arrays l, a and b (w x h x 1)
def lab_img(img):
    if img.format != ("JPEG" or "JPG"):
        img = img.convert("RGB")
    img = rgb2lab(img)

    l = img[:, :, 0]
    a = img[:, :, 1]
    b = img[:, :, 2]

    l = np.array(l) / 100
    a = (np.array(a) + 127) / 255
    b = (np.array(b) + 128) / 255

    l = np.expand_dims(l, axis=2)
    a = np.expand_dims(a, axis=2)
    b = np.expand_dims(b, axis=2)

    return l, a, b


# Returns batch of x and y values packed together
def batch_images(index, batch_size, images_path, image_paths, trained_model):
    with zipfile.ZipFile(images_path) as my_zip:
        x_batch = []
        s_batch = []
        y_batch = []
        start_position = index * batch_size
        end_position = min(len(image_paths), index * batch_size + batch_size)
        for image in range(start_position, end_position):
            with my_zip.open(image_paths[image]) as img:
                img = Image.open(img)
                
                s = predict_segmentation(img, trained_model)
                l, a, b = lab_img(img)

                y = np.concatenate((a, b), axis=2)

                x_batch.append(l)
                s_batch.append(s)
                y_batch.append(y)

                x = np.fliplr(l)
                s = np.fliplr(s)
                y = np.fliplr(y)

                x_batch.append(x)
                s_batch.append(s)
                y_batch.append(y)

        return x_batch, s_batch, y_batch


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, trained_model):
    # TODO: do this in my model
    data_mean = np.array([[[123.68, 116.779, 103.939]]])
    image_size = img.size
    input_size = (473, 473)

    if image_size != input_size:
        img = img.resize(input_size)

    pixel_img = np.array(img)
    pixel_img = pixel_img - data_mean
    bgr_img = pixel_img[:, :, ::-1]
    segmented_img = trained_model.predict(np.expand_dims(bgr_img, 0))[0]
    if image_size != input_size:
          segmented_img = resize(segmented_img, (image_size[1], image_size[0], 150), mode="constant")
    return segmented_img


# Yields batches of x and y values
def generator_fn(n_images, batch_size, images_path, trained_model):
    with zipfile.ZipFile(images_path) as my_zip:
        image_paths = my_zip.infolist()

    batches_per_epoch = int(n_images / batch_size)

    while True:
        for i in range(batches_per_epoch):
            x, s, y = batch_images(i, batch_size, images_path, image_paths, trained_model)
            yield [np.array(x), np.array(s)], np.array(y)


In [0]:
###########################################################################################################
# Loading, training and saving model
###########################################################################################################

class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns trained model instance
def load_trained_model(path):
    trained_model = load_model(path, custom_objects={'Interp': Interp})
    trained_model._make_predict_function()
    return trained_model


# Returns main model
def model_definition():
    grayscale_input = Input(shape=(None, None, 1))
    grayscale = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale_input)
    grayscale = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)
    grayscale = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)

    segmentation_input = Input(shape=(None, None, 150))
    segmentation = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True)(segmentation_input)
    segmentation = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(segmentation)
    segmentation = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(segmentation)

    merged = concatenate([grayscale, segmentation], axis=3)
    colorized = Conv2D(512, (3, 3), padding="same", activation="relu", use_bias=True)(merged)
    colorized = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(64, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(32, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(2, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)

    model = Model(inputs=[grayscale_input, segmentation_input], outputs=colorized)
    model.compile(loss="mse", optimizer="adam")
    return model


# Returns list of used keras callbacks
def callbacks(model_path):
    cb = list()
    cb.append(ModelCheckpoint(os.path.join(model_path, "model_checkpoint.hdf5")))
    return cb


# Returns the model after training it
def train_model(model, training_data_fn, validation_data_fn, epochs, steps_per_epoch, validation_steps, save_path):
    model.fit_generator(training_data_fn,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks(save_path),
                        validation_data=validation_data_fn,
                        validation_steps=validation_steps,
                        verbose=1)
    return model


In [5]:
###########################################################################################################
# Control panel
###########################################################################################################

current_directory = ""

n_images = 2
batch_size = 1
n_epochs = 3
batches_per_epoch = int(n_images / batch_size)
validation_batches = 1

images_path = os.path.join(current_directory, "dataset.zip")
trained_model_path = os.path.join(current_directory, "pspnet.h5")
trained_model = load_trained_model(trained_model_path)

training_data_fn = generator_fn(n_images, batch_size, images_path, trained_model)
validation_data_fn = generator_fn(n_images, batch_size, images_path, trained_model)
model = model_definition()
start_time = time.time()
train_model(model, 
            training_data_fn, 
            validation_data_fn, 
            n_epochs, 
            batches_per_epoch, 
            validation_batches, 
            current_directory)
print("Training took: " + str(time.time() - start_time))

model_checkpoint = drive.CreateFile()
model_checkpoint.SetContentFile("model_checkpoint.hdf5")
model_checkpoint.Upload()

!ls


Epoch 1/3
Epoch 2/3
Epoch 3/3
Training took: 82.85155916213989
adc.json  dataset.zip  model_checkpoint.hdf5  pspnet.h5  sample_data
