##### Copyright 2019 The TensorFlow Authors.

## Set up the input pipeline

Install the [tensorflow_examples](https://github.com/tensorflow/examples) package that enables importing of the generator and the discriminator.

In [None]:
pip install git+https://github.com/tensorflow/examples.git

In [None]:
!pip install keras

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from zipfile import ZipFile
from PIL import Image

from tensorflow import keras
from tensorflow.keras import layers
# import tensorflow_addons as tfa

AUTOTUNE = tf.data.AUTOTUNE

import numpy as np, pandas as pd, os
import matplotlib.pyplot as plt, cv2
import tensorflow as tf, re, math

## Input Pipeline

This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/catalog/cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#load_the_dataset)

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e., left to right.

In [None]:
# import os
# import random
# import tensorflow_datasets as tfds

# # Converting Kaggle dataset to TFDS format

# landscape_dir = '/kaggle/input/landscape-pictures/'
# vangogh_dir = '/kaggle/input/van-gogh-paintings/VincentVanGogh/'

# # Check if the directories exist
# if not os.path.exists(landscape_dir):
#     raise FileNotFoundError(f"Directory '{landscape_dir}' not found.")
# if not os.path.exists(vangogh_dir):
#     raise FileNotFoundError(f"Directory '{vangogh_dir}' not found.")

# train_ratio = 0.8

# landscape_images = [os.path.join(landscape_dir, filename) for filename in os.listdir(landscape_dir)]

# vangogh_images = [os.path.join(vangogh_dir, filename) for filename in os.listdir(vangogh_dir)]

# # Randomly shuffle images
# random.shuffle(landscape_images)
# random.shuffle(vangogh_images)

# # Split the images into training and testing sets
# num_train_ln = int(len(landscape_images) * train_ratio)
# num_train_vg = int(len(vangogh_images) * train_ratio)

# train_ln_files = landscape_images[:num_train_ln]
# test_ln_files = landscape_images[num_train_ln:]

# train_vg_files = vangogh_images[:num_train_vg]
# test_vg_files = vangogh_images[num_train_vg:]


In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

dataset, metadata = tfds.load('cycle_gan/vangogh2photo', with_info=True, as_supervised=True)

train_vg, train_ln = dataset['trainA'], dataset['trainB']
test_vg, test_ln = dataset['testA'], dataset['testB']


# kaggle_dataset = pd.read_csv('/kaggle/input/van-gogh-paintings/VanGoghPaintings.csv')

# # Split the Kaggle dataset into train and test sets
# train_kaggle, test_kaggle = train_test_split(kaggle_dataset, test_size=0.2, random_state=42)

# # Convert the train and test sets into TensorFlow Datasets
# tf_train_kaggle = tf.data.Dataset.from_tensor_slices(train_kaggle)
# tf_test_kaggle = tf.data.Dataset.from_tensor_slices(test_kaggle)

# # Combine the Kaggle dataset with train_vg and test_vg
# train_vg_combined = train_vg.concatenate(tf_kaggle_dataset)
# test_vg_combined = test_vg.concatenate(tf_kaggle_dataset)



# # Select only the image paths from the DataFrame
# image_paths = kaggle_dataset['image_path'].tolist()

# # Split the image paths into train and test sets
# train_paths, test_paths = train_test_split(image_paths, test_size=0.2, random_state=42)

# # Convert the train and test sets into TensorFlow Datasets
# tf_train_kaggle = tf.data.Dataset.from_tensor_slices(train_paths)
# tf_test_kaggle = tf.data.Dataset.from_tensor_slices(test_paths)

# # Combine the Kaggle dataset with train_vg and test_vg
# train_vg_combined = train_vg.concatenate(tf_train_kaggle)
# test_vg_combined = test_vg.concatenate(tf_test_kaggle)



In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

In [None]:
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

In [None]:
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image


# def random_jitter(image):
#     # Randomly resize the image
#     image = tf.image.resize(image, [286, 286])

#     # Random crop to 256x256
#     image = tf.image.random_crop(image, size=[256, 256, 3])

#     # Randomly flip the image horizontally
#     image = tf.image.random_flip_left_right(image)

#     return image


In [None]:
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

def preprocess_image_test(image, label):
  image = normalize(image)
  return image

In [None]:
def preprocess_tfdataset_train(dataset):
    return dataset.map(lambda x, y: x).map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def preprocess_tfdataset_test(dataset):
    return dataset.map(lambda x, y: x).map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Functions for loading in the flickr scraped cityscape images
def process_path_train(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = preprocess_image_train(img)
    return img

def process_path_test(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = preprocess_image_test(img)
    return img

In [None]:
# # cityscape train, test preprocessing
# train_ln = train_ln.map(process_path_train, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)
# test_ln = test_ln.map(process_path_test, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

# # train_pic = preprocess_tfdataset_train(train_pic)
# # test_pic = preprocess_tfdataset_test(test_pic)
# train_vg = preprocess_tfdataset_train(train_vg)
# test_vg = preprocess_tfdataset_test(test_vg)

In [None]:
# train_ln = tf.data.Dataset.from_tensor_slices(train_ln_files)
# test_ln = tf.data.Dataset.from_tensor_slices(test_ln_files)
# train_vg = tf.data.Dataset.from_tensor_slices(train_vg_files)
# test_vg = tf.data.Dataset.from_tensor_slices(test_vg_files)

train_vg = train_vg.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_ln = train_ln.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_vg = test_vg.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_ln = test_ln.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)


# # Apply preprocessing functions to the datasets
# train_vg = train_vg.map(lambda x: preprocess_image_train(x, label), num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# train_ln = train_ln.map(lambda x: preprocess_image_train(x, label), num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# test_vg = test_vg.map(lambda x: preprocess_image_test(x, label), num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# test_ln = test_ln.map(lambda x: preprocess_image_test(x, label), num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
sample_vg = next(iter(train_vg))
sample_ln = next(iter(train_ln))

In [None]:
plt.subplot(121)
plt.title('Van Gogh Painting')
plt.imshow(sample_vg[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Van Gogh Painting with random jitter')
plt.imshow(random_jitter(sample_vg[0]) * 0.5 + 0.5)

In [None]:
plt.subplot(121)
plt.title('Landscape')
plt.imshow(sample_ln[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Landscape with random jitter')
plt.imshow(random_jitter(sample_ln[0]) * 0.5 + 0.5)

## Import and reuse the Pix2Pix models

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

![Cyclegan model](images/cyclegan_model.png)

In [None]:
# # Making the functions
# def _get_norm_layer(norm):
#   if norm == "none":
#     return lambda: lambda x: x
#   elif norm == "batch_norm":
#     return tf.keras.layers.BatchNormalization
#   elif norm == "instance_norm":
#     return tfa.layers.InstanceNormalization
#   elif norm == "layer_norm":
#     return tf.keras.layers.LayerNormalization

In [None]:
# # Making the resnet block
# def ResnetGenerator(input_shape=(256, 256, 3), output_channels=3, dim=64,
#                     n_downsamplings=2, n_blocks=9, norm='instance_norm'):
#     Norm = _get_norm_layer(norm)
    
#     def _residual_block(x):
#         dim = x.shape[-1]
#         h = x
#         h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
#         h = tf.keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
#         h = Norm()(h)
#         h = tf.nn.relu(h)
#         h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
#         h = tf.keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
#         h = Norm()(h)
#         return tf.keras.layers.add([x, h])
# # 0
#     h = inputs = tf.keras.Input(shape=input_shape)
# # 1
#     h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
#     h = tf.keras.layers.Conv2D(dim, 7, padding='valid', use_bias=False)(h)
#     h = Norm()(h)
#     h = tf.nn.relu(h)
# # 2
#     for _ in range(n_downsamplings):
#         dim *= 2
#         h = tf.keras.layers.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(h)
#         h = Norm()(h)
#         h = tf.nn.relu(h)
# # 3
#     for _ in range(n_blocks):
#         h = _residual_block(h)
# # 4
#     for _ in range(n_downsamplings):
#         dim //= 2
#         h = tf.keras.layers.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(h)
#         h = Norm()(h)
#         h = tf.nn.relu(h)
# # 5
#     h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
#     h = tf.keras.layers.Conv2D(output_channels, 7, padding='valid')(h)
#     h = tf.tanh(h)
    
#     return tf.keras.Model(inputs=inputs, outputs=h)

In [None]:
# def ConvDiscriminator(input_shape=(256, 256, 3), dim=64, n_downsamplings=3, norm='instance_norm'):
#     dim_ = dim
#     Norm = _get_norm_layer(norm)
# # 0
#     h = inputs = tf.keras.Input(shape=input_shape)
# # 1
#     h = tf.keras.layers.Conv2D(dim, 4, strides=2, padding='same')(h)
#     h = tf.nn.leaky_relu(h, alpha=0.2)
#     for _ in range(n_downsamplings - 1):
#         dim = min(dim * 2, dim_ * 8)
#         h = tf.keras.layers.Conv2D(dim, 4, strides=2, padding='same', use_bias=False)(h)
#         h = Norm()(h)
#         h = tf.nn.leaky_relu(h, alpha=0.2)
#     # 2
#         dim = min(dim * 2, dim_ * 8)
#         h = tf.keras.layers.Conv2D(dim, 4, strides=1, padding='same', use_bias=False)(h)
#         h = Norm()(h)
#         h = tf.nn.leaky_relu(h, alpha=0.2)
#     # 3
#         h = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same')(h)
#     return tf.keras.Model(inputs=inputs, outputs=h)

In [None]:
# !pip install keras

In [None]:
# from mediapipe_model_maker import image_classifier

In [None]:
OUTPUT_CHANNELS = 3

# import tensorflow_addons as tfa

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

# ## Building the Resnet Generator as per the cycle gan paper
# generator_g = ResnetGenerator()
# generator_f = ResnetGenerator()
# ## Building the Discriminator as per the cucle gan paper
# discriminator_x = ConvDiscriminator()
# discriminator_y = ConvDiscriminator()

In [None]:
to_vg = generator_g(sample_vg)
# to_ln = generator_f(sample_ln)
plt.figure(figsize=(8, 8))
contrast = 8

# imgs = [sample_vg, to_ln, sample_ln, to_vg]
# title = ['Van Gogh Painting', 'To Landscape', 'Landscape', 'To Van Gogh Painting']

imgs = [sample_ln, to_vg]
title = ['Landscape', 'To Van Gogh Painting']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real Landscape?')
plt.imshow(discriminator_y(sample_ln)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real Van Gogh Painting?')
plt.imshow(discriminator_x(sample_vg)[0, ..., -1], cmap='RdBu_r')

plt.show()

## Loss functions

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input `x` and the target `y` pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.

The discriminator loss and the generator loss are similar to the ones used in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#build_the_generator).

In [None]:
LAMBDA = 12
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)


def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1


def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the  original sentence.

In cycle consistency loss, 

* Image $X$ is passed via generator $G$ that yields generated image $\hat{Y}$.
* Generated image $\hat{Y}$ is passed via generator $F$ that yields cycled image $\hat{X}$.
* Mean absolute error is calculated between $X$ and $\hat{X}$.

$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$

$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$


![Cycle loss](images/cycle_loss.png)

In [None]:
# def calc_cycle_loss(real_image, cycled_image):
#   loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
#   return LAMBDA * loss1

As shown above, generator $G$ is responsible for translating image $X$ to image $Y$. Identity loss says that, if you fed image $Y$ to generator $G$, it should yield the real image $Y$ or something close to image $Y$.

If you run the zebra-to-horse model on a horse or the horse-to-zebra model on a zebra, it should not modify the image much since the image already contains the target class.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$

In [None]:
# def identity_loss(real_image, same_image):
#   loss = tf.reduce_mean(tf.abs(real_image - same_image))
#   return LAMBDA * 0.5 * loss

Initialize the optimizers for all the generators and the discriminators.

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [None]:
EPOCHS = 100

def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Even though the training loop looks complicated, it consists of four basic steps:

* Get the predictions.
* Calculate the loss.
* Calculate the gradients using backpropagation.
* Apply the gradients to the optimizer.

In [None]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [None]:
for epoch in range(EPOCHS // 5):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_ln, train_vg)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  #clear_output(wait=True)
  if epoch % 50 == 0: print("EPOCH:", epoch)
  generate_images(generator_g, sample_ln)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

## Generate using test dataset

In [None]:
for inp in test_ln.take(EPOCHS // 5):
    generate_images(generator_g, inp)