In [1]:
################################################################################
# Download dataset and model from google drive to local storage
################################################################################

!ls
!pip install PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# dataset_training = drive.CreateFile({'id': '10rjhJWZaL_SG6ObDH72vcOQMfcADHYuV'})
# dataset_training.GetContentFile("coco_training_1.zip")

# dataset_validation = drive.CreateFile({'id': '1cVaHMEVBYdtsCLa6q7xtEIWn5jDRDDYm'})
# dataset_validation.GetContentFile("coco_validation_1.zip")
  
dataset = drive.CreateFile({'id': '1Yo4xiohHHoRquswIda4yZP6QxHK3aWo_'})
dataset.GetContentFile("places2_training_2.zip") 
  
# dataset = drive.CreateFile({'id': '17sUNxu0Eq7-sGih46QOVO1E62L6sw4Dp'})
# dataset.GetContentFile("places2_training_1.zip")

dataset = drive.CreateFile({'id': '1zJVZMpxyHw7pScvnG-pVr5PDD9XbQqRd'})
dataset.GetContentFile("places2_validation_1.zip")

# dataset = drive.CreateFile({'id': '1QO4NdHG99U6TVswmy4j3lz7ITxRp8D1B'})
# dataset.GetContentFile("faces_in_wild.zip")

# dataset = drive.CreateFile({'id': '10zuh0FbTmgMF6lgUf_EHnP9CAdvaofxI'})
# dataset.GetContentFile("celeba_training_2.zip")

# dataset = drive.CreateFile({'id': '1WQ6h5jQ9B4RCN7dDg2ySR5FRkK1EMfEl'})
# dataset.GetContentFile("celeba_validation_1.zip")

model = drive.CreateFile({'id': '1gw1gJoHBy2kR6hmjLQFRzM6wLTJJ-6Ah'})
model.GetContentFile("colorization_model.hdf5")

model = drive.CreateFile({'id': '1YpCm6bho9fNCVgY9Bh8WBaXX5OQlFqI1'})
model.GetContentFile("pspnet.h5")


adc.json		 model_log.csv		   pspnet.h5
colorization_model.hdf5  places2_training_2.zip    sample_data
model_checkpoint.hdf5	 places2_validation_1.zip


In [2]:
################################################################################
# Import necessary libraries
################################################################################

# !pip install q keras==2.0.0
# !pip install --upgrade Pillow
import zipfile
from PIL import Image
from skimage.color import rgb2lab
import numpy as np
from skimage.transform import resize
from keras.models import Model, load_model
from keras.layers import Conv2D, concatenate, Input, UpSampling2D
from keras.callbacks import ModelCheckpoint, CSVLogger
from keras import layers
from keras.backend import tf as ktf
# from keras import backend as k
import random
import os
import time
from google.colab import files

# config = ktf.ConfigProto()
# config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
# k.tensorflow_backend.set_session(ktf.Session(config=config))

Using TensorFlow backend.


In [0]:
################################################################################
# Generator and its functions
################################################################################

# Returns numpy arrays l, a and b (w x h x 1)
def lab_img(img):
    if img.format != ("JPEG" or "JPG"):
        img = img.convert("RGB")
    img = rgb2lab(img)

    l = img[:, :, 0]
    a = img[:, :, 1]
    b = img[:, :, 2]

    l = (np.array(l) / 100) * 2 - 1
    a = ((np.array(a) + 127) / 255) * 2 - 1
    b = ((np.array(b) + 128) / 255) * 2 - 1

    l = np.expand_dims(l, axis=2)
    a = np.expand_dims(a, axis=2)
    b = np.expand_dims(b, axis=2)

    return l, a, b


# Returns batch of x and y values packed together
def batch_images(index, batch_size, images_size, image_paths, imgs, trained_model, flip):
    x_batch = np.zeros((batch_size, images_size[1], images_size[0], 1))
    s_batch = np.zeros((batch_size, int(images_size[1] / 8), int(images_size[0] / 8), 150))
    y_batch = np.zeros((batch_size, images_size[1], images_size[0], 2))
    for i in range(index, index + batch_size):
        with imgs.open(image_paths[i]) as img:
            img = Image.open(img)

            if img.size != images_size:
                img = img.resize(images_size)

            s = predict_segmentation(img, trained_model)
            l, a, b = lab_img(img)

            y = np.concatenate((a, b), axis=2)

            x_batch[i - index] = l
            s_batch[i - index] = s
            y_batch[i - index] = y
            
            if flip:
                x = np.fliplr(l)
                s = np.fliplr(s)
                y = np.fliplr(y)

                x_batch[i - index] = x
                s_batch[i - index] = s
                y_batch[i - index] = y

    return x_batch, s_batch, y_batch


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, trained_model):
    # TODO: do this in my model
    data_mean = np.array([[[123.68, 116.779, 103.939]]])
    image_size = img.size
    input_size = (473, 473)
    output_size = (32, 32)

    if image_size != input_size:
        img = img.resize(input_size)

    pixel_img = np.array(img)
    pixel_img = pixel_img - data_mean
    bgr_img = pixel_img[:, :, ::-1]
    segmented_img = trained_model.predict(np.expand_dims(bgr_img, 0))[0]
#     segmented_img = segmented_img * 2 - 1
    if output_size != input_size:
          segmented_img = resize(segmented_img,
                                 (output_size[1], output_size[0], 150),
                                  mode="constant",
                                  preserve_range=True)
    return segmented_img


# Yields batches of x and y values
def generator_fn(batch_size, images_path, images_size, trained_model, flip=False, validation=False):
    with zipfile.ZipFile(images_path) as imgs:
        image_paths = imgs.infolist()
        random.shuffle(image_paths)
        n_images = len(image_paths)
        if validation:
            n_images = batch_size
#         print(n_images)
        i = 0
        while True:
            if i + batch_size >= n_images:
                i = 0
#             print("batch start index: " + str(i) + "   " + images_path)
            x, s, y = batch_images(i, batch_size, images_size, image_paths, imgs, trained_model, flip)
            i += batch_size
            yield [x, s], y


In [0]:
################################################################################
# Loading, training and saving model
################################################################################

class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns trained model instance
def load_trained_model(path):
    trained_model = load_model(path, custom_objects={'Interp': Interp})
    trained_model._make_predict_function()
    return trained_model


# Returns main model
def model_definition():
    grayscale_input = Input(shape=(None, None, 1))
    grayscale = Conv2D(64, (3, 3), padding="same", activation="relu", use_bias=True, strides=2)(grayscale_input)
    grayscale = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)
    grayscale = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True, strides=2)(grayscale)
    grayscale = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)
    grayscale = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True, strides=2)(grayscale)
    grayscale = Conv2D(512, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)
    grayscale = Conv2D(512, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)
    grayscale = Conv2D(256, (3, 3), padding="same", activation="relu", use_bias=True)(grayscale)

    segmentation_input = Input(shape=(None, None, 150))

    merged = concatenate([grayscale, segmentation_input], axis=3)
    colorized = Conv2D(128, (3, 3), padding="same", activation="relu", use_bias=True)(merged)
    colorized = UpSampling2D()(colorized)
    colorized = Conv2D(64, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = UpSampling2D()(colorized)
    colorized = Conv2D(32, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(16, (3, 3), padding="same", activation="relu", use_bias=True)(colorized)
    colorized = Conv2D(2, (3, 3), padding="same", activation="tanh", use_bias=True)(colorized)
    colorized = UpSampling2D()(colorized)

    model = Model(inputs=[grayscale_input, segmentation_input], outputs=colorized)
    model.compile(loss="mse", optimizer="adam")
    return model


# Returns list of used keras callbacks
def callbacks(model_path):
    cb = list()
    cb.append(ModelCheckpoint(os.path.join(model_path, "model_checkpoint.hdf5"), save_best_only=True, monitor="val_loss"))
    cb.append(CSVLogger("model_log.csv"))
    return cb


# Returns the model after training it
def train_model(model, training_data_fn, validation_data_fn, epochs, steps_per_epoch, validation_steps, save_path):
    model.fit_generator(training_data_fn,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks(save_path),
                        validation_data=validation_data_fn,
                        validation_steps=validation_steps,
                        verbose=2,
                        max_queue_size=1)
    return model


In [0]:
################################################################################
# Control panel
################################################################################

current_directory = ""

batch_size = 64 # twice as much gets trained if flip=True
batches_per_epoch = 10
n_epochs = 73
batches_per_validation = 1 # don't forget about this
images_size = (256, 256)

training_path = os.path.join(current_directory, "places2_training_2.zip")
validation_path = os.path.join(current_directory, "places2_validation_1.zip")
trained_model_path = os.path.join(current_directory, "pspnet.h5")
trained_model = load_trained_model(trained_model_path)

training_data_fn = generator_fn(batch_size, training_path, images_size, trained_model)
validation_data_fn = generator_fn(batch_size, validation_path, images_size, trained_model, validation=True)
# model = model_definition()
model = load_model("colorization_model.hdf5")

start_time = time.time()
train_model(model, 
            training_data_fn, 
            validation_data_fn, 
            n_epochs, 
            batches_per_epoch, 
            batches_per_validation, 
            current_directory)
print("Training took: " + str(time.time() - start_time))

model.save("model_final.hdf5")

for i in range(5):
  try:
    files.download("model_final.hdf5")
    break
  except:
    print("local final download error:" + str(i))
  
for i in range(5):
  try:
    files.download("model_checkpoint.hdf5")
    break
  except:
    print("local checkpoint download error:" + str(i))

for i in range(5):
  try:
    files.download("model_log.csv")
    break
  except:
    print("local log download error:" + str(i))

try:
  model_checkpoint = drive.CreateFile()
  model_checkpoint.SetContentFile("model_checkpoint.hdf5")
  model_checkpoint.Upload()
except:
  print("cloud final upload error")
  
try:
  model_final = drive.CreateFile()
  model_final.SetContentFile("model_final.hdf5")
  model_final.Upload()
except:
  print("cloud checkpoint upload error")
  
try:
  model_log = drive.CreateFile()
  model_log.SetContentFile("model_log.csv")
  model_log.Upload()
except:
  print("cloud log upload error")

# !ls


Epoch 1/73
 - 604s - loss: 0.0086 - val_loss: 0.0096
Epoch 2/73
 - 544s - loss: 0.0084 - val_loss: 0.0095
Epoch 3/73
 - 548s - loss: 0.0084 - val_loss: 0.0097
Epoch 4/73
 - 548s - loss: 0.0086 - val_loss: 0.0098
Epoch 5/73
 - 545s - loss: 0.0091 - val_loss: 0.0095
Epoch 6/73
