# Image Colorization using CNNs

In this notebook we explore how to produce a plausible colorized image from a grayscale image.

## Setup

We are going to use Google Drive's storage to save and load data. First, we need to authenticate our user.
We assume everything will be inside a folder named 'shared' in the root of Drive. This cell can be skipped, but then the model won't be saved later using our method.

In [1]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


Other imports that will be used throughout the code:

In [2]:
from keras.datasets import cifar10
!pip install tqdm # for progress bar support
!pip install parmap
import parmap
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import cv2
import sklearn.neighbors as nn
from scipy.interpolate import interp1d
from scipy.signal import gaussian, convolve
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
import os

import keras
from keras.optimizers import Adam
from keras import backend as K
from keras.models import Sequential
from keras.layers import Conv2D, BatchNormalization, ZeroPadding2D, Conv2DTranspose, Lambda, Softmax, Add, Input, Activation
from keras.preprocessing.image import ImageDataGenerator

Using TensorFlow backend.




## Model definition

First we define the model hyperparameters as follows:


In [0]:
h = 32
w = h
input_shape = (h,w,1) # Channels last
batch_size = 1
epochs = 5
nb_classes = 313

We need to import the file pts_in_hull.npy so we define it's location

In [0]:
(raw_train, _), (raw_test, _) = cifar10.load_data()

data_dir = "./drive/My Drive/shared/"
q_ab = np.load(os.path.join(data_dir, "pts_in_hull.npy"))
nb_q = q_ab.shape[0]

Then we define some of the required supporting functions

In [0]:
def calcualte_factor(raw_train, q_ab, nearest):
  gamma=0.5
  alpha=1
  sigma=5
  n, h, w, c = raw_train.shape
  
  X_a = np.ravel(raw_train[:, :, :, 0])
  X_b = np.ravel(raw_train[:, :, :, 1])
  X_ab = np.vstack((X_a, X_b)).T

  # Find index of nearest neighbor for X_ab
  _, ind = nearest.kneighbors(X_ab)

  # We now count the number of occurrences of each color
  ind = np.ravel(ind)
  counts = np.bincount(ind)
  idxs = np.nonzero(counts)[0]
  prior_prob = np.zeros((q_ab.shape[0]))
  for i in range(q_ab.shape[0]):
      prior_prob[idxs] = counts[idxs]

  # We turn this into a color probability
  prior_prob = prior_prob / (1.0 * np.sum(prior_prob))

  # add an epsilon to prior prob to avoid 0 vakues and possible NaN
  prior_prob += 1E-3 * np.min(prior_prob)
  # renormalize
  prior_prob = prior_prob / (1.0 * np.sum(prior_prob))

  # Smooth with gaussian
  f = interp1d(np.arange(prior_prob.shape[0]),prior_prob)
  xx = np.linspace(0,prior_prob.shape[0] - 1, 1000)
  yy = f(xx)
  window = gaussian(2000, sigma)  # 2000 pts in the window, sigma=5
  smoothed = convolve(yy, window / window.sum(), mode='same')
  fout = interp1d(xx,smoothed)
  prior_prob_smoothed = np.array([fout(i) for i in range(prior_prob.shape[0])])
  prior_prob_smoothed = prior_prob_smoothed / np.sum(prior_prob_smoothed)

  u = np.ones_like(prior_prob_smoothed)
  u = u / np.sum(1.0 * u)

  prior_factor = (1 - gamma) * prior_prob_smoothed + gamma * u
  prior_factor = np.power(prior_factor, -alpha)

  # renormalize
  prior_factor = prior_factor / (np.sum(prior_factor * prior_prob_smoothed))

  return prior_factor

In [0]:
def transformY(y, nearest, prior_factor, nb_q): 
  Y_a = np.ravel(y[ :, :, 0])
  Y_b = np.ravel(y[ :, :, 1])
  Y_batch_ab = np.vstack((Y_a, Y_b)).T
  s = Y_batch_ab.shape[0]
  
  #Calculate nearest color cell to the pixels ab
  _, idx_neigh = nearest.kneighbors(Y_batch_ab)
  
  del Y_batch_ab
  
  #Y has size numpixels * numimages, 313 (number of color cells)
  Y = np.empty((s, nb_q))
  #idx_pts = np.arange(s)[:, np.newaxis]
  
  #We load Y with the color cell of every pixel
  Y[np.arange(s)[:, np.newaxis], idx_neigh] = np.ones((len(idx_neigh), 1))
  
  #find the actual cell color from the 313
  idx_max = np.argmax(Y, axis=1)
  
  #Add the color correction of the corresponding cell
  weights = prior_factor[idx_max].reshape(Y.shape[0], 1)
  #Append the weigts. 313 ->314
  Y = np.concatenate((Y, weights), axis=1)
  #Reshape into normal image
  h, w, c = y.shape
  Y = Y.reshape((h, w, nb_q + 1))

  return Y

In [0]:
"""From built-in optimizer classes.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import copy
from six.moves import zip

from keras import backend as K
from keras.utils.generic_utils import serialize_keras_object
from keras.utils.generic_utils import deserialize_keras_object
from keras.legacy import interfaces

from keras.optimizers import Optimizer

class AdamW(Optimizer):
    """AdamW optimizer.
    Default parameters follow those provided in the original paper.
    # Arguments
        lr: float >= 0. Learning rate.
        beta_1: float, 0 < beta < 1. Generally close to 1.
        beta_2: float, 0 < beta < 1. Generally close to 1.
        epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
        decay: float >= 0. Learning rate decay over each update.
        weight_decay: float >= 0. Weight decay (L2 penalty) (default: 0.025).
        batch_size: integer >= 1. Batch size used during training.
        samples_per_epoch: integer >= 1. Number of samples (training points) per epoch.
        epochs: integer >= 1. Total number of epochs for training. 
    # References
        - [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
        - [Fixing Weight Decay Regularization in Adam](https://arxiv.org/abs/1711.05101)
    """

    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
                 epsilon=None, decay=0., weight_decay=0.025, 
                 batch_size=1, samples_per_epoch=1, 
                 epochs=1, **kwargs):
        super(AdamW, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
            self.weight_decay = K.variable(weight_decay, name='weight_decay')
            self.batch_size = K.variable(batch_size, name='batch_size')
            self.samples_per_epoch = K.variable(samples_per_epoch, name='samples_per_epoch')
            self.epochs = K.variable(epochs, name='epochs')
        if epsilon is None:
            epsilon = K.epsilon()
        self.epsilon = epsilon
        self.initial_decay = decay

    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
                                                      K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        '''Bias corrections according to the Adam paper
        '''
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            
            '''Schedule multiplier eta_t = 1 for simple AdamW
            According to the AdamW paper, eta_t can be fixed, decay, or 
            also be used for warm restarts (AdamWR to come). 
            '''
            eta_t = 1.
            p_t = p - eta_t*(lr_t * m_t / (K.sqrt(v_t) + self.epsilon))
            if self.weight_decay != 0:
                '''Normalized weight decay according to the AdamW paper
                '''
                w_d = self.weight_decay*K.sqrt(self.batch_size/(self.samples_per_epoch*self.epochs))
                p_t = p_t - eta_t*(w_d*p) 

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'beta_1': float(K.get_value(self.beta_1)),
                  'beta_2': float(K.get_value(self.beta_2)),
                  'decay': float(K.get_value(self.decay)),
                  'weight_decay': float(K.get_value(self.weight_decay)),
                  'batch_size': int(K.get_value(self.batch_size)),
                  'samples_per_epoch': int(K.get_value(self.samples_per_epoch)),
                  'epochs': int(K.get_value(self.epochs)),
                  'epsilon': self.epsilon}
        base_config = super(AdamW, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [0]:
def categorical_crossentropy_color(y_true, y_pred):
  # Flatten
  y_pred = K.reshape(y_pred, (batch_size * h * w, nb_classes+1))
  y_true = K.reshape(y_true, (batch_size * h * w, nb_classes+1))

  weights = y_true[:, nb_classes:]  # extract weight from y_true
  weights = K.concatenate([weights] * nb_classes, axis=1)
  y_true = y_true[:, :-1]  # remove last column
  y_pred = y_pred[:, :-1]  # remove last column
  # multiply y_true by weights
  y_true = y_true * weights

  cross_ent = K.categorical_crossentropy(y_pred, y_true)
  cross_ent = K.mean(cross_ent, axis=-1)
  return cross_ent

We define the model

In [0]:
model = Sequential()

model.add(Conv2D(64, kernel_size=3, name='conv1', input_shape=input_shape, activation='relu', padding='same'))

for i in range(0,3):
  model.add(Conv2D(64, kernel_size=3, padding="same"))
  model.add(BatchNormalization(axis=-1))
  model.add(Activation("relu"))
  
  model.add(Conv2D(64, kernel_size=3, padding="same"))
  model.add(BatchNormalization(axis=-1))
  model.add(Activation("relu"))

model.add(Conv2D(nb_classes, (1, 1), name="convFinal", padding="same"))

# Reshape Softmax
def output_shape(input_shape):
    return (batch_size, h, w, nb_classes + 1)

def reshape_softmax(x):
#     x = K.permute_dimensions(x, [0, 2, 3, 1])  # last dimension in number of filters
    x = K.reshape(x, (batch_size * h * w, nb_classes))
    x = K.softmax(x)
    # Add a zero column so that x has the same dimension as the target (313 classes + 1 weight)
    xc = K.zeros((batch_size * h * w, 1))
    x = K.concatenate([x, xc], axis=1)
    # Reshape back to (batch_size, h, w, nb_classes + 1) to satisfy keras' shape checks
    x = K.reshape(x, (batch_size, h, w, nb_classes + 1))
    return x

model.add(Lambda(lambda z: reshape_softmax(z), name="ReshapeSoftmax"))


In [11]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1 (Conv2D)               (None, 32, 32, 64)        640       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
batch_normalization_1 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
batch_normalization_2 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
activation_2 (Activation)    (None, 32, 32, 64)        0         
__________

In [0]:
adamw = AdamW(lr=10e-5, beta_1=0.9, beta_2=0.999, weight_decay=1e-3, batch_size=batch_size, epochs=epochs)
model.compile(loss=categorical_crossentropy_color, optimizer=adamw)

## Model training

Before actually training the model, we have to specify how we are going to feed the data. We will use a Generator derived from the Sequence class so we can use multiprocessing

In [0]:
class DataGenerator(keras.utils.Sequence):
  
  def __init__(self, data, batch_size=16, dim=(32,32), shuffle=True):
    'Initialization'
    self.dim = dim
    self.data = data
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.on_epoch_end()
    print("generator initialzed with, ", len(data))
    
  def on_epoch_end(self):
    self.indexes = np.arange(len(self.data))
    if self.shuffle == True:
      np.random.shuffle(self.indexes)
      
  def __data_generation(self, idx_temp):
    
    X = np.empty((self.batch_size, *self.dim, 1)) # 1 Channel
    y = np.empty((self.batch_size, *self.dim, 314)) # 313+1 channels
    
    for i, idx in enumerate(idx_temp):
      lab = np.asarray(rgb2lab(self.data[idx]))
      
      X[i,] = lab[:,:,:1]/100.
      y[i,] = transformY(lab[:,:,1:3], nearest, prior_factor, nb_q)
    
    return X, y
  
  def __len__(self):
    'Denotes the number of batches per epoch'
    return int(np.floor(len(self.data) / self.batch_size))
  
  def __getitem__(self, index):
    'Generate one batch of data'
    # Generate indexes of the batch
    indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

    # Find list of IDs
    list_IDs_temp = [self.indexes[k] for k in indexes]

    # Generate data
    X, y = self.__data_generation(list_IDs_temp)

    return X, y


Actual training using the CIFAR10 dataset using a train and validation set.

In [16]:
# Create nearest neighbord instance with index = q_ab
nearest = nn.NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(q_ab)
prior_factor = calcualte_factor(raw_train, q_ab, nearest)

training_generator = DataGenerator(data=raw_train, batch_size=batch_size, shuffle=False)
testing_generator = DataGenerator(data=raw_test,  batch_size=batch_size, shuffle=False)

import multiprocessing

workers = multiprocessing.cpu_count()
model.fit_generator(training_generator,
                    validation_data=testing_generator,
                    epochs=epochs,
                    use_multiprocessing=True,
                    workers=workers,
                    verbose=1)


generator initialzed with,  50000
generator initialzed with,  10000
Epoch 1/5
 2726/50000 [>.............................] - ETA: 27:51 - loss: 6.4047

Process ForkPoolWorker-4:
Process ForkPoolWorker-3:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process ForkPoolWorker-1:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/lib/

KeyboardInterrupt: ignored

Save the model to Google Drive. 

In [0]:
model.save("./drive/My Drive/shared/model_b{}_e{}.h5".format(batch_size, epochs))
model.save_weights("./drive/My Drive/shared/model_b{}_e{}_weights.h5".format(batch_size, epochs))

Alternatively, we can save the model to the current working environment using:

In [0]:
model.save("model.h5")

## Model validation

To perform validation , and due to some  of the hardcoded parameters of our model, it is necesary to run the Colorize_Complex.py script