In [1]:
import tensorflow as tf
import pickle
import os
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Activation, Conv2D, Conv2DTranspose, BatchNormalization, Dropout, LeakyReLU
from tensorflow.keras.layers import Resizing, Rescaling
from tf.keras.applications.vgg16 import VGG16

# Generator and Discriminator

In [8]:
def create_generator():
    lrelu = LeakyReLU(alpha = 0.2)
      
    inputs = Input(shape = (256, 256, 1))
    
    encoder_1 = Conv2D(32, (4, 4), (2, 2), padding = 'same')(inputs)
    
    encoder_2 = Conv2D(64, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_1)
    encoder_2 = BatchNormalization()(encoder_2)
    
    encoder_3 = Conv2D(128, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_2)
    encoder_3 = BatchNormalization()(encoder_3)
    
    encoder_4 = Conv2D(256, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_3)
    encoder_4 = BatchNormalization()(encoder_4)
    
    encoder_5 = Conv2D(256, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_4)
    encoder_5 = BatchNormalization()(encoder_5)
    
    encoder_6 = Conv2D(256, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_5)
    encoder_6 = BatchNormalization()(encoder_6)
    
    encoder_7 = Conv2D(256, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_6)
    encoder_7 = BatchNormalization()(encoder_7)
    
    encoder_8 = Conv2D(256, (4, 4), (2, 2), activation = lrelu, padding = 'same')(encoder_7)
    encoder_8 = BatchNormalization()(encoder_8)
    
    decoder_8 = Conv2DTranspose(256, (4, 4), (2, 2), activation = 'relu', padding = 'same')(encoder_8)
    decoder_8 = BatchNormalization()(decoder_8)
    decoder_8 = Dropout(0.5)(decoder_8)
    
    decoder_7 = Conv2DTranspose(256, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_8, encoder_7], axis = -1))
    decoder_7 = BatchNormalization()(decoder_7)
    decoder_7 = Dropout(0.5)(decoder_7)
    
    decoder_6 = Conv2DTranspose(256, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_7, encoder_6], axis = -1))
    decoder_6 = BatchNormalization()(decoder_6)
    decoder_6 = Dropout(0.5)(decoder_6)
    
    decoder_5 = Conv2DTranspose(256, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_6, encoder_5], axis = -1))
    decoder_5 = BatchNormalization()(decoder_5)
    
    decoder_4 = Conv2DTranspose(128, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_5, encoder_4], axis = -1))
    decoder_4 = BatchNormalization()(decoder_4)
    
    decoder_3 = Conv2DTranspose(64, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_4, encoder_3], axis = -1))
    decoder_3 = BatchNormalization()(decoder_3)
    
    decoder_2 = Conv2DTranspose(32, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_3, encoder_2], axis = -1))
    decoder_2 = BatchNormalization()(decoder_2)
    
    decoder_1 = Conv2DTranspose(1, (4, 4), (2, 2), activation = 'relu', padding = 'same')\
                        (concatenate([decoder_2, encoder_1], axis = -1))
    
    outputs = Activation('tanh')(decoder_1)
    
    model = Model(inputs = inputs, outputs = outputs)
    
    return model

In [None]:
def create_discriminator():
    inputs = Input(shape = (256, 256, 2))
    lrelu = LeakyReLU(alpha = 0.2)
    
    layer_1 = Conv2D(32, (4, 4), (2, 2), activation = lrelu, padding = 'same')(inputs)
    
    layer_2 = Conv2D(64, (4, 4), (2, 2), activation = lrelu, padding = 'same')(layer_1)
    layer_2 = BatchNormalization()(layer_2)
    
    layer_3 = Conv2D(128, (4, 4), (2, 2), activation = lrelu, padding = 'same')(layer_2)
    layer_3 = BatchNormalization()(layer_3)
    
    layer_4 = Conv2D(256, (4, 4), (1, 1), activation = lrelu, padding = 'same')(layer_3)
    layer_4 = BatchNormalization()(layer_4)
    
    layer_4 = Conv2D(1, (4, 4), (1, 1), activation = lrelu, padding = 'same')(layer_4)
    
    outputs = Activation('sigmoid')(layer_4)
    
    model = Model(inputs = inputs, outputs = outputs)
    
    return model

# Loss and Train Functions

In [None]:
def sum_tv_loss(image):
    loss_y = tf.nn.l2_loss(image[:, 1:, :, :] - image[:, :-1, :, :])
    loss_x = tf.nn.l2_loss(image[:, :, 1:, :] - image[:, :, :-1, :])
    loss = 2 * (loss_y + loss_x)
    loss = tf.cast(loss, tf.float32)
    return loss

In [None]:
def feature_loss(image, vgg):
    vgg.build(image)
    loss = vgg.conv3_3
    return loss

In [None]:
def discriminator_loss(predict_real, predict_fake):
    return tf.reduce_mean(-(tf.log(predict_real + 1e-12) + tf.log(1 - predict_fake + 1e-12)))

In [None]:
def generator_loss(predict_fake, targets, outputs, net1, net2):
    gen_loss_GAN = -tf.reduce_mean(predict_fake)
    gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
    gen_loss_tv = tf.reduce_mean(tf.sqrt(tf.nn.l2_loss(sum_tv_loss(outputs))))
    gen_loss_f = tf.reduce_mean(tf.sqrt(tf.nn.l2_loss(feature_loss(targets,net1) - feature_loss(outputs,net2))))
    gen_loss = gen_loss_GAN + (gen_loss_L1 * 10) + (gen_loss_tv * 1e-5) + (gen_loss_f * 1e-4)
    
    return gen_loss

In [None]:
@tf.function
def train_step(sketches, images, net1, net2):    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(sketches, training=True)

        predict_real = discriminator_real(concatenate(sketches, images, axis = 3), training = True)
        predict_fake = discriminator_fake(concatenate(sketches, generated_images, axis = 3), training = True)

        gen_loss = generator_loss(predict_fake, images, generated_images, net1, net2)
        disc_loss = discriminator_loss(predict_real, predict_fake)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator_real.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator_real.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator_fake.trainable_variables))
    
    return

In [None]:
def train(dataset, epochs):
    
    net1 = Vgg16()
    net2 = Vgg16()
    
    for epoch in range(epochs):
        start = time.time()

        for sketch_batch, image_batch in dataset:
            train_step(sketch_batch, image_batch, net1, net2)

        # Produce images for the GIF as you go
    #     display.clear_output(wait=True)
        generate_and_save_images(generator, epoch + 1, seed)

        # Save the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, seed)
    
    return

# Preprocessing

In [10]:
def resize_image(image):
    return tf.image.resize_with_pad(image, 256, 256, method = ResizeMethod.BILINEAR, antialias=False)

In [None]:
def rescale_image(image):
    re

In [None]:
generator = create_generator()
discriminator_real = create_discriminator()
discriminator_fake = create_discriminator()

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Generate and Save Images

In [None]:
def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
# def conv(batch_input, out_channels, stride):
#     with tf.variable_scope("conv"):
#         in_channels = batch_input.get_shape()[3]
#         filter = tf.get_variable("filter", [4, 4, in_channels, out_channels], dtype=tf.float32,
#                                  initializer=tf.random_normal_initializer(0, 0.02))
#         # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
#         #     => [batch, out_height, out_width, out_channels]
#         padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
#         conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding="VALID")
#         return conv

In [None]:
# def deconv(batch_input, out_channels):
#     with tf.variable_scope("deconv"):
#         batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
#         filter = tf.get_variable("filter", [4, 4, out_channels, in_channels], dtype=tf.float32,
#                                  initializer=tf.random_normal_initializer(0, 0.02))
#         # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
#         #     => [batch, out_height, out_width, out_channels]
#         conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels],
#                                       [1, 2, 2, 1], padding="SAME")
#         return conv