In [None]:
import os
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def load_and_preprocess_image(path):
    #load img
    image = tf.io.read_file(path)
    #preprocess
    #image = tf.cond(
     #   tf.image.is_jpeg(image),
      #  lambda: tf.image.decode_jpeg(image, channels=3),
       # lambda: tf.image.decode_png(image, channels=3))
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (64, 64))
    image = (image - 127.5) / 127.5  # [-1, 1]
    
    return image

In [None]:
def load_normalized_dataset(path):
    image_samples_path = list()
    class_count = 1
    img_sample_count = 1
    for image_name in os.listdir(path):
        full_image_name = os.path.join(path, image_name)
        image_samples_path.append(full_image_name)
        img_sample_count += 1
        
    class_count += 1

    random.shuffle(image_samples_path)
    scene_dataset = tf.data.Dataset.from_tensor_slices(image_samples_path)
    scene_dataset = scene_dataset.map(load_and_preprocess_image)
    batch_size = 8
    scene_dataset = scene_dataset.batch(batch_size)

    return scene_dataset

In [None]:
scene_dataset = load_normalized_dataset("../data/Toyota/")
print(list(scene_dataset.as_numpy_iterator()))
scene_dataset

In [None]:
def generator():
    model = keras.Sequential([
        layers.Dense(units=8 * 8 * 256, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((8, 8, 256)),

        layers.Conv2DTranspose(filters=64, kernel_size=(3,3), strides=(1,1), padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.01),

        layers.Conv2DTranspose(filters=128, kernel_size=(3,3), strides=(2,2), padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.01),

        layers.Conv2DTranspose(filters=256, kernel_size=(3,3), strides=(2,2), padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.01),

        layers.Conv2DTranspose(filters=3, kernel_size=(3,3), strides=(2,2), padding="same", use_bias=False,
                               activation="tanh"),
    ])
    return model

In [None]:
gen_model = generator()
gen_model.summary()

In [None]:
def generator_loss(fake_output):
    cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
def generator_optimizer():
    return tf.optimizers.Adam(1e-4)

In [None]:
gen_optimizer = generator_optimizer()

In [None]:
def discriminator():
    model = keras.Sequential([
        layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same',
                      input_shape=[64, 64, 3]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.3),
        
        layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.3),

        layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.3),

        layers.Flatten(),
        layers.Dense(units=1)
    ])
    return model

In [None]:
dis_model = discriminator()
dis_model.summary()

In [None]:
def discriminator_loss(real_output,
                       fake_output):
    cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss

    return total_loss

In [None]:
def discriminator_optimizer():
    return tf.optimizers.Adam(1e-4)

In [None]:
dis_optimizer = discriminator_optimizer()

In [None]:
def train_step(images, gen_model, gen_optimizer, dis_model, dis_optimizer, gen_loss_metric, dis_loss_metric):
    noise = tf.random.normal([8, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = gen_model(noise, training=True)

        real_output = dis_model(images, training=True)
        fake_output = dis_model(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, gen_model.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, dis_model.trainable_variables)

    gen_optimizer.apply_gradients(zip(gradients_of_generator, gen_model.trainable_variables))
    dis_optimizer.apply_gradients(zip(gradients_of_discriminator, dis_model.trainable_variables))

    gen_loss_metric(gen_loss)
    dis_loss_metric(disc_loss)



In [None]:
def test_step(real_images, gen_model, dis_model):
    random_seed = tf.random.normal([8, 100])
    fake_images = gen_model(random_seed, training=False)

    real_dis_prediction = dis_model(real_images)
    fake_dis_prediction = dis_model(fake_images)

    correct = len(real_dis_prediction[real_dis_prediction >= 0.0])
    wrong = len(real_dis_prediction[real_dis_prediction < 0.0])
    real_dis_acc = float(correct) / float(correct + wrong)

    correct = len(fake_dis_prediction[fake_dis_prediction < 0.0])
    wrong = len(fake_dis_prediction[fake_dis_prediction >= 0.0])
    fake_dis_acc = float(correct) / float(correct + wrong)

    combined_dis_acc = (real_dis_acc + fake_dis_acc) / 2

    return real_dis_acc, fake_dis_acc, combined_dis_acc

In [None]:
def loss_and_accuracy(gen_loss_metric, epoch, dis_loss_metric, real_dis_acc, fake_dis_acc, combined_dis_acc):
    print('Loss')
    print('Generator:{}'.format(gen_loss_metric.result()))
    print('Discriminator:{}'.format(dis_loss_metric.result()))
    print('Accuracy')
    print('Real Discriminator:{}'.format(real_dis_acc))
    print('Fake Discriminator:{}'.format(fake_dis_acc))
    print('Combined Discriminator:{}'.format(combined_dis_acc))
    gen_loss_metric.reset_states()
    dis_loss_metric.reset_states()

In [None]:
def plot_graph(gen_loss_values, dis_loss_values, real_disc_acc_values, fake_disc_acc_values, combined_disc_acc_values):
    fig, ax = plt.subplots(figsize=(14, 4))
    plt.plot(gen_loss_values, marker='o', linestyle='-', color='b')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Generator loss')
    plt.savefig('../res/Toyota2/metrics/generator_loss.png')
    plt.show()

    fig, ax = plt.subplots(figsize=(14, 4))
    plt.plot(dis_loss_values, marker='o', linestyle='-', color='b')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Discriminator loss')
    plt.savefig('../res/Toyota2/metrics/discriminator_loss.png')
    plt.show()

    fig, ax = plt.subplots(figsize=(14, 4))
    plt.plot(real_disc_acc_values, marker='o', linestyle='-', color='b')
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Real discriminator accuracy')
    plt.savefig('../res/Toyota2/metrics/real_accuracy.png')
    plt.show()

    fig, ax = plt.subplots(figsize=(14, 4))
    plt.plot(fake_disc_acc_values, marker='o', linestyle='-', color='b')
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Fake discriminator accuracy')
    plt.savefig('../res/Toyota2/metrics/fake_accuracy.png')
    plt.show()

    fig, ax = plt.subplots(figsize=(14, 4))
    plt.plot(combined_disc_acc_values, marker='o', linestyle='-', color='b')
    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Combined discriminator accuracy')
    plt.savefig('../res/Toyota2/metrics/combined_accuracy.png')
    plt.show()

In [None]:
def generate_and_save_images(model, epoch):
    test_input = tf.random.normal([9, 100])
    predictions = gen_model(test_input, training=False)
    fig = plt.figure(figsize=(3, 3))
    for i in range(predictions.shape[0]):
        plt.subplot(3, 3, i+1)
        img = np.array(predictions[i]) * 127.5 + 127.5
        img = img.astype(np.uint8, copy=False)
        plt.imshow(img)
        plt.axis('off')
    plt.savefig('../res/Toyota2/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
def train(real_image_dataset,  epochs, gen_model, gen_optimizer, dis_model, dis_optimizer):
    gen_loss_values = []
    dis_loss_values = []
    real_disc_acc_values = []
    fake_disc_acc_values = []
    combined_disc_acc_values = []
    
    gen_loss_metric = keras.metrics.Mean('train_loss', dtype=tf.float32)
    dis_loss_metric = keras.metrics.Mean('train_loss', dtype=tf.float32)
    for epoch in range(epochs):
        print('Epoch {}'.format(epoch))
        gen_loss = dis_loss = real_acc = fake_acc = comb_acc = 0
        for image_batch in real_image_dataset:
            train_step(image_batch, gen_model, gen_optimizer, dis_model, dis_optimizer, gen_loss_metric, dis_loss_metric)
            real_dis_acc, fake_dis_acc, combined_dis_acc = test_step(image_batch, gen_model, dis_model)
            loss_and_accuracy(gen_loss_metric, epoch, dis_loss_metric, real_dis_acc, fake_dis_acc, combined_dis_acc)
            gen_loss += gen_loss_metric.result()
            dis_loss += dis_loss_metric.result()
            real_acc += real_dis_acc
            fake_acc += fake_dis_acc
            comb_acc += combined_dis_acc
        gen_loss_values.append(gen_loss/len(real_image_dataset))
        dis_loss_values.append(dis_loss/len(real_image_dataset))
        real_disc_acc_values.append(real_acc/len(real_image_dataset))
        fake_disc_acc_values.append(fake_acc/len(real_image_dataset))
        combined_disc_acc_values.append(comb_acc/len(real_image_dataset))
        #generate_and_save_images(gen_model, epoch + 1)
    plot_graph(gen_loss_values, dis_loss_values, real_disc_acc_values, fake_disc_acc_values, combined_disc_acc_values)
    

In [None]:
train(real_image_dataset=scene_dataset,epochs=5000,gen_model=gen_model, gen_optimizer=gen_optimizer, dis_model=dis_model,
            dis_optimizer=dis_optimizer)

In [None]:
# import imageio
# import glob
# import PIL
# def display_image(epoch_no):
#   return PIL.Image.open('../res/Toyota2/image_at_epoch_{:04d}.png'.format(epoch_no))

In [None]:
#display_image(2)

In [None]:
# anim_file = '../res/Toyota2/gan.gif'

# with imageio.get_writer(anim_file, mode='I') as writer:
#   filenames = glob.glob('../res/Toyota2/image*.png')
#   filenames = sorted(filenames)
#   for filename in filenames:
#     image = imageio.v2.imread(filename)
#     writer.append_data(image)
#   image = imageio.v2.imread(filename)
#   writer.append_data(image)

In [None]:
# from PIL import Image
# im = Image.open('../res/Toyota2/gan.gif')
# im