<a href="https://colab.research.google.com/github/ShaneZhong/CircleGAN-And-Pix2Pix/blob/master/CycleGAN_Monet_Painting_to_Photo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CircleGAN Monet Painting to Photo

Update Date: 17 July 2019

Reference:
* https://www.tensorflow.org/datasets/datasets#cycle_gan
* https://towardsdatascience.com/cyclegans-and-pix2pix-5e6a5f0159c4

##### Copyright 2019 The TensorFlow Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Setup - Connect to your Drive to save the model

In [0]:
from google.colab import drive
drive.mount('/content/drive')

# CycleGAN

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/beta/tutorials/generative/cyclegan"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/r2/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This notebook demonstrates unpaired image to image translation using conditional GAN's, as described in [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593), also known as CycleGAN. The paper proposes a method through which we can capture the characteristics of one image domain and figure out how these characteristics could be translated into another image domain, all in the absence of any paired training examples. 

This notebook assumes you are familiar with Pix2Pix, which you can learn about in the [Pix2Pix tutorial](https://www.tensorflow.org/beta/tutorials/generative/pix2pix). The code for CycleGAN is similar, the main difference is an additional loss function, and the use of unpaired training data.

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain. 

This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

## Set up the input pipeline

Install the [tensorflow_examples](https://github.com/tensorflow/examples) package that enables importing of the generator and the discriminator.

In [0]:
# load the github with pre-defined generator and discriminator
!pip install git+https://github.com/tensorflow/examples.git

In [0]:
# using tensorflow 2.0
!pip install tensorflow-gpu==2.0.0-beta1
import tensorflow as tf

In [0]:
# loading tf 2.0 features
from __future__ import absolute_import, division, print_function, unicode_literals

# loading the dataset and pix2pix from github
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

# other libraries
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Disable progress bar
# https://www.tensorflow.org/datasets/api_docs/python/tfds/disable_progress_bar
# https://www.tensorflow.org/beta/guide/data_performance
# autotone decouple the computation of CPU and GPU. Making the pipeline
#   process a lot faster.
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Input Pipeline

This tutorial trains a model to translate from images of Monet's paintings, to photos. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/beta/tutorials/generative/pix2pix#load_the_dataset)
* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

### Pipeline:

Without pipelining, the CPU and the GPU/TPU sit idle much of the time:
![without Pipeline](https://www.tensorflow.org/images/datasets_without_pipelining.png)<br>


With pipelining, idle time diminishes significantly:<br>
![with pipeline](https://www.tensorflow.org/images/datasets_with_pipelining.png)


In [0]:
# https://www.tensorflow.org/datasets/datasets#cycle_gan
dataset, metadata = tfds.load('cycle_gan/monet2photo',
                              with_info=True, as_supervised=True)

# train and test is using the same source
# but later on the train/test dataset are created by
# shuffle().batch() - i.e. randomly select ones from the same source
# Both are already in 256*256*3 shape
train_monet, train_photo = dataset['trainA'], dataset['trainB']
test_monet, test_photo = dataset['testA'], dataset['testB']

In [0]:
type(train_monet)

In [0]:
# Constant - using CAPTAL LETTERS

# https://www.tensorflow.org/api_docs/python/tf/train/shuffle_batch
# BUFFER_SIZE = Minimum number elements in the queue after a dequeue, 
#    used to ensure a level of mixing of elements.
BUFFER_SIZE = 1000
# BATCH_SIZE = The new batch size pulled from the queue.
BATCH_SIZE = 1

# resize to 256*256
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [0]:
def random_crop(image):
  # crop the image to 256*256*3
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

In [0]:
# normalizing the images to [-1, 1]
def normalize(image):
  # Convert the image to float 32
  image = tf.cast(image, tf.float32)
  # 256/127.5 = 2, 0/127.5 = 0.
  image = (image / 127.5) - 1
  return image

In [0]:
def random_jitter(image):
  # resizing to 286 x 286 x 3. Using nearest neighour to resize
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [0]:
# to train image: resize, crop, mirror and normalise
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

In [0]:
# to test image: normalise image.
# since the test dataset is already in 256*256*3 size, no reshape required
def preprocess_image_test(image, label):
  image = normalize(image)
  return image

In [0]:
# map function - map(function, iterables)
def myfunc(n):
  return len(n)

x = map(myfunc, ('apple', 'banana', 'cherry'))
print(x)
print(list(x))

In [0]:
# Apply the procecess_image_train function to train_monet images
# cache() to speed up the loading time
# shuffle().batch():
#     tf.data.Dataset.shuffle(min_after_dequeue).batch(batch_size).
#     BUFFER_SIZE = min_after_dequeue
#     https://www.tensorflow.org/api_docs/python/tf/train/shuffle_batch
# AUTOTUNE parallels CPU and GPU

train_monet = train_monet.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_photo = train_photo.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

# apply preprocess_image_test to test set
test_monet = test_monet.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_photo = test_photo.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

In [0]:
# net(iter()) - iteration after run next()
lst = iter([1,2,3])
print(next(lst))
print(next(lst))

In [0]:
# iterate one at a time
sample_monet = next(iter(train_monet))
sample_photo = next(iter(train_photo))

In [0]:
type(sample_monet)

In [0]:
sample_monet[0]

In [0]:
# normalised photo
sample_photo[0]

In [0]:
# Since the Batch size = 1, only one paint can be visualised
# i.e. sample_monet[0] is Ok, but sample_monet[1] is invalid
plt.subplot(121)
plt.title('Monet')
plt.imshow(sample_monet[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet with random jitter')
plt.imshow(random_jitter(sample_monet[0]) * 0.5 + 0.5)

In [0]:
plt.subplot(121)
plt.title('Photo')
plt.imshow(sample_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Photo with random jitter')
plt.imshow(random_jitter(sample_photo[0]) * 0.5 + 0.5)

## Import and reuse the Pix2Pix models

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:
* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 
* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

![Cyclegan model](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/images/cyclegan_model.png?raw=1)

* X is Monet painting
* Y is photo
* generator_g convert Monet painting to photo
* generator_f convert photo to Monet painting
* discriminator_x check if the input is a Monet Painting
* discriminator_y check if the input is a photo


In [0]:
# RGB channels
OUTPUT_CHANNELS = 3

# import unet generator structure from the github
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

[pix2pix script link](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py)
```
def unet_generator(output_channels, norm_type='batchnorm'):
  """Modified u-net generator model (https://arxiv.org/abs/1611.07004).
  Args:
    output_channels: Output channels
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
  Returns:
    Generator model
  """
  ```
  
  ----
  ```
def discriminator(norm_type='batchnorm', target=True):
  """PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
  Args:
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
    target: Bool, indicating whether target image is an input or not.
  Returns:
    Discriminator model
    
    
  if target:
    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  else:
    return tf.keras.Model(inputs=inp, outputs=last)
  """
  ```

In [0]:
# generator_g convert Monet painting to photo
# generator_f convert photo to Monet painting
to_photo = generator_g(sample_monet)
to_monet = generator_f(sample_photo)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_monet, to_photo, sample_photo, to_monet]
title = ['Monet', 'To Photo', 'Photo', 'To Monet']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [0]:
# discriminator_x check if the input is a Monet Painting
# discriminator_y check if the input is a photo
# PatchGAN is used in discriminators with 30*30 output

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real photo?')
plt.imshow(discriminator_y(sample_photo)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real Monet painting?')
plt.imshow(discriminator_x(sample_monet)[0, ..., -1], cmap='RdBu_r')

plt.show()

In [0]:
# value ranges from -1 to 1
discriminator_y(sample_photo)[0, ..., -1][28]

## Loss functions

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input `x` and the target `y` pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.

The discriminator loss and the generator loss are similar to the ones used in [pix2pix](https://www.tensorflow.org/beta/tutorials/generative/pix2pix#define_the_loss_functions_and_the_optimizer).

In [0]:
# multiply factor to the L1 loss (circle loss)
LAMBDA = 10

In [0]:
# https://www.tensorflow.org/api_docs/python/tf/keras/losses/binary_crossentropy
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def discriminator_loss(real, generated):
  # if the image is real, disc should identify it as 1
  real_loss = loss_obj(tf.ones_like(real), real)
  
  # if the image is generated, disc sould identify it as 0
  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  # sum up
  total_disc_loss = real_loss + generated_loss

  # the total loss divid by two
  return total_disc_loss * 0.5

In [0]:
# for generator, we mark it as sucessful if the generated image is discrimated as 1
# i.e. whether we successfuly fool the disc
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the  original sentence.

In our case:
* X is Monet painting
* Y is photo
* generator_g convert Monet painting to photo
* generator_f convert photo to Monet painting
* discriminator_x check if the input is a Monet Painting
* discriminator_y check if the input is a photo

In cycle consistency loss, 
* Image $X$ is passed via generator $G$ that yields generated image $\hat{Y}$.
* Generated image $\hat{Y}$ is passed via generator $F$ that yields cycled image $\hat{X}$.
* Mean absolute error is calculated between $X$ and $\hat{X}$.

$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$

$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$

We multipy the cycle-consistency loss by LAMBDA afterwards.

![Cycle loss](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/images/cycle_loss.png?raw=1)

In [0]:
# It is important that the generator re-crates something similar to the orignial
# therefore a multiplication (e.g 10 times) is applied.
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

As shown above, generator $G$ is responsible for translating image $X$ to image $Y$. Identity loss says that, if you fed image $Y$ to generator $G$, it should yield the real image $Y$ or something close to image $Y$.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$

In [0]:
# to make sure the generator_g using the real photo
# output something similar to the real photo
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Initialize the optimizers for all the generators and the discriminators.

In [0]:
# Adam is used
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [0]:
# Save your model to GDrive
#checkpoint_path = "./checkpoints/train"
checkpoint_path = "/content/drive/My Drive/Models/CircleGAN/Monet2Photo/checkpoints/"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

## Training

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate. 

In [0]:
EPOCHS = 200

In [0]:
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Even though the training loop looks complicated, it consists of four basic steps:
* Get the predictions.
* Calculate the loss.
* Calculate the gradients using backpropagation.
* Apply the gradients to the optimizer.

### GradientTape
Reference: https://www.tensorflow.org/api_docs/python/tf/GradientTape <br>
Record operations for automatic differentiation.

In [0]:
# Simple GradientTape
x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x**2
dy_dx = g.gradient(y, x) # Will compute to 6.0 (dy/dx)
print(dy_dx)

By default, the resources held by a GradientTape are released as soon as GradientTape.gradient() method is called. To compute multiple gradients over the same computation, create a persistent gradient tape. This allows multiple calls to the gradient() method as resources are released when the tape object is garbage collected. For example:

In [0]:
# GradientTape with persistent
x= tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y
dz_dx = g.gradient(z, x)  # 108.0 (4*x^3 at x = 3)
dy_dx = g.gradient(y, x)  # 6.0
del g  # Drop the reference to the tap

print(dz_dx, dy_dx)

In [0]:
# The use of zip
a = ("John", "Charles", "Mike")
b = ("Jenny", "Christy", "Monica", "Vicky")

x = zip(a, b)

#use the tuple() function to display a readable version of the result:

print(tuple(x))


In [0]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because gen_tape and disc_tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(
      persistent=True) as disc_tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    # create cycled images for cycle loss.
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)
    
    # feed real image to disc, expect to return 1
    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    # feed fake image to disc, expect to return 0
    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the adversarial loss - to see if we fooled the disc
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    # Total generator loss = adversarial loss + cycle loss + identity loss
    total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x) + identity_loss(real_x, same_x)
    total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y) + identity_loss(real_y, same_y)

    # Total discrimator loss = 1 for real and 0 for fake
    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = gen_tape.gradient(total_gen_g_loss, 
                                            generator_g.trainable_variables)
  generator_f_gradients = gen_tape.gradient(total_gen_f_loss, 
                                            generator_f.trainable_variables)
  
  discriminator_x_gradients = disc_tape.gradient(
      disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = disc_tape.gradient(
      disc_y_loss, discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                             generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                             generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(
      zip(discriminator_x_gradients,
      discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(
      zip(discriminator_y_gradients,
      discriminator_y.trainable_variables))

In [0]:
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_monet, train_photo)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_monet) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_monet)

  # Save model every 5 epoch
  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

## Generate using test dataset

In [0]:
# Run the trained model on the test dataset
for inp in test_monet.take(20):
  generate_images(generator_g, inp)

In [0]:
# Run the trained model on the test dataset
for inp in test_photo.take(20):
  generate_images(generator_f, inp)

# Save the images to your drive

### Function to save files to local drive

In [0]:
# Function
def generate_images2(model, test_input, index = 0, folder_dir = "", prefix =""):
  prediction = model(test_input)
    
  fig = plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
  
  # Save image
  
  output_dir = folder_dir + str(index) +"_"+ prefix + "_output.png"
  #output_dir = F"output/test_Monet Painting.png"
  print("Save images to:" + output_dir)

  fig.savefig(output_dir)

### Create a local directory

In [0]:
# delete current output directory
!if [ -d "output" ]; then rm -Rf output; fi

# Create a dir in Colab env
!mkdir -p output

### Run model and save results to a local directory

In [0]:
folder_dir = "output/"
prefix = "M2P_comparison"

# Run the trained model on the test dataset
for index, inp in enumerate(test_monet.take(500)):
  print(index)
  generate_images2(generator_g, inp, index, folder_dir, prefix)

# Copy the output to Google Drive
!cp -r /content/output /content/drive/My\ Drive/Models/CircleGAN/Monet2Photo/
print("Outputs are saved to GDrive.")

In [0]:
folder_dir = "output/"
prefix = "P2M_comparison"

# Run the trained model on the test dataset
for index, inp in enumerate(test_photo.take(500)):
  print(index)
  generate_images2(generator_f, inp, index, folder_dir, prefix)

# Copy the output to Google Drive
!cp -r /content/output /content/drive/My\ Drive/Models/CircleGAN/Monet2Photo/
print("Outputs are saved to GDrive.")

## Alternative - Using PyDrive to upload output to GDrive

In [0]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

In [0]:
# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:
# 2. Check folders in GDrive
file_list = drive.ListFile({'q': "'root' in parents and trashed=false"}).GetList()
for file1 in file_list:       
    print ('title: %s, id: %s' % (file1['title'], file1['id']))
sys.exit()

In [0]:
# 3. Upload local file to GDrive
output_file = drive.CreateFile({'title' : 'test_Monet Painting2.png'})
output_file.SetContentFile('test_Monet Painting.png')
output_file.Upload()
drive.CreateFile({'id': output_file.get('id')})

## Convert your photo to Monte Style
[Tutorial](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/load_data/images.ipynb#scrollTo=LJfkyC_Qkt7A)

### Functions

In [0]:
# Crop to the center of the image
# https://stackoverflow.com/questions/54865717/tensorflow-crop-largest-central-square-region-of-image
def crop_center(image):
    h, w = image.shape[-3], image.shape[-2]
    if h > w:
        cropped_image = tf.image.crop_to_bounding_box(image, (h - w) // 2, 0, w, w)
    else:
        cropped_image = tf.image.crop_to_bounding_box(image, 0, (w - h) // 2, h, h)
    return cropped_image

# Load the image in TF
def load_and_preprocess_image(path, center_crop = True):
  image = tf.io.read_file(path)
  return preprocess_image(image, center_crop)

# Convert image into a normalised [256,256,3] tensor
def preprocess_image(image, center_crop = True):
  image = tf.image.decode_jpeg(image, channels=3)
  
  # crop to the center of the image
  if center_crop: image = crop_center(image)
        
  # Resize the image to 256*256
  image = tf.image.resize(image, [256, 256],
                         method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  # normalise the image
  image = normalize(image)
  
  # Expand to [1,256,256,3] tensor as input of the model
  # Ref: https://www.tensorflow.org/api_docs/python/tf/expand_dims
  image = tf.expand_dims(image, 0)
  
  return image

### Load a single photo to the environment and convert it to tensor

In [0]:
from google.colab import files
files.upload()

In [0]:
# Input your image name here, including the image type.
image_name = "IMG_8644.jpg" #@param 

In [0]:
# Convert image to tensor
image = load_and_preprocess_image(image_name, center_crop = True)

if image.shape == [1, 256, 256, 3]: print("The image is ready to go!")

### Let the model do the magic

In [0]:
folder_dir = ""
prefix = "YOUR_PHOTO"

# Run the trained model on the your image
generate_images2(generator_f, image, 1, folder_dir, prefix)

## Convert multiple photos to Monet filter

In [0]:
import re
files = [f for f in os.listdir('.') if re.match(r'[IMG]+.*\.jpg', f)]
files2 = [f for f in os.listdir('.') if re.match(r'[IMG]+.*\.JPG', f)]
for photos in files2:
  files.append(photos)

print(files)

In [0]:
!mkdir output_your_photo
folder_dir = "output_your_photo"
prefix = "Monet"

for index, image_name in enumerate(files):
  # Convert image to tensor
  image = load_and_preprocess_image(image_name, center_crop = True)
  
  # Run the trained model on the your image
  generate_images2(generator_f, image, index, folder_dir, prefix)

## Next Steps

This tutorial has shown how to implement CycleGAN starting from the generator and discriminator implemented in the [Pix2Pix](https://www.tensorflow.org/beta/tutorials/generative/pix2pix) tutorial. As a next step, you could try using a different dataset from [TensorFlow Datasets](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

You could also train for a larger number of epochs to improve the results, or you could implement the modified ResNet generator used in the [paper](https://arxiv.org/abs/1703.10593) instead of the U-Net generator used here.



Try using a different dataset from . You can also implement the modified ResNet generator used in the [paper](https://arxiv.org/abs/1703.10593) instead of the U-Net generator that's used here.