In [1]:
# Sources
# --------
# Emil Wallner blog post and code on Floydhub for image colorization
# "Deep Koalarization": https://github.com/baldassarreFe/deep-koalarization
# "Colorful Image Colorization": https://github.com/richzhang/colorization
#Todo: Remove this line once it is installed, reset the kernel: Menu > Kernel > Reset & Clear Output
# !git clone https://github.com/fchollet/keras.git && cd keras && python setup.py install --user

In [2]:
import numpy as np
import os
import random
import sys
import tensorflow as tf

from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave

import matplotlib
# This line was for some of our machines; some edge case
# matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pickle

import keras
# import keras.backend as K
from keras.layers.core import RepeatVector
# Not sure we need all of these
from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape, merge, concatenate, Activation, Dense, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.losses import categorical_crossentropy
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import array_to_img, img_to_array, load_img
# from keras.preprocessing.image import ImageDataGenerator
from keras.utils.np_utils import to_categorical

from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_resnet_v2 import preprocess_input

Using TensorFlow backend.


In [3]:
# Paths and utilities
# -------------------
data_dir = "../data/"
train_dir = data_dir + "train/"
test_dir = data_dir + "test/"
results_dir = "../results/"

model_chkpt = "model_params/"
model_name = "rainbownet"

# Adjustable constants
BATCH_SIZE = 16
N_EPOCHS = 1
NUM_BUCKETS = 259
GRID_SIZE = 10

# Python 3 peculiarities
if sys.version_info[0] == 3:
    xrange = range

## Helper Functions
Stuff we're using in the other functions below.

In [4]:
def resize_images(images, shape):
  """ Does exactly what it sounds like it does.
      |shape|: a 3-tuple of the dimensions to resize to. 
               (H, W, 3) usually.
  """
  resized = []
  for i in images:
    im = resize(im, shape, mode='constant')
    resized.append(im)
  resized = np.array(resized)
  return resized

## **The Inception ResNet Fusion**

We're going to be using Google's Inception ResNet v2 to augment our CNN. This model has been highly trained on millions of examples, and can detect features that would be useful in colorizing our test images. The fusion of IRNet features with our CNN will come later in RainbowNetModel(). For now, these functions load the model and set up predictions.



In [5]:
def load_inception_net():
  inception = InceptionResNetV2(weights=None, include_top=True)
  inception.load_weights(data_dir + 'inception_resnet_v2_weights.h5')
  inception.graph = tf.get_default_graph()
  return inception

In [6]:
def create_inception_embedding(grayscaled_rgb):
  """ One forward pass through the InceptionResNet for one image,
      to give said image a 1000-vector embedding of it's features.
      
      |grayscaled_rgb|: the images we are embedding in grayscale. 
      
      :return: (None, 1000) vector.
  """
  # Resize all the images to (299, 299, 3).
  grayscaled_rgb_resized = resize_images(grayscaled_rgb, (299, 299, 3))

  # preprocessing for the Inception resnet
  grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)

  # Predict.
  with inception.graph.as_default():
    embed = inception.predict(grayscaled_rgb_resized)
  
  return embed

## Finding our color buckets
This is how we discretized the *ab* channels' color space into buckets for our model.


In [7]:
xrange = range

def discretize_lab():
    print("Building original mapping")
    bucket2ab_map = {} # index = bucket #, bucket2ab_map[index] = (a, b)
    bucket_count = 0
    for a in xrange(-128, 128, GRID_SIZE):
        for b in xrange(-128, 128, GRID_SIZE):
            a_mid = float(a + GRID_SIZE / 2)
            b_mid = float(b + GRID_SIZE / 2)
            bucket2ab_map[bucket_count] = (a_mid, b_mid)
            bucket_count += 1
    
    # Invert bucket2ab_map
    ab2bucket_map = {ab_pair : bucket for \
                     bucket, ab_pair in bucket2ab_map.items()}
    for bucket in bucket2ab_map:
        assert ab2bucket_map[bucket2ab_map[bucket]] == bucket

    return bucket2ab_map, ab2bucket_map

def get_bucket(lab_pixel, ab2bucket_map):
    a = int(lab_pixel[1])
    b = int(lab_pixel[2])

    # -128 through -119 maps to -123
    a_mid = GRID_SIZE * ((a + 128) // GRID_SIZE) - 128 + GRID_SIZE / 2
    b_mid = GRID_SIZE * ((b + 128) // GRID_SIZE) - 128 + GRID_SIZE / 2

    assert (a_mid, b_mid) in ab2bucket_map
    return ab2bucket_map[(a_mid, b_mid)]

def find_nonzero_buckets(bucket2ab_map, ab2bucket_map):
    print("Getting bucket counts")
    bucket_counts = [0] * len(bucket2ab_map)
    buckets = set()
    all_pixels = []
    for r in xrange(0, 256, 3):
        for g in xrange(0, 256, 2):
            for b in xrange(0, 256, 3):
                rgb_pixel = np.array([r/255., g/255., b/255.])
                all_pixels.append([rgb_pixel])
    lab_pixels = rgb2lab(all_pixels)
    for i in xrange(len(lab_pixels)):
        for j in xrange(len(lab_pixels[0])):
            lab_pixel = lab_pixels[i][j]
            bucket = get_bucket(lab_pixel, ab2bucket_map)
            bucket_counts[bucket] += 1

    print("Building updated mapping")
    new_bucket2ab_map = {}
    for i in xrange(len(bucket_counts)):
        if bucket_counts[i] > 0:
            new_bucket2ab_map[i] = bucket2ab_map[i]

    new_ab2bucket_map = {ab_pair : bucket for \
                         bucket, ab_pair in new_bucket2ab_map.items()}

    print("There are %d buckets." % len(new_bucket2ab_map))
    return new_bucket2ab_map, new_ab2bucket_map

def plot_mapping(bucket2ab_map):
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect='equal')
    for bucket in bucket2ab_map:
        a_mid = bucket2ab_map[bucket][0]
        a_low = a_mid - GRID_SIZE / 2
        b_mid = bucket2ab_map[bucket][1]
        b_low = b_mid - GRID_SIZE / 2

        lab_pixel = np.array([50., a_mid, b_mid])
        lab_pixel = np.reshape(lab_pixel, (1, 1, 3))
        rgb = tuple(lab2rgb(lab_pixel)[0][0])

        ax.add_patch(
            patches.Rectangle(
                (b_low, -a_low), 
                GRID_SIZE, 
                GRID_SIZE, 
                facecolor=rgb
            )
        )

    plt.xlim([-128, 128])
    plt.ylim([-128, 128])
    plt.show()

def get_buckets():
    bucket2ab_map, ab2bucket_map = discretize_lab()
    return find_nonzero_buckets(bucket2ab_map, ab2bucket_map)

In [8]:
"""
Loading our bucket maps.
"""

# To open: `map = pickle.load(open(<filename>, 'r'))`
BUCKET2AB = 'bucket2ab_map.pkl'
AB2BUCKET = 'ab2bucket_map.pkl'

bucket2ab_map = None
ab2bucket_map = None

if not os.path.isfile(BUCKET2AB) or not os.path.isfile(AB2BUCKET):
  bucket2ab_map, ab2bucket_map = get_buckets()

  plot_demo = False  
  if plot_demo:
    plot_mapping(bucket2ab_map)
  
  pickle.dump(bucket2ab_map, open(BUCKET2AB, 'wb'))
  pickle.dump(ab2bucket_map, open(AB2BUCKET, 'wb'))
  
else:
  print("Loading buckets from pickle.")
  bucket2ab_map = pickle.load(open(BUCKET2AB, 'rb'))
  ab2bucket_map = pickle.load(open(AB2BUCKET, 'rb'))
  
if NUM_BUCKETS != len(bucket2ab_map):
  print("NUM_BUCKETS= %s does not match the number of buckets found."
        % str(NUM_BUCKETS))
  print("Setting NUM_BUCKETS=%d" % len(bucket2ab_map))
  NUM_BUCKETS = len(bucket2ab_map)

print("Buckets loaded successfully!")

# bucket2ab_map: map of bucket # to (a, b)
# ab2bucket_map: map of (a, b) to bucket #
# use get_bucket(lab_pixel, ab2bucketmap) to get corresponding (a, b)

Loading buckets from pickle.
Buckets loaded successfully!


In [9]:
def discretize(images_ab):
  """ This is a preprocessing step, that will be used in converting 
      Y (ab image true labels) into buckets, so that we can calculate
      a loss in colorization_loss().
  
      |images_ab|: (m, H, W, 2) array representing ab channels of images.
  
      :return: (m, H, W) array where each entry is in [0, NUM_BUCKETS].
               One of the NUM_BUCKETS=259 color buckets we found.
  """
  m, H, W, _ = images_ab.shape
  images_d = np.zeros((m, H, W))

  for i in xrange(m):
    for h in xrange(H):
      for w in xrange(W):
        (a,b) = images_ab[i,h,w]
        images_d[i,h,w] = ab2bucket_map[(a,b)]
  
  return images_d

In [10]:
# This function will be called when we want to demo a predicted image.
# It takes an image with pixels labeled with buckets, and transforms it
#   into 2 color channels.

def inverse_discretize(images_d):
  """ The inverse of the above function. Maps the indicated bucket to
      the mean of that bucket.
      
      |images_d|: an array (m,H,W) with the color bucket assigned to each
                  pixel.
                  
      :return: (m, H, W, 2) array with ab color values."""
  m, H, W = images_d.shape
  images_ab = np.zeros((m, H, W, 2))
  
  for i in xrange(m):
    for h in xrange(H):
      for w in xrange(W):
        bucket = images_d[i,h,w]
        images_ab[i,h,w] = bucket2ab_map[bucket] # Sets to [a,b]
  
  return images_ab
  

## Getting the data
This function loads the data, for training and for test.

```preprocess_data()``` then processes them for input into our CNN. This involves converting to *Lab* and separating the channels, running through InceptionResnet, and normalizing the *L* channel input.

In [11]:
# Examples of use:
# load_data(train_dir)
# load_data(test_dir)

def load_data(directory):
  """ Load an entire set of |m| examples. If loading entire dataset takes
      too much memory, may have to run in batches: 
      train, save chkpt.
      Put new examples in the directory.
      Repeat. 
  """
  images = []
  for filename in os.listdir(directory):
      image = load_img(directory + filename)  # PIL image
      images.append(img_to_array(image))      # np.array

  images = np.array(images, dtype=float)
  return images
  

In [12]:
def preprocess_data(images):
  """ Preprocess the data for input into our rainbownet model.
  """
  # Resize them to (224,224)
  images_resized = resize_images(images, (224, 224, 3))
  
  # Get the inception embeddings from grayscaled images. 
  #   This will be part of our passed in input.
  grayscaled_images = gray2rgb_vec(rgb2gray_vec(images_resized))
  embs = create_inception_embedding(grayscaled_images) # already vec'd
  
  rgb2gray_vec = np.vectorize(rgb2gray)
  gray2rgb_vec = np.vectorize(gray2rgb)  
  rgb2lab_vec = np.vectorize(rgb2lab)
  
  # Separate the l- and ab- channels
  images_lab = rgb2lab_vec(images)
  images_l = images_lab[:,:,:,0]        # first channel is L
  images_ab = images_lab[:,:,:,1:]      # second 2, ab channels
  
  # Normalize L channel to be between 0, 1.
  pass
  
  # Create X, composed of L channel + the embedding
  X = zip(images_l, embs)         # tuple: (m, H, W), (m, 1000)
  
  # Create Y, including discretizing the ab image
  Y = discretize(images_ab)    # shape (m, H, W)
  
  return X, Y

## Defining our loss
Our colorization loss is the softmax cross-entropy between the multinomial color distributions of every pixel in ```y_true``` and ```y_pred``` over all pixels (H,W), over all images in the minibatch. Inspired by the colorization paper, to solve issues other losses like MSE face.

TODO: implement color weighting, like in Richard Zhang et al.



In [13]:
def colorization_loss(y_true, y_pred):
  """ |y_true|: Our true colors. An array (batch_size, H, W) with entries 
                specifying one of the buckets that pixel's color is in.
      |y_pred|: A (batch_size, H, W, 259) volume with last dimension a
                softmax over bucket probabilities.
  
      This loss involves computing the softmax cross-entropy over pixel's
        predicted color bucket, over all images in the batch.
  """
  # softmax cse with logits for each pixel
  
  # https://www.tensorflow.org/api_docs/python/tf/contrib/layers/flatten
  y_true_flat = tf.contrib.layers.flatten(y_true)
  y_pred_flat = tf.contrib.layers.flatten(y_pred)
  # Turns into one-hot representation
  y_true_cat = to_categorical(y_true_flat, num_classes=NUM_BUCKETS)
  
  return categorical_crossentropy(y_true_cat, y_pred_flat)

## Defining our model
We initialize the first four layers of our CNN to the VGG16's pretrained layers, and freeze them. This transfer learning will help us 

In [14]:
def RainbowNetModel():
  # THE MEAT OF THE CODE
  
  embed_input = Input(shape=(1000,))
  encoder_input = Input(shape=(224, 224, 1,)) # SHAPE WILL CHANGE, probably

  
  """
  #Encoder
  
  encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
  encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
  encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
  encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
  encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
  encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
  encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
  encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

  #Fusion
  fusion_output = RepeatVector(32 * 32)(embed_input) 
  fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
  fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
  fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

  #Decoder
  decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
  decoder_output = UpSampling2D((2, 2))(decoder_output)
  decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
  decoder_output = UpSampling2D((2, 2))(decoder_output)
  decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
  decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
  decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
  decoder_output = UpSampling2D((2, 2))(decoder_output)
  """
  
  model = Model(inputs=[encoder_input, embedding_input], outputs=decoder_output)
  return model


In [15]:
def save_model(model):
  model_json = model.to_json()
  with open(model_chkpt + model_name + ".json", "w") as json_file:
    json_file.write(model_json)
  model.save_weights(model_chkpt + model_name + ".h5")

In [17]:
def load_existing_model():
  weights_path = model_chkpt + model_name + ".h5"

  if not os.path.isfile(weights_path):
    print("The model at path %s was not found." % weights_path)
    quit()
    
  model = RainbowNetModel()
  rainbowModel.load_weights(weights_path)
  return model

## Time to train the model!
Now we can run our model. It will save its parameters after every training session.

If you're looking to only predict, run the cell after this.

In [18]:
images_train = load_data(train_dir)
X_train, Y_train = preprocess_data(images_train)

images_test = load_data(test_dir)
X_test, Y_test = preprocess_data(images_test)

# Run the model!
rainbowModel = RainbowNetModel()

rainbowModel.compile(optimizer='adam', 
                     loss=colorization_loss,
                     metrics=['top_k_categorical_accuracy'])

rainbowModel.fit(x=X_train, y=Y_train,
                 epochs=N_EPOCHS,
                 batch_size=BATCH_SIZE)

save_model(rainbowModel)

# Get predictions after we've trained.
predictions = rainbowModel.evaluate(X_test, Y_test)


FileNotFoundError: [Errno 2] No such file or directory: '../data/train/'

## Predict using the model
If you don't want to train, but only use the saved model to predict something, run this cell.

In [None]:
X_test, Y_test = load_data(test_dir)

rainbowModel = load_existing_model()

predictions = rainbowModel.evaluate(X_test, Y_test)
