In [14]:
import os
import cv2
import keras
import math
import random
import pathlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from keras import backend as K
from keras.preprocessing import image
from keras.engine import Layer
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate, Activation, Dense, Dropout, Flatten, LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.models import Sequential, Model
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
from skimage import exposure

In [15]:
# Memory error, so to feed model in batches
# Image transformer
batch_size = 20
train_dir = '../Capstone/images/train/'

def custom_preprocessing(image):
    state = random.randint(0,2)
    if state == 0:
        processed_img = exposure.equalize_adapthist((image*1.0/255), clip_limit=0.02)
    elif state == 1:
        processed_img = exposure.equalize_hist(image)
    elif state == 2:
        p2, p98 = np.percentile(image, (2,98))
        processed_img = exposure.rescale_intensity(image, in_range=(p2,p98))
    return processed_img

image_gen = ImageDataGenerator(
        shear_range=0.4,
        zoom_range=0.4,
        vertical_flip=True,
        preprocessing_function=custom_preprocessing)

# function to split training set X train, y train and produce augmented images       
def image_a_b_gen(batch_size):
    for i in image_gen.flow_from_directory(train_dir, batch_size=batch_size,class_mode=None,shuffle=False):
        lab_batch = rgb2lab(i)
        X_train = lab_batch[:,:,:,0] / 100
        X_train = X_train.reshape(X_train.shape+(1,))
        y_train = lab_batch[:,:,:,1:] / 128
        yield ([X_train, y_train])


In [17]:
class DSSIMObjective:
    def __init__(self, k1=0.01, k2=0.03, max_value=1.0):
        self.__name__ = 'DSSIMObjective'
        self.k1 = k1
        self.k2 = k2
        self.max_value = max_value
        self.backend = K.backend()

    def __int_shape(self, x):
        return K.int_shape(x) if self.backend == 'tensorflow' else K.shape(x)

    def __call__(self, y_true, y_pred):
        ch = K.shape(y_pred)[-1]

        def _fspecial_gauss(size, sigma):
            #Function to mimic the 'fspecial' gaussian MATLAB function.
            coords = np.arange(0, size, dtype=K.floatx())
            coords -= (size - 1 ) / 2.0
            g = coords**2
            g *= ( -0.5 / (sigma**2) )
            g = np.reshape (g, (1,-1)) + np.reshape(g, (-1,1) )
            g = K.constant ( np.reshape (g, (1,-1)) )
            g = K.softmax(g)
            g = K.reshape (g, (size, size, 1, 1)) 
            g = K.tile (g, (1,1,ch,1))
            return g
                  
        kernel = _fspecial_gauss(11,1.5)

        def reducer(x):
            return K.depthwise_conv2d(x, kernel, strides=(1, 1), padding='valid')

        c1 = (self.k1 * self.max_value) ** 2
        c2 = (self.k2 * self.max_value) ** 2
        
        mean0 = reducer(y_true)
        mean1 = reducer(y_pred)
        num0 = mean0 * mean1 * 2.0
        den0 = K.square(mean0) + K.square(mean1)
        luminance = (num0 + c1) / (den0 + c1)
        
        num1 = reducer(y_true * y_pred) * 2.0
        den1 = reducer(K.square(y_true) + K.square(y_pred))
        c2 *= 1.0 #compensation factor
        cs = (num1 - num0 + c2) / (den1 - den0 + c2)

        ssim_val = K.mean(luminance * cs, axis=(-3, -2) )
        return K.mean( (1.0 - ssim_val ) / 2.0 )

In [18]:
ssim_loss = DSSIMObjective()

In [16]:
#Shared models
encoder_input = Input(shape=(256, 256, 1,))
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = BatchNormalization()(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output_shared = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)

#Model A
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output_shared)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
#Model B
global_encoder = Conv2D(512, (3,3), activation='relu', padding='same',strides=2)(encoder_output_shared)
global_encoder = Conv2D(512, (3,3), activation='relu', padding='same')(global_encoder)
global_encoder = BatchNormalization()(global_encoder)
global_encoder = Conv2D(512, (3,3), activation='relu', padding='same',strides=2)(global_encoder)
global_encoder = Conv2D(512, (3,3), activation='relu', padding='same')(global_encoder)
global_encoder = BatchNormalization()(global_encoder)
global_encoder = Flatten()(global_encoder)
global_encoder = Dense(1024, activation='relu')(global_encoder)
global_encoder = Dense(512, activation='relu')(global_encoder)
global_encoder = Dense(256, activation='relu')(global_encoder)
global_encoder = RepeatVector(32 * 32)(global_encoder)
global_encoder = Reshape([32,32,256])(global_encoder)
#Fusion 
fusion_output = concatenate([encoder_output, global_encoder], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output)
#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=encoder_input, outputs=decoder_output)
# Finish model
model.compile(optimizer='adam',loss=ssim_loss ,metrics=['mse','mean_absolute_error'])

In [None]:
# Train model
model.fit_generator(image_a_b_gen(batch_size=batch_size), steps_per_epoch=10, epochs=1000)

In [14]:
model.save('C:/Users/n3rDx/Desktop/Homework Upload/Capstone/3rd_expt_2.h5')

In [19]:
# Load black and white images
test = []
for filename in os.listdir('../Capstone/images/test/test/'):
        test.append(img_to_array(load_img('C:/Users/n3rDx/Desktop/Homework Upload/Capstone/images/test/test/'+filename)))
test = np.array(test, dtype=float)
test = rgb2lab(test/255.0)[:,:,:,0]
test = test.reshape(test.shape+(1,))

In [20]:
# Test model
output = model.predict(test)
output = output * 128


In [None]:
# Output colorizations
for i in range(len(output)):
        cur = np.zeros((256, 256, 3))
        cur[:,:,0] = color_me[i][:,:,0]
        cur[:,:,1:] = output[i]
        imsave("C:/Users/n3rDx/Desktop/Homework Upload/Capstone/result/"+str(i)+".jpg", lab2rgb(cur).astype('uint8'))