In [5]:
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from tensorflow.keras.layers import (
    Activation, AveragePooling2D, BatchNormalization, Conv2D, Conv2DTranspose,
    Dense, Dropout, Flatten, Input, LeakyReLU, ReLU, UpSampling2D)
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential, Model
from time import time
from skimage.color import rgb2lab, lab2rgb
import time

IMAGE_SIZE = 32
EPOCHS = 100
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 100

In [6]:
def generate_dataset(images, debug=False):
    X = []
    Y = []

    for i in images:
        lab_image_array = rgb2lab(i / 255)
        x = lab_image_array[:, :, 0]
        y = lab_image_array[:, :, 1:]
        y /= 128  # normalize


        X.append(x.reshape(IMAGE_SIZE, IMAGE_SIZE, 1))
        Y.append(y)

    X = np.array(X, dtype=np.float32)
    Y = np.array(Y, dtype=np.float32)

    return X, Y


def load_data(force=False):
    (train_images, _), (test_images, _) = cifar10.load_data()
    X_train, Y_train = generate_dataset(train_images)
    X_test, Y_test = generate_dataset(test_images)
    return X_train, Y_train, X_test, Y_test


X_train, Y_train, X_test, Y_test = load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [7]:
def conv2d_block(filters, kernel_size, apply_batchnorm=True,stride=2):
    initializer = tf.random_uniform_initializer(0, 0.02)
    model = Sequential()
    model.add(Conv2D(filters, kernel_size, strides=stride, padding='same',
                     kernel_initializer=initializer, use_bias=False))
    
    if apply_batchnorm:
        model.add(BatchNormalization())

    model.add(LeakyReLU())
    return model


def conv2d_transpose_block(filters, kernel_size, apply_batchnorm=True,stride=2,apply_dropout=False):
    initializer = tf.random_uniform_initializer(0, 0.02)
    model = Sequential()
    model.add(Conv2DTranspose(filters, kernel_size, strides=stride, padding='same',
                              kernel_initializer=initializer, use_bias=False))
    model.add(BatchNormalization())

    if apply_batchnorm:
        model.add(BatchNormalization())
    
    model.add(ReLU())

    if apply_dropout:
        model.add(Dropout(0.5))

    
    return model


def make_autoencoder_generator_model():
    inputs = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))

    # Encoder Layers
    # 1: (BATCH_SIZE, 16, 16, 32)
    # 2: (BATCH_SIZE, 8, 8, 64)
    # 3: (BATCH_SIZE, 4, 4, 128)
    # 4: (BATCH_SIZE, 2, 2, 256)
    # 5: (BATCH_SIZE, 1, 1, 256)
  

    downstack = [
        conv2d_block(32, 4, apply_batchnorm=False),
        conv2d_block(64, 4),
        conv2d_block(128, 4),
        conv2d_block(256, 4),
        conv2d_block(256, 4)
    ]

    # Decoder layers
    # 1: (BATCH_SIZE, 1, 1, 256)
    # 2: (BATCH_SIZE, 1, 1, 128)
    # 3: (BATCH_SIZE, 1, 1, 64)
    # 4: (BATCH_SIZE, 1, 1, 32)

    upstack = [
        conv2d_transpose_block(256, 4, apply_dropout=True),
        conv2d_transpose_block(128, 4),
        conv2d_transpose_block(64, 4),
        conv2d_transpose_block(32, 4),
    ]

    initializer = tf.random_uniform_initializer(0, 0.02)
    output_layer = Conv2DTranspose(2, 3, strides=2, padding='same',
                                   kernel_initializer=initializer,
                                   activation='tanh')
    
    x = inputs

    # Downsampling layers
    skips = []
    for dm in downstack:
        x = dm(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling layers
    for um, skip in zip(upstack, skips):
        x = um(x)
        x = tf.keras.layers.Concatenate()([x, skip])
    
    x = output_layer(x)

    return Model(inputs=inputs, outputs=x)

In [8]:
def make_discriminator_model():
    inputs = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
    x = inputs
    discstack=[
        conv2d_block(64, 4,apply_batchnorm=False),
        conv2d_block(128, 4),
        conv2d_block(256, 4),
        conv2d_block(512, 4,stride=1,apply_batchnorm=False),
    ]
    output_layer = conv2d_block(1,4,stride=1)


    for ds in discstack:
      x = ds(x)
    x = output_layer(x)


    return Model(inputs=inputs,outputs=x)

In [9]:
LAMBDA = 100
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss


def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    return total_gen_loss, gan_loss, l1_loss

In [10]:
generator = make_autoencoder_generator_model()
discriminator = make_discriminator_model()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)



In [11]:

@tf.function
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # print(input_image.shape)
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator(tf.keras.layers.concatenate([input_image, target]), training=True)
        # print(target.shape)
        # print(gen_output.shape)
        disc_generated_output = discriminator(tf.keras.layers.concatenate([input_image, gen_output]), training=True)
        
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

    
    return gen_total_loss, disc_loss

In [None]:
for e in range(EPOCHS):
    gen_loss_total = disc_loss_total = 0
    for input_image, target in train_dataset:
        gen_loss, disc_loss = train_step(input_image, target, e)
        gen_loss_total += gen_loss
        disc_loss_total += disc_loss

    
    print('Epoch {}: gen loss: {}, disc loss: {}'.format(
        e + 1, gen_loss_total / BATCH_SIZE, disc_loss_total / BATCH_SIZE))

In [12]:

# generator = tf.saved_model.load("./generator")


In [None]:
Y_hat = generator(X_test[:20])
# print(Y_hat.shape)
total_count = len(Y_hat)
import time

for idx, (x, y, y_hat) in enumerate(zip(X_test[:250], Y_test[:250], Y_hat)):

    # Original RGB image
    orig_lab = np.dstack((x, y * 128))
    orig_rgb = lab2rgb(orig_lab)

    # Grayscale version of the original image
    grayscale_lab = np.dstack((x, np.zeros((IMAGE_SIZE, IMAGE_SIZE, 2))))
    grayscale_rgb = lab2rgb(grayscale_lab)

    # Colorized image
    predicted_lab = np.dstack((x, y_hat * 128))
    predicted_rgb = lab2rgb(predicted_lab)
    # print(predicted_rgb.shape)


    
    plt.axis('off')
    # plt.imshow(grayscale_rgb)
    plt.savefig(os.path.join("./", 'results', '{}-bw.png'.format(idx)))

    plt.axis('off')
    # plt.imshow(orig_rgb)
    plt.savefig(os.path.join("./", 'results', '{}-gt.png'.format(idx)))

    plt.axis('off')
    # plt.imshow(predicted_rgb)
    plt.savefig(os.path.join("./", 'results', '{}-gan.png'.format(idx)))


None

In [None]:
tf.saved_model.save(generator, "./generator")

In [None]:
tf.saved_model.save(discriminator, "./discriminator")