In [None]:
import keras
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.preprocessing import image
from keras.engine import Layer
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate, Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard 
from keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
from keras.datasets import cifar10
from skimage.transform import resize

from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import os
import random
import tensorflow as tf

In [20]:
#Load weights
# inception = InceptionResNetV2(weights=None, include_top=True)
inception = InceptionResNetV2(weights='imagenet', include_top=True)

# inception.load_weights('/data/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5')
inception.graph = tf.get_default_graph()

In [21]:
data_path = '/Users/shreyajain/Downloads/image_train_dataset/'


In [22]:
# img_size_ori = 2048
img_size_ori = 256
img_size_target = 128

def upsample(img):
    img_height = img.shape[0]
    img_width = img.shape[1]
    if img_size_ori == img_height and img_size_ori == img_width:
        return img
    return resize(img, (img_size_ori, img_size_ori), mode='constant', preserve_range=True)
    #res = np.zeros((img_size_target, img_size_target), dtype=img.dtype)
    #res[:img_size_ori, :img_size_ori] = img
    #return res
    
def downsample(img, img_shape):
    img_height = img_shape[0] 
    img_width = img_shape[1]
#     if img_size_ori == img_size_target:
#         return img
    return resize(img, (img_height, img_width), mode='constant', preserve_range=True)
    #return img[:img_size_ori, :img_size_ori]

In [23]:
# 256 l -> ab 
embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(256, 256, 1,))
print ("encoder_input ", encoder_input.shape)

encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
print ("encoder_output ", encoder_output.shape)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print ("encoder_output ", encoder_output.shape)

encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
print ("encoder_output ", encoder_output.shape)

encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
print ("encoder_output ", encoder_output.shape)

#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input) 
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 
print ("fusion_output ", fusion_output.shape)

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
print ("decoder_output ", decoder_output.shape)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
print ("decoder_output ", decoder_output.shape)

decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
print ("decoder_output ", decoder_output.shape)

decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
print ("decoder_output ", decoder_output.shape)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

encoder_input  (?, 256, 256, 1)
encoder_output  (?, 128, 128, 128)
encoder_output  (?, 64, 64, 256)
encoder_output  (?, 32, 32, 512)
encoder_output  (?, 32, 32, 256)
fusion_output  (?, 32, 32, 256)
decoder_output  (?, 32, 32, 128)
decoder_output  (?, 64, 64, 64)
decoder_output  (?, 128, 128, 16)
decoder_output  (?, 256, 256, 2)


In [24]:
val_path = "/Users/shreyajain/Downloads/image_val_dataset/"

import time
start_time = time.time()
from keras.preprocessing.image import ImageDataGenerator

def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_rgb_resized)
    return embed


#parameters
# batch_size = 20
batch_size = 2
lr_rate = 0.001




datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)


def val_image_a_b_gen(batch_size):
    for img_batch in datagen.flow_from_directory(directory = val_path,
                                                        target_size=(256,256),
                                                        color_mode='rgb',
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        seed=42):

        # print ("tuple ", len(img_batch))
        
        batch = img_batch[0]
        
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch)


def image_a_b_gen(batch_size):
    for img_batch in datagen.flow_from_directory(directory = data_path,
                                                        target_size=(256,256),
                                                        color_mode='rgb',
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        seed=42):

        # print ("tuple ", len(img_batch))
        batch = img_batch[0]
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        embed = create_inception_embedding(grayscaled_rgb)
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        print ("X_batch_input ", X_batch.shape)
        print ("Y_batch_input ", Y_batch.shape)
        print ("embed_input ", embed.shape)
        print ("X_batch", X_batch[0][0])
        print ("Y_batch", Y_batch[0][0])

        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch)



In [None]:
tensorboard = TensorBoard(log_dir="/Users/shreyajain/PycharmProjects/GAN/Image_colorisation/output",histogram_freq=0,  
          write_graph=True, write_images=True)
model.compile(optimizer='adam', loss='mse')

# model.fit_generator(generator=image_a_b_gen(batch_size), callbacks=[tensorboard], validation_data =val_image_a_b_gen(batch_size),validation_steps=batch_size, epochs=20, steps_per_epoch=2)
model.fit_generator(generator=image_a_b_gen(batch_size), epochs=5, steps_per_epoch=2)



In [12]:
# Save model
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("color_tensorflow_real_mode.h5")

In [18]:
color_me = []
test_path = '/Users/shreyajain/Downloads/image_test_dataset2/'
image_shape = ()
for filename in os.listdir(test_path):
    image_shape = img_to_array(load_img(test_path+filename)).shape
    color_me.append(upsample(img_to_array(load_img(test_path+filename))))
color_me = np.array(color_me, dtype=float)
gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
color_me_embed = create_inception_embedding(gray_me)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

print ("image_shape ", image_shape)

# Test model
output = model.predict([color_me, color_me_embed])
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:,:,0] = color_me[i][:,:,0]
    cur[:,:,1:] = output[i]
    img = lab2rgb(cur)
    img_resize = resize(img, (image_shape[0], image_shape[1]), mode='constant', preserve_range=True)
    imsave("result/img_"+str(i)+".png", img_resize)

image_shape  (2040, 1356, 3)
