In [None]:
pip install tensorflow-addons

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

import os
import pathlib
import time
import datetime
import math

from matplotlib import pyplot as plt
from IPython import display

print(tf.__version__)

## Load the dataset

Load your dataset. In Colab you can select other datasets from the drop-down menu. Note that some of the other datasets are significantly larger (`edges2handbags` is 8GB).

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import numpy as np
before_1 = np.load("/content/drive/My Drive/RhinoplasticPaper/Data/Datasets/ds_before_8.npy")
after_1 = np.load("/content/drive/My Drive/RhinoplasticPaper/Data/Datasets/ds_after_8.npy")
before_2 = np.load("/content/drive/My Drive/RhinoplasticPaper/Data/Datasets/ds_before_fem.npy")
after_2 = np.load("/content/drive/My Drive/RhinoplasticPaper/Data/Datasets/ds_after_fem.npy")

before_1 = np.concatenate((before_1, np.zeros((before_1.shape[0], before_1.shape[1], before_1.shape[2],1))), axis = 3)
before_2 = np.concatenate((before_2, np.zeros((before_2.shape[0], before_2.shape[1], before_2.shape[2],1))), axis = 3)
after_1 = np.concatenate((after_1, np.zeros((after_1.shape[0], after_1.shape[1], after_1.shape[2],1))), axis = 3)
after_2 = np.concatenate((after_2, np.zeros((after_2.shape[0], after_2.shape[1], after_2.shape[2],1))), axis = 3)

train_before = np.concatenate((before_1[50:1420, ...], before_2[0:843, ...]), axis = 0)
train_after = np.concatenate((after_1[50:1420, ...], after_2[0:843, ...]), axis = 0)
test_before = np.concatenate((before_1[:50, ...], before_2[843:, ...]), axis = 0)
test_after = np.concatenate((after_1[:50, ...], after_2[843:, ...]), axis = 0)

train_df = np.concatenate((train_before, train_after), axis = 2)
test_df = np.concatenate((test_before, test_after), axis = 2)

train_df = tf.data.Dataset.from_tensor_slices((train_df))
test_df = tf.data.Dataset.from_tensor_slices((test_df))

print(len(train_df))
print(len(test_df))

In [None]:
def load(image):
  w = tf.shape(image)[1]
  w = w // 2
  input_image = image[:, :w, :]
  real_image = image[:, w:, :]

  # Convert both images to float32 tensors
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

Plot a sample of the input (architecture label image) and real (building facade photo) images:

In [None]:
BUFFER_SIZE = 9999
# The batch size of 1 produced better results in the original pix2pix experiment
BATCH_SIZE = 1
IMG_WIDTH = 128
IMG_HEIGHT = 256

In [None]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [None]:
def random_saturation(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  image = tf.image.random_saturation(
      stacked_image, 0.6, 1.6)

  return image[0], image[1]

In [None]:
def random_hue(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  image = tf.image.random_hue(
      stacked_image, 0.075)

  return image[0], image[1]

In [None]:
def random_contrast(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  image = tf.image.random_contrast(
      stacked_image, 0.8, 1.2)

  return image[0], image[1]

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image, last_layer, train):
  if train:
    rs_f1 = 286 #tf.random.uniform(shape=(), minval=256, maxval=286, dtype=tf.int32)
    rs_f2 = 143 #tf.random.uniform(shape=(), minval=128, maxval=138, dtype=tf.int32)
    input_image, real_image = resize(input_image, real_image, rs_f1, rs_f2)
    input_image, real_image = random_saturation(input_image, real_image)
    input_image, real_image = random_contrast(input_image, real_image)

    rn = tf.random.uniform(()) * 15
    noise_1 = tf.random.normal(shape=tf.shape(input_image), mean=0.0, stddev=rn, dtype=tf.float32)
    input_image = tf.add(input_image, noise_1)
    input_image, real_image = random_crop(input_image, real_image)
    input_image = tf.clip_by_value(input_image, 0 , 255)

  if tf.random.uniform(()) > 0.5:
    # Random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)
    ones = tf.ones_like(input_image)
  else:
    ones = tf.zeros_like(input_image)
  ones = ones[..., -1]
  last_layer_1 = last_layer #tf.concat([last_layer, last_layer], axis = -2)
  last_layer_2 = tf.expand_dims(ones, axis = -1)

  return input_image, real_image, last_layer_1, last_layer_2

You can inspect some of the preprocessed output:

In [None]:
def load_image_train(image):
  input_image, real_image = load(image)
  last_layer_input_image = input_image[...,3]
  last_layer_input_image = tf.expand_dims(last_layer_input_image, axis = -1)
  input_image = input_image[...,:3]
  real_image = real_image[...,:3]
  input_image, real_image, last_layer_1, last_layer_2 = random_jitter(input_image, real_image, last_layer_input_image, True)
  input_image, real_image = normalize(input_image, real_image)
  input_image = tf.concat([input_image, last_layer_1, last_layer_2], axis = -1)

  return input_image, real_image

In [None]:
def load_image_test(image):
  input_image, real_image = load(image)

  last_layer_input_image = input_image[...,3]
  last_layer_input_image = tf.expand_dims(last_layer_input_image, axis = -1)
  input_image = input_image[...,:3]
  real_image = real_image[...,:3]
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image, last_layer_1, last_layer_2 = random_jitter(input_image, real_image, last_layer_input_image, False)
  input_image, real_image = normalize(input_image, real_image)
  input_image = tf.concat([input_image, last_layer_1, last_layer_2], axis = -1)

  return input_image, real_image

## Build an input pipeline with `tf.data`

In [None]:
train_dataset = train_df
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

plt.figure(figsize=(24, 24))
it = iter(train_dataset)
for ii in range(0, 32, 2):
  im = next(it)
  plt.subplot(8, 8, ii + 1)
  plt.imshow((im[0][0][...,:3]+1)/2)
  plt.subplot(8, 8, ii + 2)
  plt.imshow((im[1][0]+1)/2)
  plt.axis('off')
plt.show()

In [None]:
try:
  test_dataset = test_df
except tf.errors.InvalidArgumentError:
  test_dataset = test_df
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
def downsample(filters, size, apply_batchnorm=True, strides = (2,2)):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(im[0][0], 0))
print (down_result.shape)

Define the upsampler (decoder):

In [None]:
def upsample(filters, size, apply_dropout=False, strides = 2):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=strides,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

Define the generator with the downsampler and the upsampler:

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 5])

  #norm = tf.norm(inputs, ord=1, axis=3, keepdims=True)

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4, strides = (2,1)),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(1024, 4, apply_dropout=True, strides = (2,1)),  # (batch_size, 2, 2, 1024)
    upsample(1024, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(1024, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(1024, 4),  # (batch_size, 16, 16, 1024)
    upsample(512, 4),  # (batch_size, 32, 32, 512)
    upsample(256, 4),  # (batch_size, 64, 64, 256)
    upsample(128, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Visualize the generator model architecture:

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

Test the generator:

In [None]:
gen_output = generator(im[0][0][tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

### Define the generator loss

GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- The generator loss is a sigmoid cross-entropy loss of the generated images and an **array of ones**.
- The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.
- This allows the generated image to become structurally similar to the target image.
- The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.

In [None]:
LAMBDA = 100

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

The training procedure for the generator is as follows:

![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)


## Build the discriminator

The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.
- The shape of the output after the last layer is `(batch_size, 30, 30, 1)`.
- Each `30 x 30` image patch of the output classifies a `70 x 70` portion of the input image.
- The discriminator receives 2 inputs:
    - The input image and the target image, which it should classify as real.
    - The input image and the generated image (the output of the generator), which it should classify as fake.
    - Use `tf.concat([inp, tar], axis=-1)` to concatenate these 2 inputs together.

Let's define the discriminator:

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 5], name='input_image')
  tar = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 3], name='target_image')

  #norm, _ = tf.linalg.normalize(inp, ord=1, axis=3)

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

Visualize the discriminator model architecture:

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

Test the discriminator:

In [None]:
target = im[0][0][tf.newaxis, ...]
disc_out = discriminator([im[0][0][tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

### Define the discriminator loss

- The `discriminator_loss` function takes 2 inputs: **real images** and **generated images**.
- `real_loss` is a sigmoid cross-entropy loss of the **real images** and an **array of ones(since these are the real images)**.
- `generated_loss` is a sigmoid cross-entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**.
- The `total_loss` is the sum of `real_loss` and `generated_loss`.

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = '/content/drive/My Drive/RhinoplasticPaper/Models/pix2pix/model_12'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
import cv2
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0][...,:3], tar[0][...,:3], prediction[0][...,:3]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    im = cv2.resize(np.float32(display_list[i]), dsize=(128, 256))
    plt.imshow(im * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Test the function:

In [None]:
for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)

## Training

- For each example input generates an output.
- The discriminator receives the `input_image` and the generated image as the first input. The second input is the `input_image` and the `target_image`.
- Next, calculate the generator and the discriminator loss.
- Then, calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.
- Finally, log the losses to TensorBoard.

In [None]:
log_dir="/content/drive/MyDrive/pix2pix/logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)
    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
    tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

The actual training loop. Since this tutorial can run of more than one dataset, and the datasets vary greatly in size the training loop is setup to work in steps instead of epochs.

- Iterates over the number of steps.
- Every 10 steps print a dot (`.`).
- Every 1k steps: clear the display and run `generate_images` to show the progress.
- Every 5k steps: save a checkpoint.

In [None]:
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))
  start = time.time()

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    if (step) % 1000 == 0:
      display.clear_output(wait=True)

      if step != 0:
        print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')

      start = time.time()

      for example_input, example_target in test_dataset.take(1):
        generate_images(generator, example_input, example_target)
      for example_input, example_target in train_dataset.take(1):
        generate_images(generator, example_input, example_target)
      print(f"Step: {step//1000}k")

    train_step(input_image, target, step)

    # Training step
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)


    # Save (checkpoint) the model every 5k steps
    if (step + 1) % 5000 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

This training loop saves logs that you can view in TensorBoard to monitor the training progress.

If you work on a local machine, you would launch a separate TensorBoard process. When working in a notebook, launch the viewer before starting the training to monitor with TensorBoard.

To launch the viewer paste the following into a code-cell:

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

Finally, run the training loop:

In [None]:
fit(train_dataset, test_dataset, steps=500000)

If you want to share the TensorBoard results _publicly_, you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.

Note: This requires a Google account.

```
!tensorboard dev upload --logdir {log_dir}
```

Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the "interrupt execution" option in your notebook tool.

You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).

TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.

It can also included inline using an `<iframe>`:

In [None]:
display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

Interpreting the logs is more subtle when training a GAN (or a cGAN like pix2pix) compared to a simple classification or regression model. Things to look for:

- Check that neither the generator nor the discriminator model has "won". If either the `gen_gan_loss` or the `disc_loss` gets very low, it's an indicator that this model is dominating the other, and you are not successfully training the combined model.
- The value `log(2) = 0.69` is a good reference point for these losses, as it indicates a perplexity of 2 - the discriminator is, on average, equally uncertain about the two options.
- For the `disc_loss`, a value below `0.69` means the discriminator is doing better than random on the combined set of real and generated images.
- For the `gen_gan_loss`, a value below `0.69` means the generator is doing better than random at fooling the discriminator.
- As training progresses, the `gen_l1_loss` should go down.

## Restore the latest checkpoint and test the network

## Generate some images using the test set

In [None]:
# Run the trained model on a few examples from the test set
for inp, tar in test_dataset.take(29):
  generate_images(generator, inp, tar)