In [None]:
import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt

# Change PATH variable to absolute/ relative path to the images directory on your machine which contains the train and val folders
PATH = './data//' 

# Change these variables as per your need
EPOCHS = 100
BUFFER_SIZE = 10
BATCH_SIZE = 2
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def load(image_file):
    
    image = tf.io.read_file(image_file)
    image = tf.image.decode_png(image)
    

    w = tf.shape(image)[1]

    w = w // 2

    input_image = image[:, w:, :]
    real_image = image[:, :w, :]
    
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    #random_jitter -> resize, random_crop
    height = 286
    width = 286
    
    #resize
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    
    #random_crop
    #เอา input_image และ real_image มาต่อกันและสร้างอีกมิติ 3D -> 4D
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    input_image = cropped_image[0]
    real_image = cropped_image[1]

    #สุ่มเลขทศนิยม
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
        
    #normalize
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1
    
    return input_image, real_image

In [None]:
train_dataset = tf.data.Dataset.list_files(PATH+'\\train\\*.png')
train_dataset = train_dataset.map(load, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.list_files(PATH+'\\val\\*.png')
test_dataset = test_dataset.map(load)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, shape, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', batch_input_shape=shape, 
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

def upsample(filters, size, shape, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2, batch_input_shape=shape,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

def buildGenerator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])

    down_stack = [
        downsample(64, 4, (None, 256, 256, 3), apply_batchnorm=False), # (none, 128, 128, 64)
        downsample(128, 4, (None, 128, 128, 64)), # (none, 64, 64, 128)
        downsample(256, 4, (None, 64, 64, 128)), # (none, 32, 32, 256)
        downsample(512, 4, (None, 32, 32, 256)), # (none, 16, 16, 512)
        downsample(512, 4, (None, 16, 16, 512)), # (none, 8, 8, 512)
        downsample(512, 4, (None, 8, 8, 512)), # (none, 4, 4, 512)
        downsample(512, 4, (None, 4, 4, 512)), # (none, 2, 2, 512)
        downsample(512, 4, (None, 2, 2, 512)), # (none, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, (None, 1, 1, 512), apply_dropout=True), # (none, 2, 2, 1024)
        upsample(512, 4, (None, 2, 2, 1024), apply_dropout=True), # (none, 4, 4, 1024)
        upsample(512, 4, (None, 4, 4, 1024), apply_dropout=True), # (none, 8, 8, 1024)
        upsample(512, 4, (None, 8, 8, 1024)), # (none, 16, 16, 1024)
        upsample(256, 4, (None, 16, 16, 1024)), # (none, 32, 32, 512)
        upsample(128, 4, (None, 32, 32, 512)), # (none, 64, 64, 256)
        upsample(64, 4, (None, 64, 64, 256)), # (none, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                           strides=2,
                                           padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh') # (none, 256, 256, 3)

    # U-Net
    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip]) # up1 down8

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

generator = buildGenerator()
generator.summary()

In [None]:
def downs(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', 
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

def buildDiscriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar]) # (none, 256, 256, channels*2)

    down1 = downs(64, 4, False)(x) # (none, 128, 128, 64)
    down2 = downs(128, 4)(down1) # (none, 64, 64, 128)
    down3 = downs(256, 4)(down2) # (none, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (none, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (none, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (none, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (none, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  
discriminator = buildDiscriminator()
discriminator.summary()

In [None]:
#Optimizers
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
#Displaying
def generate_images(model, test_input, tar):
    generate = model(test_input, training=True)
    plt.figure(figsize=(15,15))
    plt.subplot(1, 3, 1)
    plt.title('Input image')
    plt.imshow(test_input[0] / 2 + 0.5)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title('Target image')
    plt.imshow(tar[0] / 2 + 0.5)
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title('Generating image')
    plt.imshow(generate[0] / 2 + 0.5)
    plt.axis('off')
    
    plt.show()

In [None]:
import numpy as np
list_gen_target_loss = []
list_gen_generate_loss = []
list_disc_target_loss= []

list_val_gen_target_loss = []
list_val_gen_generate_loss = []
list_val_disc_target_loss = []

In [None]:
def show_graph():
    fig = plt.figure(figsize=(15,4.5))

    fig.add_subplot(1, 3, 1)
    plt.title('Gen_target_loss')
    plt.xlabel('step')
    plt.plot(list_gen_target_loss, label='Training')
    plt.plot(list_val_gen_target_loss, label='Val')
    plt.legend()
    
    fig.add_subplot(1, 3, 2)
    plt.title('Gen_generate_loss')
    plt.xlabel('step')
    plt.plot(list_gen_generate_loss, label='Training')
    plt.plot(list_val_gen_generate_loss, label='Val')
    plt.legend()
    
    fig.add_subplot(1, 3, 3)
    plt.title('Disc_loss')
    plt.xlabel('step')
    plt.plot(list_disc_target_loss, label='Training')
    plt.plot(list_val_disc_target_loss, label='Val')
    plt.legend()
    
    fig.tight_layout(pad=3.0)
    plt.show()

def train_func(gen_target_loss, gen_generate_loss, disc_target_loss):
    list_gen_target_loss.append(gen_target_loss)
    list_gen_generate_loss.append(gen_generate_loss)
    list_disc_target_loss.append(disc_target_loss)

    return True
    
def tf_function_gan(gen_target_loss, gen_generate_loss, disc_target_loss):
    tf.numpy_function(train_func, [gen_target_loss, gen_generate_loss, disc_target_loss], tf.bool)


def val_func(val_gen_target_loss, val_gen_generate_loss,val_disc_target_loss):
    list_val_gen_target_loss.append(val_gen_target_loss)
    list_val_gen_generate_loss.append(val_gen_generate_loss)
    list_val_disc_target_loss.append(val_disc_target_loss)

    return True
    
def tf_function_val(val_gen_target_loss, val_gen_generate_loss,val_disc_target_loss):
    tf.numpy_function(val_func, [val_gen_target_loss, val_gen_generate_loss,val_disc_target_loss], tf.bool)


show_graph()

In [None]:
#Loss Functions for the Models
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    gan_generate_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    gan_target_loss = gan_generate_loss + (LAMBDA * l1_loss)

    return gan_target_loss, gan_generate_loss


def discriminator_loss(disc_target_output, disc_generated_output):
    target_loss = loss_object(tf.ones_like(disc_target_output), disc_target_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = target_loss + generated_loss
    
    return total_disc_loss

In [None]:
#Train Step

def train_step(input_image, target, epoch, test_ds):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_target_output = discriminator([input_image, target], training=True)
        disc_generate_output = discriminator([input_image, gen_output], training=True)

        #train loss
        gen_target_loss, gen_generate_loss = generator_loss(disc_generate_output, gen_output, target)
        disc_target_loss = discriminator_loss(disc_target_output, disc_generate_output)
        
    #val loss
    for val_input, val_target in test_ds.take(1) :
        val_gen_output = generator(val_input)
        val_disc_target_output = discriminator([val_input, val_target])
        val_disc_generate_output = discriminator([val_input, val_gen_output])

        val_gen_target_loss, val_gen_generate_loss = generator_loss(val_disc_generate_output, val_gen_output, val_target)
        val_disc_target_loss = discriminator_loss(val_disc_target_output, val_disc_generate_output)
        
        tf_function_val(val_gen_target_loss, val_gen_generate_loss,val_disc_target_loss)
    
    tf_function_gan(gen_target_loss, gen_generate_loss, disc_target_loss)
                                            
    generator_gradients = gen_tape.gradient(gen_target_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_target_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

    

In [None]:
from tqdm import tqdm
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        for example_input, example_target in test_ds.take(1):
            generate_images(generator, example_input, example_target)
            show_graph()
        print("Epoch: ", epoch)

        for n, (input_image, target) in train_ds.enumerate():
            tf.print(n+1, end='')
            train_step(input_image, target, epoch, test_ds)
            
        print()
        print ('Time taken for epoch {} is {:.2f} sec\n'.format(epoch + 1, time.time()-start))

In [None]:
fit(train_dataset, EPOCHS, test_dataset)