In [1]:
import numpy as np
import os 
import random

import tensorflow

from keras.layers import Conv2D, UpSampling2D, Conv2DTranspose
from keras.layers import Activation, Dense, Dropout,Flatten,InputLayer
from keras.layers.normalization import BatchNormalization
from keras.callbacks import TensorBoard
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

from skimage.color import rgb2lab, lab2rgb, rgb2gray
from skimage.io import imsave

Using TensorFlow backend.


In [2]:
# Get images
X = []
number_of_files = 0
for filename in os.listdir('Train'):
    temporary_img = load_img('Train/'+filename)
    temporary_img = img_to_array(temporary_img)
    X.append(temporary_img) #add the image to the array.
    number_of_files += 1 # count the number of files
    if number_of_files % 10 == 0:
        print(number_of_files) # print every 10th number
    # load in just the first 400 images
    # to help fight out of memory error.
    if number_of_files - 400 == 0:
        break
# Convert standard array to numpy array for future
X = np.array(X)#, dtype=float) #float gives error

10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400


In [3]:
# Set up train and test data
split = int(0.95 * len(X))
# use 95% of the loaded dataset for training.
Xtrain = X[:split]
# divide the numbers in the array by 255
# but keep them as float numbers (with 1.0)
Xtrain = 1.0/255 * Xtrain

In [4]:
# Building the neural network
model = Sequential()
#256px by 256px is expected as input
#and only lightness channel is given as input.
model.add(InputLayer(input_shape=(256, 256, 1)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
model.add(UpSampling2D((2, 2)))

In [5]:
model.compile(optimizer='rmsprop', loss='mse')

In [6]:
# Image transformer
# Rotate, flip, zoom in on pictures and etc so that
# the training set becomes larger
# improves the accurracy too!
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

In [7]:
# Generate training data
def image_a_b_gen():
    #for every picture: get its L and ab channels
    #and save them for later
    for batch in datagen.flow(Xtrain):#, batch_size=batch_size):
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0] # get the lightness channel
        Y_batch = lab_batch[:,:,:,1:] / 128
        # yield is like return but it just keeps the variables
        # in the memory, that way it doesn't stop the for loop on 1st
        # loop.
        yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)
        # it might be the cause of the out of memory error.

In [8]:
#Debugging information, to ensure everything is set up correctly.
print(number_of_files)
print(Xtrain.shape) #number of files to train on,
#the rest is left for testing and evaluating.

400
(380, 256, 256, 3)


In [None]:
# Train model
#tensorboard works on Manjaro, breaks on my laptop's windows 10.
#tensorboard = TensorBoard(log_dir="output/current_run")
#tensorboard = [tensorboard]
#model.fit(x=X,y=Y,batch_size=20,epochs=1)
model.fit_generator(image_a_b_gen(),
                   steps_per_epoch=3,
                   epochs=1,use_multiprocessing=True)#, callbacks=tensorboard)
                    #comment out use_multiprocessing on win10.

In [None]:
# Save model and its weights.
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("model.h5")

In [None]:
# Test the model using the test images
# Get the lightness channel of an image.
Xtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
#Get the ab channels of the image.
Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]
# Normalise the data to be between -1 and 1.
Ytest = Ytest / 128
#print the model's accuracy.
print(model.evaluate(Xtest, Ytest, batch_size=20))

In [None]:
# Prepare the images to be ran across the model's predictions.
color_me = []
for filename in os.listdir('Test/'):
    color_me.append(img_to_array(load_img('Test/'+filename)))
color_me = np.array(color_me, dtype=float)
color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]
color_me = color_me.reshape(color_me.shape+(1,))

In [None]:
# Test model
# Run the model through the prepared testing set.
output = model.predict(color_me)
#the values are between -1 and 1 so to restore the values of ab channels we need to multiply them by 128.
output = output * 128

In [None]:
#Save the output
#Add a module that does float64 to uint8 conversion for us
from skimage import img_as_ubyte

# Output colorizations
for i in range(len(output)):
    # Create an empty matrix that has 3 channels, each 256 x 256 in preparation of the final rgb picture.
    create_image = np.zeros((256, 256, 3))
    # Fill the first layer with the lightness channel information
    create_image[:,:,0] = color_me[i][:,:,0]
    #fill the 2nd and 3rd channel with produced output.
    create_image[:,:,1:] = output[i]
    
    create_image = lab2rgb(create_image)
    create_image = img_as_ubyte(create_image) #convert float64 to uint8 to avoid lossy conversion.
    #print(output[i].max())
    imsave("result/img_"+str(i)+".png", create_image)
    print("Saved picture number: ", i)