In [1]:
################################################################################
# Download dataset and model from google drive to local storage
################################################################################

!ls
!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
from google.colab import files
import os.path

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

dataset_training = drive.CreateFile({'id': '1YTHaiFGabB-YWCCzD873mNg5ekvOgVBb'})
dataset_training.GetContentFile("imagenet_training_1.zip")

dataset_validation = drive.CreateFile({'id': '1gqDHERzr2pH1j22GuPuv55rDbgjtY_ed'})
dataset_validation.GetContentFile("imagenet_validation_1.zip")

# model = drive.CreateFile({'id': '16H2UavQKdC_T8OTAWM5Z8EwL9FyesCMH'})
# model.GetContentFile("colorization_model.hdf5")

model = drive.CreateFile({'id': '1YpCm6bho9fNCVgY9Bh8WBaXX5OQlFqI1'})
model.GetContentFile("pspnet.h5")


adc.json		 imagenet_training_1.zip    pspnet.h5
colorization_model.hdf5  imagenet_validation_1.zip  sample_data


In [2]:
################################################################################
# Import necessary libraries
################################################################################

import zipfile
from PIL import Image
from skimage.color import rgb2lab
import numpy as np
from skimage.transform import resize
from keras.models import Model, load_model
from keras.layers import Conv2D, concatenate, Input, UpSampling2D
from keras.callbacks import CSVLogger, Callback
from keras.optimizers import Adam
from keras import layers
from keras.backend import tf as ktf
import time


Using TensorFlow backend.


In [0]:
################################################################################
# Generator and its functions
################################################################################

# Returns numpy arrays l, a and b (w x h x 1)
def lab_img(img):
#     if img.format != ("JPEG" or "JPG"):
#         img = img.convert("RGB")
    img = rgb2lab(img)

    l = img[:, :, 0]
    a = img[:, :, 1]
    b = img[:, :, 2]

    l = (np.array(l) / 100)
    a = ((np.array(a) + 127) / 255) * 2 - 1
    b = ((np.array(b) + 128) / 255) * 2 - 1

    l = np.expand_dims(l, axis=2)
    a = np.expand_dims(a, axis=2)
    b = np.expand_dims(b, axis=2)

    return l, a, b


# Returns batch of x and y values packed together
def batch_images(index, batch_size, images_size, image_paths, imgs, trained_model, flip):
    x_batch = np.zeros((batch_size, images_size[1], images_size[0], 1))
    s_batch = np.zeros((batch_size, int(images_size[1] / 8), int(images_size[0] / 8), 150))
    y_batch = np.zeros((batch_size, images_size[1], images_size[0], 2))
    for i in range(index, index + batch_size):
        with imgs.open(image_paths[i]) as img:
            img = Image.open(img)

#             if img.size != images_size:
#                 img = img.resize(images_size)

            s = predict_segmentation(img.convert("L").convert("RGB"), trained_model) 
            l, a, b = lab_img(img)

            y = np.concatenate((a, b), axis=2)

            x_batch[i - index] = l
            s_batch[i - index] = s
            y_batch[i - index] = y
            
            if flip:
                x = np.fliplr(l)
                s = np.fliplr(s)
                y = np.fliplr(y)

                x_batch[i - index] = x
                s_batch[i - index] = s
                y_batch[i - index] = y

    return x_batch, s_batch, y_batch


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, trained_model):
    data_mean = np.array([[[123.68, 116.779, 103.939]]])
    input_size = (473, 473)
    output_size = (img.size[0] / 8, img.size[1] / 8)

    if img.size != input_size:
        img = img.resize(input_size)

    pixel_img = np.array(img)
    pixel_img = pixel_img - data_mean
    bgr_img = pixel_img[:, :, ::-1]
    segmented_img = trained_model.predict(np.expand_dims(bgr_img, axis=0))[0]
    if output_size != input_size:
          segmented_img = resize(segmented_img,
                                 (output_size[1], output_size[0], 150),
                                  mode="constant",
                                  preserve_range=True)
    return segmented_img


# Yields batches of x and y values
def generator_fn(batch_size, images_path, images_size, trained_model, flip=False):
    with zipfile.ZipFile(images_path) as imgs:
        image_paths = imgs.infolist()
        n_images = len(image_paths)
        i = 0
        while True:
            if i + batch_size > n_images:
                i = 0
            x, s, y = batch_images(i, batch_size, images_size, image_paths, imgs, trained_model, flip)
            i += batch_size
            yield [x, s], y


In [0]:
################################################################################
# Loading, training and saving model
################################################################################

class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns trained model instance
def load_trained_model(path):
    trained_model = load_model(path, custom_objects={'Interp': Interp})
    trained_model._make_predict_function()
    return trained_model


# Returns main model
def model_definition():
    grayscale_input = Input(shape=(None, None, 1))
    grayscale1 = Conv2D(32, (3, 3), padding="same", activation="relu", strides=2)(grayscale_input)
    grayscale2 = Conv2D(64, (3, 3), padding="same", activation="relu")(grayscale1)
    grayscale3 = Conv2D(64, (3, 3), padding="same", activation="relu")(grayscale2)
    grayscale4 = Conv2D(64, (3, 3), padding="same", activation="relu")(grayscale3)
    r1 = layers.add([grayscale4, grayscale2])
    grayscale5 = Conv2D(64, (3, 3), padding="same", activation="relu", strides=2)(r1)
    grayscale6 = Conv2D(128, (3, 3), padding="same", activation="relu")(grayscale5)
    grayscale7 = Conv2D(128, (3, 3), padding="same", activation="relu")(grayscale6)
    grayscale8 = Conv2D(128, (3, 3), padding="same", activation="relu")(grayscale7)
    r2 = layers.add([grayscale8, grayscale6])
    grayscale9 = Conv2D(128, (3, 3), padding="same", activation="relu", strides=2)(r2)
    grayscale10 = Conv2D(256, (3, 3), padding="same", activation="relu")(grayscale9)
    grayscale11 = Conv2D(256, (3, 3), padding="same", activation="relu")(grayscale10)
    grayscale12 = Conv2D(256, (3, 3), padding="same", activation="relu")(grayscale11)
    r3 = layers.add([grayscale12, grayscale10])

    segmentation_input = Input(shape=(None, None, 150))

    merged = concatenate([r3, segmentation_input], axis=3)
    colorized1 = Conv2D(256, (3, 3), padding="same", activation="relu")(merged)
    colorized2 = Conv2D(256, (3, 3), padding="same", activation="relu")(colorized1)
    colorized3 = Conv2D(256, (3, 3), padding="same", activation="relu")(colorized2)
    r4 = layers.add([colorized3, colorized1])
    colorized4 = UpSampling2D()(r4)
    colorized5 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized4)
    colorized6 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized5)
    colorized7 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized6)
    r5 = layers.add([colorized7, colorized5])
    colorized8 = UpSampling2D()(r5)
    colorized9 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized8)
    colorized10 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized9)
    colorized11 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized10)
    r6 = layers.add([colorized11, colorized9])
    colorized12 = UpSampling2D()(r6)
    colorized13 = Conv2D(32, (3, 3), padding="same", activation="relu")(colorized12)
    colorized14 = Conv2D(16, (3, 3), padding="same", activation="relu")(colorized13)
    colorized15 = Conv2D(2, (3, 3), padding="same", activation="tanh")(colorized14)

    model = Model(inputs=[grayscale_input, segmentation_input], outputs=colorized15)
    model.compile(loss="mse", optimizer=Adam(lr=0.0005, decay=0.015))
    return model

  
class Upload2Drive(Callback):
  def __init__(self, model_name, n_epochs):
      self.model_name = model_name
      self.n_epochs = n_epochs
      self.model_checkpoint = drive.CreateFile({"title": self.model_name + ".hdf5",
                                              "parents": [{"kind": "drive#childList",
                                                           "id": "15svx5A7mYSLMrDcRDmgY5cB3Cj66m6j8"}]})
      self.model_log = drive.CreateFile({"title": self.model_name + ".csv",
                                        "parents": [{"kind": "drive#childList",
                                                     "id": "15svx5A7mYSLMrDcRDmgY5cB3Cj66m6j8"}]})
  
  def on_epoch_end(self, epoch, logs=None):
      if gauth.access_token_expired:
          gauth.Refresh()
      if epoch / self.n_epochs > 0.7:
          self.model.save(self.model_name + ".hdf5", overwrite=True)
          self.model_checkpoint.SetContentFile(self.model_name + ".hdf5")
          self.model_log.SetContentFile(self.model_name + ".csv")

          try:
              self.model_checkpoint.Upload()
              self.model_log.Upload()
          except:
              print("upload checkpoint failed")

# Returns list of used keras callbacks
def callbacks(model_name, n_epochs):
    cb = list()
    cb.append(CSVLogger(model_name + ".csv"))
    cb.append(Upload2Drive(model_name, n_epochs))
    return cb


# Returns the model after training it
def train_model(model, training_data_fn, validation_data_fn, n_epochs, steps_per_epoch, validation_steps, model_name):
    model.fit_generator(training_data_fn,
                        epochs=n_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks(model_name, n_epochs),
                        validation_data=validation_data_fn,
                        validation_steps=validation_steps,
                        verbose=2,
                        max_queue_size=1)
    return model


In [0]:
################################################################################
# Control panel
################################################################################

model_name = "FinalModel_1" # don't forget about this

batch_size = 64 # twice as much gets trained if flip=True
batches_per_epoch = 73
n_epochs = 10
batches_per_validation = 4
images_size = (256, 256)

trained_model = load_trained_model("pspnet.h5")
model = model_definition()
# model = load_model("colorization_model.hdf5")

training_data_fn = generator_fn(batch_size, "imagenet_training_1.zip", images_size, trained_model)
validation_data_fn = generator_fn(batch_size, "imagenet_validation_1.zip", images_size, trained_model)

start_time = time.time()
train_model(model, 
            training_data_fn, 
            validation_data_fn, 
            n_epochs, 
            batches_per_epoch, 
            batches_per_validation, 
            model_name)
print("Training took: " + str(time.time() - start_time))

!ls


Epoch 1/10
 - 4113s - loss: 0.0157 - val_loss: 0.0139
Epoch 2/10
 - 3969s - loss: 0.0151 - val_loss: 0.0135
Epoch 3/10
 - 3997s - loss: 0.0148 - val_loss: 0.0135
Epoch 4/10
