In [1]:
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from skimage import color

In [2]:
# Take entire data in a batch
print(os.listdir('/kaggle/input'))
TRAIN_BATCH_SIZE = len(os.listdir('/kaggle/input/landscape-pictures'))
TEST_BATCH_SIZE = len(os.listdir('/kaggle/input/image-colorization-dataset/data/test_color'))

print(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)

In [3]:
BASE_PATH = '/kaggle/input/'

train_datagen = ImageDataGenerator(
    rescale = 1/255.0
)

train_data = train_datagen.flow_from_directory(
    BASE_PATH,
    classes=['landscape-pictures'],
    target_size = (224,224),
    batch_size = TRAIN_BATCH_SIZE
)

In [4]:
X = [] #grayscale images
y = [] #ground truth for X (other 2 channels)
i = 0
for image in train_data.next()[0]:
    image = color.rgb2lab(image)
    X.append(image[:, :, 0].reshape(224,224,1))
    y.append(image[:, :, 1:]/128)
    i += 1

X = np.array(X)
y = np.array(y)

In [5]:
print(X.shape, y.shape)

In [6]:
def build_encoder():
    encoder = keras.Sequential()
    encoder.add(layers.Input(shape=(224, 224, 1)))
    encoder.add(layers.Conv2D(64, (3,3), strides=2, padding='same', activation='relu'))
    encoder.add(layers.Conv2D(128, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(128, (3,3), strides= 2, padding='same', activation='relu'))
    encoder.add(layers.Conv2D(256, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(256, (3,3), strides= 2, padding='same', activation='relu'))
    encoder.add(layers.Conv2D(512, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(512, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(512, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(256, (3,3), padding='same', activation='relu'))

    # Decoder stage
    encoder.add(layers.Conv2D(128, (3,3), padding='same', activation='relu'))
    encoder.add(layers.UpSampling2D(size=(2,2)))
    encoder.add(layers.Conv2D(64, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(32, (3,3), padding='same', activation='relu'))

    encoder.add(layers.UpSampling2D(size=(2,2)))
    encoder.add(layers.Conv2D(16, (3,3), padding='same', activation='relu'))
    encoder.add(layers.Conv2D(2, (3,3), padding='same', activation='tanh'))
    encoder.add(layers.UpSampling2D(size=(2,2)))
    
    return encoder
model = build_encoder()
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=5, min_lr=0.001)

model.compile(optimizer='adam', metrics=['acc'], loss='mean_squared_error')
history = model.fit(X, y, epochs = 50, batch_size = 64, steps_per_epoch = X.shape[0]//64, validation_split = 0.3, verbose = 1, callbacks=[reduce_lr])

In [7]:
test_data = train_datagen.flow_from_directory(
    '/kaggle/input/image-colorization-dataset/data',
    classes=["test_color"],
    target_size = (224,224),
    batch_size = 64
)

In [15]:
test_image_batch = test_data.next()
test_images = test_image_batch[0]

In [17]:
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [16]:
for image in test_images:
    test_img = color.rgb2lab(image)
    plt.figure(figsize=(8, 6), dpi=80)
    plt.subplot(1,3,1)
    gray = np.zeros((224, 224, 1))
    gray[:,:,0] = test_img[:, :, 0]
    plt.title("Graysacle Image")
    plt.imshow(gray, cmap='gray')
    
    plt.subplot(1,3,2)
    pred = np.zeros((224, 224, 3))
    pred[:,:,0] = test_img[:, :, 0]
    ab = model.predict(test_img[:, :, 0].reshape((1,224,224,1))) 
    ab = ab*128
    pred[:,:,1:] = ab
    pred = color.lab2rgb(pred)
    plt.title("Encoder Output")
    plt.imshow(pred)
    
    plt.subplot(1,3,3)
    plt.title("Ground Truth")
    plt.imshow(color.lab2rgb(test_img))
    plt.show()

In [10]:
model.save("color.h5")