In [None]:
import keras
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Layer, BatchNormalization
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input
from tensorflow.keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose, Input, Reshape
from keras.layers import merge, concatenate
from tensorflow.keras.layers import Activation, Dense, Dropout, Flatten
from tensorflow.keras.callbacks import TensorBoard 
from tensorflow.keras.models import Sequential, Model
from keras.layers.core import RepeatVector, Permute
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
from keras.callbacks import *
import numpy as np
import os
import seaborn as sns
import random
import tensorflow as tf
import zipfile
import matplotlib.pyplot as plt
from tqdm import tqdm
print("Libriaries loaded")

Libriaries loaded


In [None]:
photos = zipfile.ZipFile("/content/drive/MyDrive/buildings_and_human_dataset.zip",'r')
photos.extractall("input/");

In [None]:
def create_inception_embedding(grayscaled_rgb):
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    embed = inception.predict(grayscaled_rgb_resized)
    return embed

    
def load_train_data(path):
    train_data = np.array([img_to_array(load_img(path + file_name)) for file_name in os.listdir(path)], dtype=float)
    print("Train data is loaded.")
    print(len(train_data), "img loaded.")
    train_data = 1.0 / 255 * train_data
    return train_data


def load_test_data(path):
    test_data = np.array([img_to_array(load_img(path + file_name)) for file_name in os.listdir(path)], dtype=float)
    gray_me = gray2rgb(rgb2gray(1.0/255*test_data))
    X_embed = create_inception_embedding(gray_me)
    X_test = rgb2lab(test_data)[:, :, :, 0]
    X_test = X_test.reshape(X_test.shape + (1,))
    Y_test = rgb2lab(test_data)[:, :, :, 1:]
    Y_test /= 128
    return X_test, X_embed, Y_test


In [None]:
inception = InceptionResNetV2(weights='imagenet', include_top=True)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_resnet_v2/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5


In [None]:
path = "/content/input/"
train_data = load_train_data(path + "train/")
X_test, X_embed, Y_test = load_test_data(path + "test/")

In [None]:
embed_input = Input(shape=(1000,))

#Encoder
encoder_input = Input(shape=(None, None, 1,))
encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)

#Fusion
fusion_output = RepeatVector(32 * 32)(embed_input) 
fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output) 

#Decoder
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(fusion_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)

model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

In [None]:
class CustomSaver(Callback):
    def on_epoch_end(self, epoch, logs={}):
        if (epoch + 1) % 10 == 0:  # or save after some epoch, each k-th epoch etc.
            self.model.save_weights(f"/content/drive/MyDrive/model_weights/human_and_buildings/model_v2_{epoch + 1}.h5")
        tf.keras.backend.clear_session()

In [None]:
model.load_weights("/content/drive/MyDrive/model_v2_138.h5")

In [None]:
datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

batch_size = 40
epochs_size = 200

def image_a_b_gen(batch_size):
    for batch in datagen.flow(train_data, batch_size=batch_size):
        grayscaled_rgb = gray2rgb(rgb2gray(batch))
        lab_batch = rgb2lab(batch)
        X_batch = lab_batch[:,:,:,0]
        X_batch = X_batch.reshape(X_batch.shape+(1,))
        Y_batch = lab_batch[:,:,:,1:] / 128
        yield ([X_batch, create_inception_embedding(grayscaled_rgb)], Y_batch)

In [None]:



saver = CustomSaver()
history = History()
model.compile(optimizer='rmsprop', loss='mse', metrics = ["accuracy"])
model.fit_generator(image_a_b_gen(batch_size),
                    epochs=epochs_size,
                    callbacks=[saver, history],
                    steps_per_epoch= len(train_data) // batch_size,
                    validation_data=([X_test, X_embed],Y_test))

In [None]:
model.save_weights(f"/content/drive/MyDrive/model_weights/human_and_buildings/model_v2{epochs_size}}.h5")

In [None]:
plt.figure(figsize=(14,6))
sns.lineplot(data = {'accuracy':history.history['accuracy'], 'val_accuracy' : history.history['val_accuracy']})

In [None]:
def colorize_part(img, path, shape = (256, 256)):
  image = img_to_array(img)
  color_me = [image]
  color_me = np.array(color_me, dtype=float)
  gray_me = gray2rgb(rgb2gray(1.0/255*color_me))
  color_me_embed = create_inception_embedding(gray_me)
  color_me = rgb2lab(1.0/255*image)[:,:,0]
  color_me = color_me.reshape((1,*shape,1))

  out = model.predict([color_me, color_me_embed])
  out *= 128
  cur = np.zeros((*shape , 3))
  cur[:,:,0] = color_me[0][:,:,0]
  cur[:,:,1:] = out[0]
  cur = lab2rgb(cur)
  return cur

In [None]:
from PIL import Image
from PIL import ImageFile


def split_image(img, stride=256):
    shape = img.size
    part_list = []
    for x in range(0, shape[0], stride):
        img_line = []
        for y in range(0, shape[1], stride):
            part = img.crop((x, y, x + 256, y + 256))
            img_line.append(part)
        part_list.append(img_line)
    return part_list

def get_composed_arr(img_list, original_shape, stride):
    composed_arr = np.zeros((original_shape[1], original_shape[0], 3))
    mask = np.zeros((original_shape[1], original_shape[0], 3))
    
    max_x = original_shape[1]
    max_y = original_shape[0]

    pbar = tqdm(total=len(img_list)*len(img_list[0]), position=0, leave=True)
    x = 0
    count = 0
    for i in range(len(img_list[0])):
        y = 0
        for j in range(len(img_list)):
            pbar.update(1)
            step_x = min(x + 256, max_x)
            step_y = min(y + 256, max_y)
            composed_arr[x:step_x, y:step_y, :] += img_list[j][i][:step_x-x, :step_y-y, :]
            mask[x:step_x, y:step_y, :] += 1
            y += stride

            count += 1
            pbar.set_description(f'\tComposed {count} parts')
        x += stride
    pbar.close()
    composed_arr /= mask
    return composed_arr

def compose_images(img_list, original_shape, stride=256):
    composed_arr = get_composed_arr(img_list, original_shape, stride)
    composed_img = Image.fromarray((composed_arr * 255).astype(np.uint8))
    return composed_img

In [None]:
def colorize_image(img_path, save_path, stride=256):
    image = load_img(img_path)
    splited_img = split_image(image, stride=stride)

    print(f"Coloring {img_path.split('/')[-1]}")
    print('\tColoring parts...')
    
    count = 0
    pbar = tqdm(total=len(splited_img)*len(splited_img[0]), position=0, leave=True)
    for i in range(len(splited_img)):
        for j in range(len(splited_img[0])):
            splited_img[i][j] = colorize_part(splited_img[i][j], img_path)
            
            pbar.update(1)
            count += 1
            pbar.set_description(f"\tColorized {count} parts")
    pbar.close()
    print('\tCompose parts...')
    composed = compose_images(splited_img, image.size, stride=stride)
    print(img_path.split('/')[-1], f"colorized with stride {stride}.")
    if not os.path.exists(save_path):
      os.mkdir(save_path)
    composed.save(f"{save_path}/{stride}_" + img_path.split('/')[-1])


In [None]:
bw_image_path = '/content/img/'

for file in os.listdir(bw_image_path):
  if file == '.ipynb_checkpoints':
    continue
  colorize_image(bw_image_path + file, '/content/result/', stride=32)

In [None]:
path_lst = ['/content/2665.jpg']

for path in path_lst:
  for stride in [256, 32]:
    colorize_image(path, '/content/result/', stride=stride)

Coloring 2665.jpg
	Coloring parts...


	Colorized 16 parts: 100%|██████████| 16/16 [00:06<00:00,  2.29it/s]


	Compose parts...


	Composed 16 parts: 100%|██████████| 16/16 [00:00<00:00, 213.73it/s]


2665.jpg colorized with stride 256.
Coloring 2665.jpg
	Coloring parts...


	Colorized 725 parts: 100%|██████████| 725/725 [04:23<00:00,  2.75it/s]


	Compose parts...


	Composed 725 parts: 100%|██████████| 725/725 [00:02<00:00, 283.34it/s]


2665.jpg colorized with stride 32.


In [None]:
  !zip -r /content/result.zip /content/result

  adding: content/result/ (stored 0%)
  adding: content/result/32_1 (1).png (deflated 1%)
  adding: content/result/32_1 (2).png (deflated 1%)
  adding: content/result/32_1 (3).png (deflated 0%)
  adding: content/result/32_1 (4).png (deflated 0%)
