In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
import tensorflow.keras.models as km
import tensorflow.keras.layers as kl
import tensorflow_addons as tfa
import time

def load_data(name='ukiyoe2photo'):
    # load data from TensorFlow
    dataset, metadata = tfds.load('cycle_gan/' + name, with_info=True, as_supervised=True)
    train_a, train_b = dataset['trainA'], dataset['trainB']
    test_a, test_b = dataset['testA'], dataset['testB']
    return train_a, train_b, test_a, test_b

def generate_and_save_images(model, epoch, test_input,moldename):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(8,8))
    plt.imshow(predictions[0])
    plt.axis('off')
    plt.savefig('image_at_epoch_{:04d}_{}.png'.format(epoch,moldename))
    plt.show()

def generate_images(model, test_input):
    prediction = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

def normalize(image):
    # normalizing the images to [-1, 1]
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def preprocess_image(image,label):
    image = normalize(image)
    return image

def normes(norm):
    if norm == 'none':
        return lambda: lambda x: x
    elif norm == 'batch_norm':
        return kl.BatchNormalization
    elif norm == 'instance_norm':
        return tfa.layers.InstanceNormalization
    elif norm == 'layer_norm':
        return kl.LayerNormalization

# Define Generator architecture
def generator(input_shape=(256, 256, 3), output_channels=3, dim=64, n_downsamplings=2, n_blocks=9, norm='instance_norm'):
    norme = normes(norm)

    def resnet_block(x):
        dim = x.shape[-1]
        h = x

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = kl.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = norme()(h)
        h = tf.nn.relu(h)

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = kl.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = norme()(h)

        return kl.add([x, h])

    # 0
    h = inputs = tf.keras.Input(shape=input_shape)

    # 1, Convolution 64 filtres
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = kl.Conv2D(dim, 7, padding='valid', use_bias=False)(h)
    h = norme()(h)
    h = tf.nn.relu(h)

    # 2, Convolution 128, et 256 filtres
    for _ in range(n_downsamplings):
        dim *= 2
        h = kl.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = norme()(h)
        h = tf.nn.relu(h)

    # 3, Resnet Block (x9)
    for _ in range(n_blocks):
        h = resnet_block(h)

    # 4, Déconvolution 128, 64 filtres
    for _ in range(n_downsamplings):
        dim //= 2
        h = kl.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = norme()(h)
        h = tf.nn.relu(h)

    # 5, Convolution 3 filtres
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = kl.Conv2D(output_channels, 7, padding='valid')(h)
    h = tf.tanh(h)

    return km.Model(inputs=inputs, outputs=h)

def discriminator(input_shape=(256, 256, 3), dim=64, n_downsamplings=3, norm='instance_norm'):
    dim_ = dim
    norme = normes(norm)

    # 0
    h = inputs = tf.keras.Input(shape=input_shape)

    # 1
    h = kl.Conv2D(dim, 4, strides=2, padding='same')(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    for _ in range(n_downsamplings - 1):
        dim = min(dim * 2, dim_ * 8)
        h = kl.Conv2D(dim, 4, strides=2, padding='same', use_bias=False)(h)
        h = norme()(h)
        h = tf.nn.leaky_relu(h, alpha=0.2)

    # 2
    dim = min(dim * 2, dim_ * 8)
    h = kl.Conv2D(dim, 4, strides=1, padding='same', use_bias=False)(h)
    h = norme()(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    # 3
    h = kl.Conv2D(1, 4, strides=1, padding='same')(h)

    return km.Model(inputs=inputs, outputs=h)

# Define losses
LAMBDA = 10  # Additional Weigh for the cycle loss and identity loss

# Generator loss
def gen_loss(generated):
    # Maximise the likehood of generated photo to be considered real, ie 1
    return tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(generated), generated)

# Discriminator Loss
def disc_loss(real, generated):
    # Maximise the likehood of the real photo, ie 1
    # Minimise the likehood of generated photo, ie 0
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(real), real)
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss/2

# Cycle loss
def cycle_loss(real_image, cycled_image):
    # difference between original image an cycled image
    cycl_loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * cycl_loss

# Identity loss
def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return 1 * 0.5 * loss

@tf.function
def train_step(real_a, real_b, gen_a, gen_b, disc_a, disc_b, gen_a_opt, gen_b_opt, disc_a_opt, disc_b_opt):

    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generators
        fake_b = gen_a(real_a, training=True)
        cycled_a = gen_b(fake_b, training=True)
        fake_a = gen_b(real_b, training=True)
        cycled_b = gen_a(fake_a, training=True)

        # same_a and same_b are used for identity loss.
        same_a = gen_a(real_a, training=True)
        same_b = gen_b(real_b, training=True)

        # Discriminator
        disc_real_a = disc_a(real_a, training=True)
        disc_real_b = disc_b(real_b, training=True)
        disc_fake_a = disc_a(fake_a, training=True)
        disc_fake_b = disc_b(fake_b, training=True)

        # calculate the loss
        gen_a_loss = gen_loss(disc_fake_b)
        gen_b_loss = gen_loss(disc_fake_a)

        total_cycle_loss = cycle_loss(real_a, cycled_a) + cycle_loss(real_b, cycled_b)

        # Total generator loss = adversarial loss + cycle loss + identity loss
        total_gen_b_loss = gen_b_loss + total_cycle_loss + identity_loss(real_b, same_b)
        total_gen_a_loss = gen_a_loss + total_cycle_loss + identity_loss(real_a, same_a)

        disc_a_loss = disc_loss(disc_real_a, disc_fake_a)
        disc_b_loss = disc_loss(disc_real_b, disc_fake_b)

    # Calculate the gradients for generator and discriminator
    generator_b_gradients = tape.gradient(total_gen_b_loss, gen_b.trainable_variables)
    generator_a_gradients = tape.gradient(total_gen_a_loss, gen_a.trainable_variables)

    discriminator_a_gradients = tape.gradient(disc_a_loss, disc_a.trainable_variables)
    discriminator_b_gradients = tape.gradient(disc_b_loss, disc_b.trainable_variables)

    # Apply the gradients to the optimizer
    gen_b_opt.apply_gradients(zip(generator_b_gradients, gen_b.trainable_variables))
    gen_a_opt.apply_gradients(zip(generator_a_gradients, gen_a.trainable_variables))
    disc_a_opt.apply_gradients(zip(discriminator_a_gradients, disc_a.trainable_variables))
    disc_b_opt.apply_gradients(zip(discriminator_b_gradients, disc_b.trainable_variables))

# Set Global Variables
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

# Load data
train_a, train_b, test_a, test_b = load_data(name='ukiyoe2photo')

# Transform data
train_a = train_a.map(preprocess_image, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
train_b = train_b.map(preprocess_image, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_a = test_a.map(preprocess_image, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_b = test_b.map(preprocess_image, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

# Instanciate Networks
gen_a = generator((256, 256, 3))
gen_b = generator((256, 256, 3))
disc_a = discriminator((256, 256, 3))
disc_b = discriminator((256, 256, 3))

# Optimiseur
generator_a_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_b_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_a_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_b_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Save check
# Create a checkpoint
checkpoint_path = "./checkpoints"

ckpt = tf.train.Checkpoint(generator_g=gen_a,
                           generator_f=gen_b,
                           discriminator_x=disc_a,
                           discriminator_y=disc_b,
                           generator_g_optimizer=generator_a_optimizer,
                           generator_f_optimizer=generator_b_optimizer,
                           discriminator_x_optimizer=discriminator_a_optimizer,
                           discriminator_y_optimizer=discriminator_b_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')

# Draw a sample
sample_a = next(iter(train_a))
sample_b = next(iter(train_b))

# Train the models
seed = tf.random.normal([1, 256, 256, 3])
for epoch in range(10):
    start = time.time()
    n = 0
    for image_A, image_B in tf.data.Dataset.zip((train_a, train_b)):
        print("trainstep")
        train_step(image_A, image_B, gen_a, gen_b, disc_a, disc_b, generator_a_optimizer,
                   generator_b_optimizer, discriminator_a_optimizer, discriminator_b_optimizer)
        if n % 10 == 0:
            print(n / 10, end=' ')
        n += 1

    # Using a consistent image (sample_horse) so that the progress of the model
    # is clearly visible.
    generate_and_save_images(gen_a, epoch + 1, sample_a, 'A')
    generate_and_save_images(gen_b, epoch + 1, sample_b, 'B')

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
                                                            ckpt_save_path))

    print('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                       time.time() - start))
#
