In [None]:
from __future__ import print_function
from __future__ import absolute_import

import warnings
import os
import random
import numpy as np
import threading

import keras
from keras.models import Model
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Input
from keras.layers import Conv2D, Deconvolution2D
from keras.layers import MaxPooling2D, UpSampling2D
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.layers import merge, concatenate, add
from keras.engine.topology import get_source_inputs
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.preprocessing import image as image_utils
from keras.preprocessing.image import ImageDataGenerator
from imagenet_utils import _obtain_input_shape
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img


WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'

In [None]:
def create_model(input_tensor=None, input_shape=None):

    # Determine proper input shape
    input_shape = _obtain_input_shape(input_shape,
                                      default_size=224,
                                      min_size=48,
                                      data_format=K.image_data_format(),
                                      include_top=False)

    if input_tensor is None:
        inputs = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            inputs = Input(tensor=input_tensor, shape=input_shape)
        else:
            inputs = input_tensor
    
    # Encoder 1
    e1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(inputs)
    e1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(e1)
    e1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(e1)

    # Encoder 2
    e2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(e1)
    e2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(e2)
    e2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(e2)

    # Encoder 3
    e3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(e2)
    e3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(e3)
    e3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(e3)
    e3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(e3)

    # Encoder 4
    e4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(e3)
    e4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(e4)
    e4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(e4)
    e4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(e4)

    # Encoder 5
    e5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(e4)
    e5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(e5)
    e5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(e5)
    e5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(e5)
    
    # Now lets decode the representation
    # Using residual connections (in this case concatenate)
    
    # Decoder 5
    d5 = Deconvolution2D(512, (3, 3), activation='relu', padding='same', name='block5_deconv1')(e5)
    d5 = Deconvolution2D(512, (3, 3), activation='relu', padding='same', name='block5_deconv2')(d5)
    d5 = Deconvolution2D(512, (3, 3), activation='relu', padding='same', name='block5_deconv3')(d5)
    d5 = UpSampling2D((2, 2), name='block5_upsample')(d5)
    
    # Decoder 4
    merged = concatenate([d5, e4])
    d4 = Deconvolution2D(256, (3, 3), activation='relu', padding='same', name='block4_deconv1')(merged)
    d4 = Deconvolution2D(256, (3, 3), activation='relu', padding='same', name='block4_deconv2')(d4)
    d4 = Deconvolution2D(256, (3, 3), activation='relu', padding='same', name='block4_deconv3')(d4)
    d4 = UpSampling2D((2, 2), name='block4_upsample')(d4)
    
    # Decoder 3
    merged = concatenate([d4, e3])
    d3 = Deconvolution2D(128, (3, 3), activation='relu', padding='same', name='block3_deconv1')(merged)
    d3 = Deconvolution2D(128, (3, 3), activation='relu', padding='same', name='block3_deconv2')(d3)
    d3 = Deconvolution2D(128, (3, 3), activation='relu', padding='same', name='block3_deconv3')(d3)
    d3 = UpSampling2D((2, 2), name='block3_upsample')(d3)
    
    # Decoder 2
    merged = concatenate([d3, e2])
    d2 = Deconvolution2D(64, (3, 3), activation='relu', padding='same', name='block2_deconv1')(merged)
    d2 = Deconvolution2D(64, (3, 3), activation='relu', padding='same', name='block2_deconv2')(d2)
    d2 = UpSampling2D((2, 2), name='block2_upsample')(d2)
    
    # Decoder 1
    merged = concatenate([d2, e1])
    d1 = Deconvolution2D(32, (3, 3), activation='relu', padding='same', name='block1_deconv1')(merged)
    d1 = Deconvolution2D(32, (3, 3), activation='relu', padding='same', name='block1_deconv2')(d1)
    d1 = UpSampling2D((2, 2), name='block1_upsample')(d1)
    
    d0 = Conv2D(16, (3, 3), activation='relu',    padding='same', name='out1')(d1)
    d0 = Conv2D(3, (3, 3), activation='sigmoid', padding='same', name='out2')(d0)


    
    # Create model.
    vgg16 = Model(inputs, e5, name='vgg16')
    weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                            WEIGHTS_PATH_NO_TOP,
                            cache_subdir='models')
    
    # Load weights of vgg16 and fix them (set non-trainable)
    vgg16.load_weights(weights_path)
    for l in vgg16.layers:
        l.trainable = False
    
    model = Model(inputs, d0, name='colorizer')
    print( model.summary() )
    return model, vgg16

In [None]:
model, vgg = create_model()

In [None]:
img = load_img('./Shifen-Waterfall-Taiwan.jpg', target_size=(224, 224))
x = img_to_array(img)
x = x.reshape((1,) + x.shape)

# res_first = first.predict([x], 1, verbose=1)
res_second = model.predict([x], 1, verbose=1)

# print(res_first[0].shape)
print(res_second[0].shape)

# Train

In [None]:
model.compile(loss='kullback_leibler_divergence',
              optimizer='adam')

In [None]:
class BatchGenerator:

    def __init__(self, image_paths, batch_size, image_height, image_width):
        self.image_paths = image_paths
        self.batch_size = batch_size
        self.image_height = image_height
        self.image_width = image_width
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def next(self):
        with self.lock:
            batch_features = np.zeros((self.batch_size, self.image_height, self.image_width, 3))
            batch_labels   = np.zeros((self.batch_size, self.image_height, self.image_width, 3))

            for i in range(self.batch_size):
                path= random.choice(self.image_paths)
                batch_features[i] = load_img(path, target_size=(self.image_height, self.image_width))
                batch_labels[i] = load_img(path, target_size=(self.image_height, self.image_width))
            return batch_features, batch_labels

In [None]:
train_root = './test2014/'
image_paths = [train_root + item for item in os.listdir(train_root) if item.endswith('.jpg') ]
print(image_paths[0])

In [None]:
batch_size = 10
steps_per_epoch = len(image_paths) / batch_size
print( steps_per_epoch, 'iterations per one epoch' )
print( 'There are', len(image_paths), 'images in total' )

In [None]:
model.fit_generator(BatchGenerator(image_paths=image_paths, 
                              batch_size=batch_size, 
                              image_height=224,
                              image_width=224), 
                    steps_per_epoch = steps_per_epoch, 
                    epochs = 1,
                    callbacks=[TensorBoard(log_dir='/tmp/coloring')], 
                    verbose=1, 
                    workers=1)