In [1]:
################################################################################
# Import necessary libraries
################################################################################

from google.colab import drive
drive.mount('/content/drive/')

import os
import random
from keras import layers
from keras.backend import tf as ktf
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.models import load_model
import zipfile
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
import copy

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


Using TensorFlow backend.


In [0]:
################################################################################
# Generator and its functions
################################################################################

# Returns numpy arrays l, a and b (w x h x 1)
def lab_img(img):
    if img.format != ("JPEG" or "JPG"):
        img = img.convert("RGB")
    img = rgb2lab(img)

    l = img[:, :, 0]

    l = (np.array(l) / 100)

    l = np.expand_dims(l, axis=2)

    return l


# Returns batch of x and y values packed together
def batch_images(image_path, imgs, trained_model):
    with imgs.open(image_path) as img:
        img = Image.open(img)
        w, h = img.size
        while w % 8 != 0:
            w += 1
        
        while h % 8 != 0:
            h += 1
        img = img.resize((w, h))
        
#         s = predict_segmentation(img.convert("L").convert("RGB"), trained_model)
        s = predict_classification(img.convert("L").convert("RGB"), trained_model)
        l = lab_img(img)
        return np.expand_dims(l, axis=0), np.expand_dims(s, axis=0)


# Returns one-hot encoded segmentation object (w x h x 150)
def predict_segmentation(img, trained_model):
    # TODO: do this in my model
    data_mean = np.array([[[123.68, 116.779, 103.939]]])
    input_size = (473, 473)
    output_size = (img.size[0] / 8, img.size[1] / 8)

    if img.size != input_size:
        img = img.resize(input_size)

    pixel_img = np.array(img)
    pixel_img = pixel_img - data_mean
    bgr_img = pixel_img[:, :, ::-1]
    segmented_img = trained_model.predict(np.expand_dims(bgr_img, 0))[0]
    if output_size != input_size:
          segmented_img = resize(segmented_img,
                                 (output_size[1], output_size[0], 150),
                                  mode="constant",
                                  preserve_range=True)
    return segmented_img
  
  
def predict_classification(img, trained_model):
    classification_img = trained_model.predict(np.expand_dims(np.array(img), axis=0))[0]
    classification_img = resize(classification_img,
                               (img.size[1] / 8, img.size[0] / 8, 1280),
                               mode="constant",
                               preserve_range=True)
    return classification_img


# Yields batches of x and y values
def generator_fn(n_images, images_path, trained_model):
    with zipfile.ZipFile(images_path) as imgs:
        image_paths = imgs.infolist()
        random.shuffle(image_paths)
        i = 0
        while True:
            if i == n_images:
                i = 0
            x, s = batch_images(image_paths[i], imgs, trained_model)
            i += 1
            yield x, s


In [0]:
################################################################################
# Loading, training and saving model
################################################################################

class Interp(layers.Layer):

    def __init__(self, new_size, **kwargs):
        self.new_size = new_size
        super(Interp, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Interp, self).build(input_shape)

    def call(self, inputs, **kwargs):
        new_height, new_width = self.new_size
        resized = ktf.image.resize_images(inputs, [new_height, new_width],
                                          align_corners=True)
        return resized

    def compute_output_shape(self, input_shape):
        return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

    def get_config(self):
        config = super(Interp, self).get_config()
        config['new_size'] = self.new_size
        return config


# Returns trained model instance
def load_trained_model(path):
    trained_model = load_model(path, custom_objects={'Interp': Interp})
    trained_model._make_predict_function()
    return trained_model
  

In [0]:

# Takes one instance of input and output (no batches)
def decode_images(l, s, y, save_path):
    global n_labels
    h, w, d = l.shape
    a, b = np.split(y[0], [1], 2)
    l = l[:, :, 0] * 100
    a = (a[:, :, 0] + 1) * 255 / 2 - 127
    b = (b[:, :, 0] + 1) * 255 / 2 - 128

    # index to save the images under
    i = 0
    while os.path.isfile(os.path.join(save_path, str(i) + "_input.jpg")):
        i += 1

    bnw_input = np.zeros((h, w, 3))
    bnw_input[:, :, 0] = l
    Image.fromarray((lab2rgb(bnw_input)*255).astype('uint8')).save(os.path.join(save_path, str(i) + "_input.jpg"))

    color_output = np.zeros((h, w, 3))
    color_output[:, :, 0] = l
    color_output[:, :, 1] = a
    color_output[:, :, 2] = b
    Image.fromarray((lab2rgb(color_output)*255).astype('uint8')).save(os.path.join(save_path, str(i) + "_output.jpg"))

    color_output = np.zeros((h, w, 3))
    color_output[:, :, 0] = np.full((h, w), 70)
    color_output[:, :, 1] = a
    color_output[:, :, 2] = b
    Image.fromarray((lab2rgb(color_output)*255).astype('uint8')).save(os.path.join(save_path, str(i) + "_output_labels.jpg"))


# visualizes images of first batch generated by generator_fn
def validate_images(n_images, model, generator_fn, save_path):
    for i in range(n_images):
        x, s = next(generator_fn)
        y = model.predict([x, s])
        print(i)
        decode_images(x[0], s[0], y, save_path)

In [5]:
################################################################################
# Control panel
################################################################################

n_images = 15
!mkdir drive/My\ Drive/Validation/16th_gen

images_path = "drive/My Drive/Datasets/bw_test_dataset.zip"
images_destination = "drive/My Drive/Validation/16th_gen"
# trained_model = load_trained_model("drive/My Drive/Checkpoints/pspnet.h5")
trained_model = MobileNetV2(include_top=False)
colorization_model = load_model("drive/My Drive/Checkpoints/Models/16th_gen.hdf5")
validation_data_fn = generator_fn(n_images, images_path, trained_model)

validate_images(n_images, colorization_model, validation_data_fn, images_destination)




Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
