# Image Colorization using CNNs

In this notebook we explore how to produce a plausible colorized image from a grayscale image.

## Setup

We are going to use Google Drive's storage to save and load data. First, we need to authenticate our user.
We assume everything will be inside a folder named 'shared' in the root of Drive. This cell can be skipped, but then the model won't be saved later using our method.

In [1]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive/


Other imports that will be used throughout the code:

In [2]:
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
from skimage import img_as_ubyte
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
import os
import keras
from keras.optimizers import Adam
from keras import backend as K
from keras.models import Sequential
from keras.layers import Conv2D, BatchNormalization, UpSampling2D, ZeroPadding2D, Conv2DTranspose, Lambda, Softmax, Add, Input, Activation
from keras.preprocessing.image import ImageDataGenerator
import multiprocessing

Using TensorFlow backend.


## Model definition

First we define the model hyperparameters as follows:


In [0]:
h, w = (32, 32)
input_shape = (h,w,1) # Channels last
batch_size = 16
epochs = 5
nb_classes = 313

In [0]:
from keras.callbacks import TensorBoard

model = Sequential()

model.add(Conv2D(64, kernel_size=3, padding='same', strides=1, activation='relu', input_shape=input_shape))
model.add(Conv2D(64, kernel_size=3, padding='same', strides=2, activation='relu'))
model.add(Conv2D(128, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(128, kernel_size=3, padding='same' ,strides=2, activation='relu'))
model.add(Conv2D(256, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(256, kernel_size=3, padding='same', strides=2, activation='relu'))
model.add(Conv2D(512, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(256, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(128, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(16, kernel_size=3, padding='same', strides=1, activation='relu'))
model.add(Conv2D(2, kernel_size=3, padding='same', strides=1, activation='tanh'))
model.add(UpSampling2D((2, 2)))

In [5]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 32, 32, 64)        640       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 64)        36928     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 128)       73856     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 128)         147584    
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 256)         295168    
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 4, 4, 256)         590080    
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 4, 4, 512)         1180160   
__________

In [0]:
model.compile(loss='mse', optimizer='adam')

## Model training

Before actually training the model, we have to specify how we are going to feed the data. We will use a Generator derived from the Sequence class so we can use multiprocessing

In [0]:
class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, data, batch_size=16, dim=(32,32), shuffle=True):
    'Initialization'
    self.dim = dim
    self.data = data
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.on_epoch_end()
    print("generator initialzed with ", len(data), "samples")
    
  def on_epoch_end(self):
    self.indexes = np.arange(len(self.data))
    if self.shuffle == True:
      np.random.shuffle(self.indexes)
      
  def __data_generation(self, idx_temp):
    
    X = np.empty((self.batch_size, *self.dim, 1)) # 1 Channel
#     y = np.empty((self.batch_size, *self.dim, 314)) # 313+1 channels
    y = np.empty((self.batch_size, *self.dim, 2)) # a,b channels
    
    for i, idx in enumerate(idx_temp):
      lab = np.asarray(rgb2lab(resize(self.data[idx], self.dim)/255.0))
      
      X[i,] = lab[:,:,:1]/100.
      y[i,] = (lab[:,:,1:3]/128.).reshape((*self.dim, 2))
#       y[i,] = transformY(lab[:,:,1:3], nearest, prior_factor, nb_q)
    
    return X, y
  
  def __len__(self):
    'Denotes the number of batches per epoch'
    return int(np.floor(len(self.data) / self.batch_size))
  
  def __getitem__(self, index):
    'Generate one batch of data'
    # Generate indexes of the batch
    indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

    # Find list of IDs
    list_IDs_temp = [self.indexes[k] for k in indexes]

    # Generate data
    X, y = self.__data_generation(list_IDs_temp)

    return X, y


Actual training using the CIFAR10 dataset using a train and validation set.

In [8]:
(raw_train, _), (raw_test, _) = cifar10.load_data()

training_generator = DataGenerator(data=raw_train, batch_size=batch_size, shuffle=True, dim=(h,w))
testing_generator = DataGenerator(data=raw_test,  batch_size=batch_size, shuffle=False, dim=(h,w))

model.fit_generator(training_generator,
                    workers=multiprocessing.cpu_count(),
                    validation_data=testing_generator,
                    epochs=epochs,
                    verbose=1)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
generator initialzed with  50000 samples
generator initialzed with  10000 samples
Epoch 1/5


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7efe32f627b8>

Save the model to Google Drive. 

In [0]:
model.save("./drive/My Drive/shared/model_b{}_e{}.h5".format(batch_size, epochs))
model.save_weights("./drive/My Drive/shared/model_b{}_e{}_weights.h5".format(batch_size, epochs))

Alternatively, we can save the model to the current working environment using:

In [0]:
model.save("model.h5")

## Model validation

We perform a simple validation with one of the elements of of raw_test.

In [0]:
def recolorize_img(img):
  lab = np.asarray(rgb2lab(resize(img, (h,w))/255.))
  x_black = lab[:,:,:1]/100.
  y_pred = model.predict(x_black.reshape(1,h,w,1))
  
  res_img = np.empty((h,w,3))
  res_img[:,:,:1] = x_black*100
  res_img[:,:,1:] = y_pred*128
  
  rgb = lab2rgb(res_img)*255
  
  return img_as_ubyte(lab2rgb(res_img)*255)

In [0]:
img = raw_test[13] # Testing on one of the images, for example
plt.imshow(img)
plt.show()
plt.imshow(recolorize_img(img))
plt.show()