In [1]:
import tensorflow
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import os
import time

In [3]:
os.chdir("C:\\Users\\kc510\\Documents\\Projects\\Project_iNeuron\\Sketch_To_Color_Image")

In [4]:
path = 'data'

epochs=100
buffer_size=6000
batch_size=4
img_width = 256
img_height = 256

In [5]:
def load(image_file):
    image = tensorflow.io.read_file(image_file)
    image = tensorflow.image.decode_png(image)

    w = tensorflow.shape(image)[1]
    print(w)

    w = w // 2
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]

    input_image = tensorflow.cast(input_image, tensorflow.float32)
    real_image = tensorflow.cast(real_image, tensorflow.float32)

    return input_image, real_image

In [6]:
def resize(input_image, real_image, height, width):
    input_image = tensorflow.image.resize(input_image, [height, width],
                                          tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tensorflow.image.resize(real_image, [height, width],
                                          tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image

In [7]:
def random_crop(input_image, real_image):
    stacked_image = tensorflow.stack([input_image, real_image], axis=0)
    cropped_image = tensorflow.image.random_crop(
        stacked_image, size=[2, img_height, img_width, 3])

    return cropped_image[0], cropped_image[1]

In [8]:
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

In [9]:
@tensorflow.function
def random_jitter(input_image,real_image):
    input_image, real_image = resize(input_image,real_image,286,286)
    input_image, real_image = random_crop(input_image,real_image)

    if tensorflow.random.uniform(())>0.5:
        input_image = tensorflow.image.flip_left_right(input_image)
        real_image = tensorflow.image.flip_left_right(real_image)
    
    return input_image, real_image

In [10]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image,real_image)
    input_image, real_image = normalize(input_image,real_image)

    return input_image, real_image

train_dataset = tensorflow.data.Dataset.list_files(path+'\\train\\*.png')
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tensorflow.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(buffer_size=buffer_size).batch(batch_size)

Tensor("strided_slice:0", shape=(), dtype=int32)


In [11]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image,real_image,img_width,img_height)
    input_image, real_image = normalize(input_image,real_image)

    return input_image, real_image

test_dataset = tensorflow.data.Dataset.list_files(path+'\\val\\*.png')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(batch_size)

Tensor("strided_slice:0", shape=(), dtype=int32)


In [12]:
OUTPUT_CHANNELS = 3

def downsample(filters,size,shape,apply_batchnorm=True):
    initializer = tensorflow.random_normal_initializer(0.,0.02)

    result = tensorflow.keras.Sequential()
    result.add(tensorflow.keras.layers.Conv2D(filters,size,strides=2,padding='same',
                                                kernel_initializer=initializer,
                                                batch_input_shape=shape,
                                                use_bias=False))
    
    if apply_batchnorm:
        result.add(tensorflow.keras.layers.BatchNormalization())
    
    result.add(tensorflow.keras.layers.LeakyReLU())

    return result


def upsmaple(filters,size,shape,apply_dropout=False):
    initializer = tensorflow.random_normal_initializer(0.,0.02)

    result = tensorflow.keras.Sequential()
    result.add(tensorflow.keras.layers.Conv2DTranspose(filters,size,strides=2,padding='same',
                                                kernel_initializer=initializer,
                                                batch_input_shape=shape,
                                                use_bias=False))
    
    result.add(tensorflow.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tensorflow.keras.layers.Dropout(0.5))

    result.add(tensorflow.keras.layers.ReLU())

    return result

In [13]:
def buildGenerator():
    inputs = tensorflow.keras.layers.Input(shape=[256,256,3])

    down_stack = [
        downsample(64,4, (None,256,256,3), apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128,4, (None,128,128,64)),
        downsample(256,4, (None,64,64,128)),
        downsample(512,4, (None,32,32,256)),
        downsample(512, 4, (None,16,16,512)),
        downsample(512, 4, (None,8,8,512)),
        downsample(512,4, (None, 4,4,512)),
        downsample(512, 4, (None,2,2,512))
    ]

    upstack = [
        upsmaple(512,4,(None,1,1,512), apply_dropout=True),
        upsmaple(512,4,(None,2,2,1024), apply_dropout=True),
        upsmaple(512,4,(None,4,4,1024), apply_dropout=True),
        upsmaple(512,4,(None,8,8,1024)),
        upsmaple(256,4,(None,16,16,1024)),
        upsmaple(128,4,(None,32,32,512)),
        upsmaple(64,4,(None,64,64,256))
    ]

    initializer = tensorflow.random_normal_initializer(0.,0.02)
    last = tensorflow.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS,4,strides=2,padding='same',kernel_initializer=initializer,activation='tanh')

    x = inputs

    skips=[]
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(upstack,skips):
        x = up(x)
        x = tensorflow.keras.layers.Concatenate()([x,skip])
    
    x = last(x)

    return tensorflow.keras.Model(inputs=inputs,outputs=x)

generator = buildGenerator()

In [14]:
def downs(filters,size,apply_batchnorm=False):
    initializer = tensorflow.random_normal_initializer(0.,0.02)

    result = tensorflow.keras.Sequential()
    result.add(tensorflow.keras.layers.Conv2D(filters,size,strides=2,padding='same',kernel_initializer=initializer,use_bias=False))

    if apply_batchnorm:
        result.add(tensorflow.keras.layer.BatchNormalization())

    result.add(tensorflow.keras.layers.LeakyReLU())

    return result

def buildDescriminator():
    initializer = tensorflow.random_normal_initializer(0.,0.02)

    inp = tensorflow.keras.layers.Input(shape=[256,256,3],name='input_image')
    tar = tensorflow.keras.layers.Input(shape=[256,256,3],name='target_image')

    x = tensorflow.keras.layers.concatenate([inp,tar])

    down1 = downs(64,4,False)(x)
    down2 = downs(128,4)(down1)
    down3 = downs(256,4)(down2)

    zero_pad1 = tensorflow.keras.layers.ZeroPadding2D()(down3)
    conv = tensorflow.keras.layers.Conv2D(512,4,strides=1,kernel_initializer=initializer,use_bias=False)(zero_pad1)

    batchnorm1 = tensorflow.keras.layers.BatchNormalization()(conv)

    leaky_relu = tensorflow.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tensorflow.keras.layers.ZeroPadding2D()(leaky_relu)

    last = tensorflow.keras.layers.Conv2D(1,4,strides=1,kernel_initializer=initializer)(zero_pad2)

    return tensorflow.keras.Model(inputs=[inp,tar],outputs=last)

discriminator = buildDescriminator()

In [15]:
loss_object = tensorflow.keras.losses.BinaryCrossentropy(from_logits=True)

LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tensorflow.ones_like(disc_generated_output),disc_generated_output)

    l1_loss = tensorflow.reduce_mean(tensorflow.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tensorflow.ones_like(disc_real_output),disc_real_output)
    generated_loss = loss_object(tensorflow.zeros_like(disc_generated_output),disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

In [16]:
generator_optimizer = tensorflow.keras.optimizers.Adam(2e-4,beta_1=0.5)
discriminator_optimizer = tensorflow.keras.optimizers.Adam(2e-4,beta_1=0.5)

In [17]:
checkpoint_dir = '.\\notebooks\\training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
checkpoint = tensorflow.train.Checkpoint(generator_optimizer=generator_optimizer,
                                                                                discriminator_optimizer=discriminator_optimizer,
                                                                                generator=generator,
                                                                                discriminator=discriminator)

In [18]:
def generate_images(model,test_input,tar):
    prediction = model(test_input,training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0],tar[0],prediction[0]]
    title = ["Input Image", "Ground Truth","Predicted Image"]

    for i in range(3):
        plt.subplot(1,3,i+1)
        plt.title(title[i])
        plt.imshow(display_list[i]*0.5+0.5)
        plt.axis('off')
    plt.show

In [19]:
import datetime
log_dir = "notebooks\\logs\\"
summary_writer = tensorflow.summary.create_file_writer(
    log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    

In [20]:
@tensorflow.function
def train_step(input_image,target,epoch):
    with tensorflow.GradientTape() as gen_tape, tensorflow.GradientTape() as disc_tape:
        
        gen_out = generator(input_image,training=True)

        disc_real_output = discriminator([input_image,target],training=True)
        disc_generated_output = discriminator([input_image,gen_out],training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output,gen_out,target)
        disc_loss = discriminator_loss(disc_real_output,disc_generated_output)
    
    generator_gradients = gen_tape.gradient(gen_total_loss,generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

    with summary_writer.as_default():
        tensorflow.summary.scalar('gen_total_loss',gen_total_loss,step=epoch)
        tensorflow.summary.scalar('gen_gan_loss',gen_gan_loss,step=epoch)
        tensorflow.summary.scalar('gen_l1_loss',gen_l1_loss,step=epoch)
        tensorflow.summary.scalar('disc_loss',disc_loss,step=epoch)

In [22]:
def fit(train_ds,epochs):
    for epoch in range(epochs):
        start = time.time()

        for n, (input_image,target) in train_ds.enumerate():
            print(".",end=" ")
            if (n+1)%100==0:
                print('n+1=',n+1)
            train_step(input_image,target,epoch)
        print()

        if (epoch+1)%5==0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print('Time taken for epoch {} is {} sec\n'.format(epoch+1,time.time()-start))

    checkpoint.save(file_prefix=checkpoint_prefix)

In [23]:
fit(train_dataset,epochs)

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . n+1= tf.Tensor(100, shape=(), dtype=int64)
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . n+1= tf.Tensor(200, shape=(), dtype=int64)
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . n+1= tf.Tensor(300, shape=(), dtype=int64)
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . n+1= tf.Tensor(400, shape=(), dtype=int64)
. . . . . . . . . . . . . . 