In [None]:
import tensorflow as tf

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import datetime
import matplotlib.pyplot as plt
from IPython.display import clear_output
import numpy as np
import pathlib
import random
import cv2

import resnet_network
import image_preprocess
import network_losses

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 2
IMG_WIDTH = 512
IMG_HEIGHT = 512
EPOCHS = 5

STEPS_TO_FILL_POOL = 10
FAKE_POOL_SIZE = BATCH_SIZE*STEPS_TO_FILL_POOL
FAKE_POOL_PROBABILITY = 0.8

AUTOTUNE = tf.data.AUTOTUNE

load_checkpoint_path = 'directory_for_loading_checkpoints'
checkpoint_path = f'directory_for_saving_checkpoints'
log_dir = 'directory_for_saving_logs'
PATH = pathlib.Path(r'data_directory')
summary_writer = tf.summary.create_file_writer(log_dir+datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))

## Load data from folder and setup dataset pipeline

In [None]:
SEED = 2

train_x_files = tf.data.Dataset.list_files(str(PATH/'trainA/*.png'), shuffle=False)
train_y_files = tf.data.Dataset.list_files(str(PATH/'trainB/*.png'), shuffle=False)
test_x_files = tf.data.Dataset.list_files(str(PATH/'testA/*.png'), shuffle=False)
test_y_files = tf.data.Dataset.list_files(str(PATH/'testB/*.png'), shuffle=False)

train_x = train_x_files.cache().map(image_preprocess.load_image, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE, seed=SEED).batch(BATCH_SIZE)
train_y = train_y_files.cache().map(image_preprocess.load_image, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE, seed=SEED).batch(BATCH_SIZE)
test_x = test_x_files.cache().map(image_preprocess.load_image, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE, seed=SEED).batch(BATCH_SIZE)
test_y = test_y_files.cache().map(image_preprocess.load_image, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE, seed=SEED).batch(BATCH_SIZE)

print(f'Number of training images: {len(train_x)*BATCH_SIZE}')
print(f'Number of test images: {len(test_x)*BATCH_SIZE}')

## Define CycleGAN class and create a cyclegan

In [None]:
class CycleGAN():
    def __init__(self):
        super(CycleGAN, self).__init__()
        
        self.generator_g = resnet_network.build_generator_resnet_9blocks(skip=False)
        self.generator_f = resnet_network.build_generator_resnet_9blocks(skip=False)

        self.discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
        self.discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

        self.generator_g_optimizer = tf.keras.optimizers.Adam(2e-04, beta_1=0.5)
        self.generator_f_optimizer = tf.keras.optimizers.Adam(2e-04, beta_1=0.5)

        self.discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-04, beta_1=0.5)
        self.discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-04, beta_1=0.5)

    def call(self, real_x, real_y, fake_pool_x_batch=None, fake_pool_y_batch=None, training=False, fake_pool=False):
        fake_y = self.generator_g(real_x, training=training)
        cycled_x = self.generator_f(fake_y, training=training)

        fake_x = self.generator_f(real_y, training=training)
        cycled_y = self.generator_g(fake_x, training=training)

        disc_real_x = self.discriminator_x(real_x, training=training)
        disc_real_y = self.discriminator_y(real_y, training=training)

        if fake_pool: 
            disc_fake_x = self.discriminator_x(fake_pool_x_batch, training=training)
            disc_fake_y = self.discriminator_y(fake_pool_y_batch, training=training)
        else:
            disc_fake_x = self.discriminator_x(fake_x, training=training)
            disc_fake_y = self.discriminator_y(fake_y, training=training)

        return fake_x, fake_y, cycled_x, cycled_y, disc_real_x, disc_real_y, disc_fake_x, disc_fake_y

        
    def losses(self, real_x, real_y, fake_pool_x_batch, fake_pool_y_batch, fake_pool=False):

        __, __, cycled_x, cycled_y, disc_real_x, disc_real_y, disc_fake_x, disc_fake_y = self.call(real_x, real_y,
                                                                                                       fake_pool_x_batch,
                                                                                                       fake_pool_y_batch,
                                                                                                       fake_pool=fake_pool)
        gen_g_adver_loss = network_losses.generator_loss(disc_fake_y)
        gen_f_adver_loss = network_losses.generator_loss(disc_fake_x)

        total_cycle_loss = network_losses.calc_cycle_loss(real_x, cycled_x) + network_losses.calc_cycle_loss(real_y, cycled_y)

        gen_g_total_loss = gen_g_adver_loss + total_cycle_loss
        gen_f_total_loss = gen_f_adver_loss + total_cycle_loss

        disc_x_loss = network_losses.discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = network_losses.discriminator_loss(disc_real_y, disc_fake_y)

        return gen_g_total_loss, gen_f_total_loss, disc_x_loss, disc_y_loss

    def generate_images(self, real_x, real_y, training=False):
        fake_y = self.generator_g(real_x, training=training)
        fake_x = self.generator_f(real_y, training=training)

        return fake_x, fake_y

    @tf.function
    def train_step(self, real_x, real_y, step, fake_pool_x_batch=None, fake_pool_y_batch=None, fake_pool=False):
        with tf.GradientTape(persistent=True) as tape:
            gen_g_total_loss, gen_f_total_loss, disc_x_loss, disc_y_loss = self.losses(real_x, real_y,
                                                                                        fake_pool_x_batch, fake_pool_y_batch,
                                                                                        fake_pool=fake_pool)
        
        fake_x, fake_y = self.generate_images(real_x, real_y, training=True)

        generator_g_gradient = tape.gradient(gen_g_total_loss, self.generator_g.trainable_variables)
        generator_f_gradient = tape.gradient(gen_f_total_loss, self.generator_f.trainable_variables)

        discriminator_x_gradient = tape.gradient(disc_x_loss, self.discriminator_x.trainable_variables)
        discriminator_y_gradient = tape.gradient(disc_y_loss, self.discriminator_y.trainable_variables)

        self.generator_g_optimizer.apply_gradients(zip(generator_g_gradient, self.generator_g.trainable_variables))
        self.generator_f_optimizer.apply_gradients(zip(generator_f_gradient, self.generator_f.trainable_variables))

        self.discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradient, self.discriminator_x.trainable_variables))
        self.discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradient, self.discriminator_y.trainable_variables))

        with summary_writer.as_default():
            tf.summary.scalar('X TO Y generator: Total loss', gen_g_total_loss, step=step)
            tf.summary.scalar('Y TO X generator: Total loss', gen_f_total_loss, step=step)
            tf.summary.scalar('X TO Y: Discriminator loss', disc_x_loss, step=step)
            tf.summary.scalar('Y TO X: Discriminator loss', disc_y_loss, step=step)
        
        return fake_x, fake_y

In [None]:
fake_pool_x = []
fake_pool_y = []

cyclegan = CycleGAN()

In [None]:
## load current model
if not os.path.exists(checkpoint_path):
    os.mkdir(checkpoint_path)

ckpt = tf.train.Checkpoint(generator_f = cyclegan.generator_f,
                           generator_g = cyclegan.generator_g,
                           discriminator_x = cyclegan.discriminator_x,
                           discriminator_y = cyclegan.discriminator_y,
                           generator_f_optimizer = cyclegan.generator_f_optimizer,
                           generator_g_optimizer = cyclegan.generator_g_optimizer,
                           discriminator_x_optimizer = cyclegan.discriminator_x_optimizer,
                           discriminator_y_optimizer = cyclegan.discriminator_y_optimizer)

# load previous checkpoint
load_ckpt_manager = tf.train.CheckpointManager(ckpt, load_checkpoint_path, max_to_keep=5)

if load_ckpt_manager.latest_checkpoint:
    ckpt.restore(load_ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored')
else:
    print('Train from scratch')

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

## Start training process

In [None]:
for epoch in range(EPOCHS):
    start = time.time()

    for n, (image_x, image_y) in tf.data.Dataset.zip((train_x, train_y)).enumerate():
        if n < STEPS_TO_FILL_POOL and epoch == 0:
            fake_x, fake_y = cyclegan.train_step(image_x, image_y, n+epoch*len(train_x),
                                        fake_pool=False)

            fake_pool_x.append(fake_x)
            fake_pool_y.append(fake_y)
        else:
            p = tf.random.uniform(shape=[])
            if p>FAKE_POOL_PROBABILITY:
                __, __ = cyclegan.train_step(image_x, image_y, n+epoch*len(train_x),
                                    fake_pool=False)
            else:
                rd_ind_x = random.randint(0, STEPS_TO_FILL_POOL-1)
                rd_ind_y = random.randint(0, STEPS_TO_FILL_POOL-1)
                fake_pool_x_batch = fake_pool_x[rd_ind_x]
                fake_pool_y_batch = fake_pool_y[rd_ind_y]

                fake_x, fake_y = cyclegan.train_step(image_x, image_y, n+epoch*len(train_x),
                                            fake_pool_x_batch, fake_pool_y_batch,
                                            fake_pool=True)

                fake_pool_x[rd_ind_x] = fake_x
                fake_pool_y[rd_ind_y] = fake_y


        if n%10 == 0:
            print('.', end='')

        
        if n%200 == 0:
            clear_output(wait=True)
            print(f'current batch: {n} out of {len(train_x)} batches, epoch {epoch+1}')
            
            image_preprocess.generate_images(cyclegan.generator_g, image_x)
            image_preprocess.generate_images(cyclegan.generator_f, image_y)

    if (epoch + 1) % 1 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                            ckpt_save_path))

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                    time.time()-start))