In [1]:
################################################################################
# Download dataset and model from google drive to local storage
################################################################################

!ls
!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
from google.colab import files
import os.path

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

dataset_training = drive.CreateFile({'id': '1mhQWoYjj7ScvLmJ6FJL5QikkRAVPtW0t'})
dataset_training.GetContentFile("dataset2_training_27.zip")

dataset_validation = drive.CreateFile({'id': '1Hu4bm92fmnNNPlvweKUgdaPRkL_uN74Z'})
dataset_validation.GetContentFile("dataset2_validation_1.zip")

##########################################################################
model = drive.CreateFile({'id': '1TgbmPFzqiNUdmnoIue6n5eKzYL2xPf0q'}) # 2
model.GetContentFile("colorization_model.hdf5")

model = drive.CreateFile({'id': '1oFuPmSbFqy_b8UPxNHNCMmYC99mogCEm'})
model.GetContentFile("pspnet.h5")

sample_data
Collecting PyDrive
[?25l  Downloading https://files.pythonhosted.org/packages/52/e0/0e64788e5dd58ce2d6934549676243dc69d982f198524be9b99e9c2a4fd5/PyDrive-1.3.1.tar.gz (987kB)
[K    100% |████████████████████████████████| 993kB 20.5MB/s 
Building wheels for collected packages: PyDrive
  Building wheel for PyDrive (setup.py) ... [?25ldone
[?25h  Stored in directory: /root/.cache/pip/wheels/fa/d2/9a/d3b6b506c2da98289e5d417215ce34b696db856643bad779f4
Successfully built PyDrive
Installing collected packages: PyDrive
Successfully installed PyDrive-1.3.1


In [2]:
################################################################################
# Import necessary libraries
################################################################################

import zipfile
from PIL import Image
from skimage.color import rgb2lab
import numpy as np
from skimage.transform import resize
from keras.models import Model, load_model
from keras.layers import Conv2D, concatenate, Input, UpSampling2D
from keras.callbacks import CSVLogger, Callback
from keras.optimizers import Adam
from keras import layers
from keras.backend import tf as ktf, eval, set_value
import time


Using TensorFlow backend.


In [0]:
################################################################################
# Generator and its functions
################################################################################

# Returns numpy arrays l, a and b (w x h x 1)
def lab_img(img):
    if img.format != ("JPEG" or "JPG"):
        img = img.convert("RGB")
    img = rgb2lab(img)

    l = img[:, :, 0]
    a = img[:, :, 1]
    b = img[:, :, 2]

    l = (np.array(l) / 100)
    a = ((np.array(a) + 127) / 255) * 2 - 1
    b = ((np.array(b) + 128) / 255) * 2 - 1

    l = np.expand_dims(l, axis=2)
    a = np.expand_dims(a, axis=2)
    b = np.expand_dims(b, axis=2)

    return l, a, b


# Returns batch of x and y values packed together
def batch_images(index, batch_size, images_size, image_paths, imgs, trained_model, flip):
    x_batch = np.zeros((batch_size, images_size[1], images_size[0], 1))
    s_batch = np.zeros((batch_size, int(images_size[1] / 8), int(images_size[0] / 8), 150))
    y_batch = np.zeros((batch_size, images_size[1], images_size[0], 2))
    for i in range(index, index + batch_size):
        with imgs.open(image_paths[i]) as img:
            img = Image.open(img)

        img = img.transpose(Image.FLIP_LEFT_RIGHT)

        s = predict_segmentation(img.convert("L").convert("RGB"), trained_model) 
        l, a, b = lab_img(img)

        y = np.concatenate((a, b), axis=2)

        x_batch[i - index] = l
        s_batch[i - index] = s
        y_batch[i - index] = y

        if flip:
            x = np.fliplr(l)
            s = np.fliplr(s)
            y = np.fliplr(y)

            x_batch[i - index] = x
            s_batch[i - index] = s
            y_batch[i - index] = y

    return x_batch, s_batch, y_batch


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, trained_model):
    data_mean = np.array([[[128, 128, 128]]])
    input_size = (473, 473)
    output_size = (img.size[0] / 8, img.size[1] / 8)

    if img.size != input_size:
        img = img.resize(input_size)

    pixel_img = np.array(img)
    pixel_img = pixel_img - data_mean
    segmented_img = trained_model.predict(np.expand_dims(pixel_img, axis=0))[0]
    if output_size != input_size:
          segmented_img = resize(segmented_img,
                                 (output_size[1], output_size[0], 150),
                                  mode="constant",
                                  preserve_range=True)
    return segmented_img


# Yields batches of x and y values
def generator_fn(batch_size, images_path, images_size, trained_model, flip=False):
    with zipfile.ZipFile(images_path) as imgs:
        image_paths = imgs.infolist()
        n_images = len(image_paths)
        i = 0
        while True:
            if i + batch_size > n_images:
                i = 0
            x, s, y = batch_images(i, batch_size, images_size, image_paths, imgs, trained_model, flip)
            i += batch_size
            yield [x, s], y


In [0]:
################################################################################
# Loading, training and saving model
################################################################################

class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns trained model instance
def load_trained_model(path):
    trained_model = load_model(path, custom_objects={'Interp': Interp})
    trained_model._make_predict_function()
    return trained_model


# Returns main model
def model_definition():
    grayscale_input = Input(shape=(None, None, 1))
    grayscale1 = Conv2D(64, (5, 5), padding="same", activation="relu", strides=2)(grayscale_input)
    grayscale2 = Conv2D(64, (5, 5), padding="same", activation="relu")(grayscale1)
    grayscale3 = Conv2D(64, (5, 5), padding="same", activation="relu")(grayscale2)
    r1 = layers.add([grayscale3, grayscale1])
    grayscale4 = Conv2D(128, (5, 5), padding="same", activation="relu", strides=2)(r1)
    grayscale5 = Conv2D(128, (5, 5), padding="same", activation="relu")(grayscale4)
    grayscale6 = Conv2D(128, (5, 5), padding="same", activation="relu")(grayscale5)
    r2 = layers.add([grayscale6, grayscale4])
    grayscale7 = Conv2D(256, (5, 5), padding="same", activation="relu", strides=2)(r2)
    grayscale8 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale7)
    grayscale9 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale8)
    r3 = layers.add([grayscale9, grayscale7])
    grayscale10 = Conv2D(256, (5, 5), padding="same", activation="relu")(r3)
    grayscale11 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale10)
    grayscale12 = Conv2D(256, (5, 5), padding="same", activation="relu")(grayscale11)
    r4 = layers.add([grayscale12, grayscale10])

    segmentation_input = Input(shape=(None, None, 150))

    merged = concatenate([r4, segmentation_input], axis=3)
    colorized1 = Conv2D(256, (5, 5), padding="same", activation="relu")(merged)
    colorized2 = Conv2D(256, (5, 5), padding="same", activation="relu")(colorized1)
    colorized3 = Conv2D(256, (5, 5), padding="same", activation="relu")(colorized2)
    r4 = layers.add([colorized3, colorized1])
    upsampling1 = UpSampling2D()(r4)
    colorized4 = Conv2D(128, (3, 3), padding="same", activation="relu")(upsampling1)
    colorized5 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized4)
    colorized6 = Conv2D(128, (3, 3), padding="same", activation="relu")(colorized5)
    r5 = layers.add([colorized6, colorized4])
    upsampling2 = UpSampling2D()(r5)
    colorized7 = Conv2D(64, (3, 3), padding="same", activation="relu")(upsampling2)
    colorized8 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized7)
    colorized9 = Conv2D(64, (3, 3), padding="same", activation="relu")(colorized8)
    r6 = layers.add([colorized9, colorized7])
    upsampling3 = UpSampling2D()(r6)
    colorized10 = Conv2D(32, (3, 3), padding="same", activation="relu")(upsampling3)
    colorized11 = Conv2D(16, (3, 3), padding="same", activation="relu")(colorized10)
    colorized12 = Conv2D(2, (3, 3), padding="same", activation="tanh")(colorized11)

    model = Model(inputs=[grayscale_input, segmentation_input], outputs=colorized12)
    model.compile(loss="mse", optimizer=Adam(lr=0.00015, decay=0.035))
    return model
  
class Upload2Drive(Callback):
  def __init__(self, model_name, n_epochs):
      self.model_name = model_name
      self.n_epochs = n_epochs
      self.model_checkpoint = drive.CreateFile({"title": self.model_name + ".hdf5",
                                              "parents": [{"kind": "drive#childList",
                                                           "id": "1gg4PUYC-10tH1Mno0LzDa9Zdew8BJ5_n"}]})
      self.model_log = drive.CreateFile({"title": self.model_name + ".csv",
                                        "parents": [{"kind": "drive#childList",
                                                     "id": "1gg4PUYC-10tH1Mno0LzDa9Zdew8BJ5_n"}]})
  
  def on_epoch_end(self, epoch, logs=None):
      try:
          if gauth.access_token_expired:
              gauth.Refresh()
      except:
          print("refresh failed")
      if epoch / self.n_epochs > 0.4:
          try:
              self.model.save(self.model_name + ".hdf5", overwrite=True)
              self.model_checkpoint.SetContentFile(self.model_name + ".hdf5")
              self.model_log.SetContentFile(self.model_name + ".csv")
          except:
              print("save failed")

          try:
              self.model_checkpoint.Upload()
              self.model_log.Upload()
          except:
              print("upload checkpoint failed")

# Returns list of used keras callbacks
def callbacks(model_name, n_epochs):
    cb = list()
    cb.append(CSVLogger(model_name + ".csv"))
    cb.append(Upload2Drive(model_name, n_epochs))
    return cb


# Returns the model after training it
def train_model(model, training_data_fn, validation_data_fn, n_epochs, steps_per_epoch, validation_steps, model_name):
    model.fit_generator(training_data_fn,
                        epochs=n_epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks(model_name, n_epochs),
                        validation_data=validation_data_fn,
                        validation_steps=validation_steps,
                        verbose=2,
                        max_queue_size=20)
    return model


In [0]:
################################################################################
# Control panel
################################################################################

model_name = "DefinitelyFinal_58" # don't forget about this

batch_size = 64 # twice as much gets trained if flip=True
batches_per_epoch = 73
n_epochs = 10
batches_per_validation = 1
images_size = (256, 256)

trained_model = load_trained_model("pspnet.h5")
# model = model_definition()
model = load_model("colorization_model.hdf5")

training_data_fn = generator_fn(batch_size, "dataset2_training_27.zip", images_size, trained_model)
validation_data_fn = generator_fn(batch_size, "dataset2_validation_1.zip", images_size, trained_model)

start_time = time.time()
train_model(model, 
            training_data_fn, 
            validation_data_fn, 
            n_epochs, 
            batches_per_epoch, 
            batches_per_validation, 
            model_name)
print("Training took: " + str(time.time() - start_time))

!ls


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Epoch 1/10
 - 4888s - loss: 0.0086 - val_loss: 0.0034
Epoch 2/10
 - 4289s - loss: 0.0085 - val_loss: 0.0037
Epoch 3/10
 - 4293s - loss: 0.0085 - val_loss: 0.0038
Epoch 4/10
 - 4270s - loss: 0.0084 - val_loss: 0.0049
Epoch 5/10
 - 4246s - loss: 0.0082 - val_loss: 0.0035
Epoch 6/10
 - 4244s - loss: 0.0084 - val_loss: 0.0037
Epoch 7/10
 - 4159s - loss: 0.0085 - val_loss: 0.0038
Epoch 8/10
 - 4217s - loss: 0.0084 - val_loss: 0.0050
Epoch 9/10
