# Practical 3: Conditional GANs for image-to-image translation 
---

### **Authors:**

Original tutorial by **Luigi Celona** and **Flavio Piccoli**, modified by **[Nemanja Rakicevic](https://nemanja-rakicevic.github.io/)** and **[Manos Kirtas](https://scholar.google.com/citations?user=EyaKPkwAAAAJ&hl=en)**


### **Tutorial overview:**

In this tutorial you will implement, train and analyse the results of a conditional GAN model for converting building facades to real buildings. This tutorial is adapted from "Image to image translation using conditional GANs", as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004).


We will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.

Each epoch takes around 15 seconds on a single V100 GPU.

Below is the output generated after training the model for 200 epochs.

![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)
![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)


### **Tutorial outline:**
- [Theory recap](#theory-recap)
- [Setup](#setup)
  - Install and Import Packages
  - Dataset
  - Helper Functions
  - Input Pipeline
- [Implementing conditional GAN components](#implement-components)
  - Generator
  - Discriminator
  - Loss functions
- [Training and Visualisation](#training)
  - Train Utils
  - Main Train Loop
  - Run Training
- [Analysis](#analysis)


---

## Theory recap <a class="anchor" id="theory-recap"></a>  

Conditional Generative Adversarial Networks (cGANs) are a type of GAN that can be used to generate images conditioned on some input data. This makes them well-suited for image-to-image translation tasks, where the goal is to translate an image from one domain to another, such as from sketches to photos, or from black and white to color.

cGANs work by training two neural networks against each other: a generator and a discriminator. The generator takes as input the image to be translated, as well as some additional conditioning information, and produces a translated image. The discriminator is trained to distinguish between real images from the target domain and fake images generated by the generator.

![trainings setup](https://github.com/M2Lschool/tutorials2023/raw/main/2_generative/images/img2img_translation.png)
[Image credit [Image-to-image translation with conditional adversarial networks, Isola et al. (2017)](https://arxiv.org/abs/1611.07004)]

During training, the generator and discriminator are pitted against each other in a minimax game. The generator tries to produce images that are indistinguishable from real images, while the discriminator tries to get better at detecting fake images. This adversarial process forces the generator to produce increasingly realistic images.

Once the cGAN is trained, it can be used to translate images from the source domain to the target domain. This is done by simply feeding the generator an image from the source domain, along with any necessary conditioning information. The generator will then produce a translated image in the target domain.

## Setup  <a class="anchor" id="setup"></a>  

### Install and Import Packages

In [None]:
!pip install ipdb &> /dev/null
!pip install git+https://github.com/deepmind/dm-haiku &> /dev/null
!pip install -U tensorboard &> /dev/null
!pip install git+https://github.com/deepmind/optax.git &> /dev/null

In [None]:
import os
import time
import pickle
import functools
import numpy as np

# Dataset libraries.
import tensorflow as tf


import haiku as hk
import jax
import optax  # Package for optimizer.
import jax.numpy as jnp

# Plotting libraries.
from matplotlib import pyplot as plt
from IPython import display

from typing import Mapping, Optional, Tuple, NamedTuple, Any

In [3]:
# @title Hyperparameters

BUFFER_SIZE = 400  #@param
BATCH_SIZE = 1  #@param
IMG_WIDTH = 256  #@param
IMG_HEIGHT = 256  #@param
TRAIN_INIT_RANDOM_SEED = 1729  #@param
LAMBDA = 100  #@param
EPOCHS = 150

### Dataset

You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [None]:
_URL = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz'
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')

### Helper functions

In [None]:
# We need a random key for initialization.
rng = jax.random.PRNGKey(TRAIN_INIT_RANDOM_SEED)

In [None]:
#@title Dataset loading and preprocessing
# We use tensorflow readers; JAX does not have support for input data reading
# and pre-processing.
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]

    w = w // 2
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]

    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

In [None]:
inp, re = load(PATH + 'train/100.jpg')
# Casting to int for matplotlib to show the image.
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)

In [None]:
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                                 method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

In [None]:
# Normalizes the input images to [-1, 1].
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

Random jittering as described in the paper is composed of the following steps:
1. Resize an image to a bigger height and width
2. Randomly crop to the target size
3. Randomly flip the image horizontally

In [None]:
#@title Data augmentation { form-width: "40%"}
@tf.function()
def random_jitter(input_image, real_image):
    # Resizing to 286 x 286 x 3.
    input_image, real_image = resize(input_image, real_image, 286, 286)

    # Randomly cropping to 256 x 256 x 3.
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        # Random mirroring.
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

In [None]:
plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)
    plt.subplot(2, 2, i + 1)
    plt.imshow(rj_inp / 255.0)
    plt.axis('off')
plt.show()

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                     IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

### Input Pipeline

In [None]:
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

## Implementing conditional GAN components <a class="anchor" id="implement-components"></a>  

### Generator  

The architecture of the generator is a modified U-Net [[U-net: Convolutional networks for biomedical image segmentation, Ronneberger et al (2015)](https://arxiv.org/abs/1505.04597)]. The U-Net is an encoder-decoder with skip connections between mirrored layers in the encoder and decoder stacks. The skip connections allow to circumvent the bottleneck that causes the loss of low-level information (e.g. location of prominent edges).


![trainings setup](https://github.com/M2Lschool/tutorials2023/raw/main/2_generative/images/generator_architecture.png)
[Image credit [Image-to-image translation with conditional adversarial networks, Isola et al. (2017)](https://arxiv.org/abs/1611.07004)]

  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout (applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-net)

In [None]:
# Used functions

In [None]:
hk.Conv2D?

In [None]:
hk.BatchNorm?

In [None]:
jax.nn.leaky_relu?

In [None]:
#@title Encoder definition (Conv -> Batchnorm -> Leaky ReLU) { form-width: "40%" }

class Encoder(hk.Module):
    def __init__(self,
                 channels: int,
                 size: int,
                 apply_batchnorm=True):
        super().__init__()
        self.channels = channels
        self.size = size
        self.initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.apply_batchnorm = apply_batchnorm

    def __call__(self, inputs, is_training):
        ##################################################################
        #  YOUR CODE HERE:
        
        # Encoder steps:
        # 1. Apply hk.Conv2D layer (channels, size, stride=2, init, pad='SAME', nobias) to inputs.

        # 2. Apply hk.BatchNorm if flag is active.

        # 3. Apply jax.nn.leaky_relu (negative_slop=0.2) on output.
        ##################################################################
        return out

In [None]:
# Used functions:

In [None]:
hk.Conv2DTranspose?

In [None]:
hk.BatchNorm?

In [None]:
hk.dropout?

In [None]:
#@title Decoder definition (Transposed Conv -> Batchnorm -> Dropout (applied to the first 3 blocks) -> ReLU)  { form-width: "40%" }

class Decoder(hk.Module):
    def __init__(self,
                 channels: int,
                 size: int,
                 apply_dropout=False):
        super().__init__()
        self.initializer = hk.initializers.RandomNormal(mean=0.0,
                                                        stddev=0.02)
        self.channels = channels
        self.size = size
        self.apply_dropout = apply_dropout

    def __call__(self, inputs, is_training):
        ##################################################################
        #  YOUR CODE HERE:
        
        # Decoder steps:
        # 1. Apply transpose conv layer (channels, size, stride=2, init, pad='SAME', nobias) to inputs.
        
        # 2. Apply batch_norm if flag is active.

        # 3. dropout

        # 4. ReLU
        ##################################################################
        return out

In [None]:
class Generator(hk.Module):
    def __init__(self):
        super().__init__()
        # In comment the output size of each block. `bs` is the batch size.
        self.down_stack = [
            Encoder(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
            Encoder(128, 4),  # (bs, 64, 64, 128)
            Encoder(256, 4),  # (bs, 32, 32, 256)
            Encoder(512, 4),  # (bs, 16, 16, 512)
            Encoder(512, 4),  # (bs, 8, 8, 512)
            Encoder(512, 4),  # (bs, 4, 4, 512)
            Encoder(512, 4),  # (bs, 2, 2, 512)
            Encoder(512, 4),  # (bs, 1, 1, 512)
        ]

        self.up_stack = [
            Decoder(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
            Decoder(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
            Decoder(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
            Decoder(512, 4),  # (bs, 16, 16, 1024)
            Decoder(256, 4),  # (bs, 32, 32, 512)
            Decoder(128, 4),  # (bs, 64, 64, 256)
            Decoder(64, 4),  # (bs, 128, 128, 128)
        ]

        initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)
        self.last = hk.Conv2DTranspose(3, 4,
                                       stride=2,
                                       padding='SAME',
                                       w_init=initializer)  # (bs, 256, 256, 3)

    def __call__(self, x, is_training):
        # Downsampling through the model
        skips = []
        for down in self.down_stack:
            x = down(x, is_training)
            ##################################################################
            #  YOUR CODE HERE:
            
            # Add encoder outputs to the list of skips.
            
            ##################################################################


        # Upsampling and establishing the skip connections
        skips = reversed(skips[:-1])
        for up, skip in zip(self.up_stack, skips):
            x = up(x, is_training)
            
            ##################################################################
            #  YOUR CODE HERE:
            
            # Concatenate the skip and the previous step output.
            
            ##################################################################

        x = self.last(x)
        return x

![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)


## Build the Discriminator

The Discriminator is a PatchGAN. It works by classifying small patches of an image as real or fake, rather than classifying the entire image at once. 
This discriminator is run convolutionally across the image, averaging all responses to provide the ultimate output of D. Such a discriminator effectively models the image as a Markov random field, assuming independence between pixels separated by more than a patch diameter. 
This makes it more robust to global changes in the image, such as changes in lighting or color, and it can be understood as a type of texture/style loss.

  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1)
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake.
    * We concatenate these 2 inputs together in the code (`jax.numpy.concatenate([inp, tar], axis=-1)`)

In [None]:
class Discriminator(hk.Module):
    def __init__(self):
        super().__init__()
        initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)

        self.down1 = Encoder(64, 4, apply_batchnorm=False)
        self.down2 = Encoder(128, 4)
        self.down3 = Encoder(256, 4)

        self.conv = hk.Conv2D(512, 4, stride=1, w_init=initializer,
                              padding='VALID', with_bias=False)
        self.bn = hk.BatchNorm(create_scale=True, create_offset=True,
                               decay_rate=0.999, eps=0.001)
        self.last = hk.Conv2D(1, 4, stride=1, padding='VALID',
                              w_init=initializer)

    def __call__(self, x, is_training):  # (bs, 256, 256, channels*2)
        x = self.down1(x, is_training)  # (bs, 128, 128, 64)
        x = self.down2(x, is_training)  # (bs, 64, 64, 128)
        x = self.down3(x, is_training)  # (bs, 32, 32, 256)
        x = jnp.pad(x, ((0, 0), (1, 1), (1, 1), (0, 0)))  # (bs, 34, 34, 256)
        x = self.conv(x)  # (bs, 31, 31, 512)
        x = self.bn(x, is_training)
        x = jax.nn.leaky_relu(x, negative_slope=0.2)
        x = jnp.pad(x, ((0, 0), (1, 1), (1, 1), (0, 0)))  # (bs, 33, 33, 256)
        x = self.last(x)  # (bs, 30, 30, 1)
        return x

### Loss functions 

#### Generator loss

  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
# Computes binary cross entropy for classification.

def bce_w_logits(
    logits: jnp.ndarray,
    target: jnp.ndarray
) -> jnp.ndarray:
    """
    Binary Cross Entropy Loss
    :param logits: Input tensor
    :param target: Target tensor
    :return: Scalar value
    """
    ##################################################################
    #  YOUR CODE HERE:

    # Refer to the first tutorial.
    max_val = jnp.clip(logits, 0, None)
    loss = logits - logits * target + max_val + \
    jnp.log(jnp.exp(-max_val) + jnp.exp((-logits - max_val)))
    ##################################################################

    return jnp.mean(loss)

In [None]:
def generator_loss(
    disc_generated_output: jnp.ndarray,
    gen_output: jnp.ndarray,
    target: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Computes the generator loss for the given batch."""
    ##################################################################
    #  YOUR CODE HERE:

    # Pass the discriminator output as logits and the target is array of ones of the same shape.
    
    ##################################################################

    # Mean absolute error.
    l1_loss = jnp.mean(jnp.abs(target - gen_output))
    
    
    ##################################################################
    #  YOUR CODE HERE:

    # Calculate total generator loss as the GAN loss + scaled L1 loss.
    
    ##################################################################

    return total_gen_loss, gan_loss, l1_loss

#### Discriminator loss
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss


In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = bce_w_logits(disc_real_output,
                             jnp.ones_like(disc_real_output))
    generated_loss = bce_w_logits(disc_generated_output,
                                  jnp.zeros_like(disc_generated_output))
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

The training procedure for the discriminator is shown below.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)


## Model Training and Visualisation <a class="anchor" id="training"></a>  

* For each example input generates an output.
* The discriminator receives the input image and the generated image as the first input. The second input is the input image and the target image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.
* Last, we log the losses to TensorBoard.

### Define the Checkpoint-saver


In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
if not os.path.exists(checkpoint_prefix):
    os.makedirs(checkpoint_prefix)

### Define the main model


In [None]:
class P2PTuple(NamedTuple):
    gen: Any
    disc: Any


class P2PState(NamedTuple):
    params: P2PTuple
    states: P2PTuple
    opt_state: P2PTuple


class Pix2Pix:
    """Pix2Pix model."""

    def __init__(self):
        self.gen_transform = hk.transform_with_state(
            lambda *args: Generator()(*args)
        )
        self.disc_transform = hk.transform_with_state(
            lambda *args: Discriminator()(*args)
        )

        # Build the optimizers.
        self.gen_optimizer = optax.adam(2e-4, b1=0.5, b2=0.999)
        self.disc_optimizer = optax.adam(2e-4, b1=0.5, b2=0.999)

    @functools.partial(jax.jit, static_argnums=0)
    def initial_state(self,
                      rng: jnp.ndarray,
                      batch: Tuple[jnp.ndarray, jnp.ndarray]):
        """Returns the initial parameters and optimize states of the generator.
        """
        rng, gen_rng, disc_rng = jax.random.split(rng, 3)
        gen_params, gen_state = self.gen_transform.init(gen_rng, batch[0], True)
        disc_params, disc_state = \
            self.disc_transform.init(disc_rng,
                                     jnp.concatenate(batch, axis=-1),
                                     True)
        params = P2PTuple(gen=gen_params, disc=disc_params)
        states = P2PTuple(gen=gen_state, disc=disc_state)

        # Initialize the optimizers.
        opt_state = P2PTuple(gen=self.gen_optimizer.init(params.gen),
                             disc=self.disc_optimizer.init(params.disc)
                             )
        return P2PState(params=params, states=states, opt_state=opt_state)

    def generate_images(self,
                        params: P2PTuple,
                        state: P2PTuple,
                        test_input):
        # Note: The `training=True` is intentional here since
        #       we want the batch statistics while running the model
        #       on the test dataset. If we use training=False, we will get
        #       the accumulated statistics learned from the training dataset
        #       (which we don't want)
        prediction, _ = self.gen_transform.apply(
            params, state, None, test_input, True
        )

        return prediction

    def gen_loss(self,
                 gen_params: P2PTuple,
                 gen_state: P2PTuple,
                 batch: Tuple[jnp.ndarray, jnp.ndarray],
                 disc_params: P2PTuple,
                 disc_state: P2PTuple,
                 rng_gen, rng_disc):
        """Computes a regularized loss for the given batch."""

        input, target = batch
        
        # Apply the generator to the input
        output, gen_state = self.gen_transform.apply(
            gen_params, gen_state, rng_gen, input, True
        )

        # Evaluate using the discriminator.
        ##################################################################
        #  YOUR CODE HERE:

        # Apply the disc_transform (like in the step above) to the concatenated input and output.
        # disc_generated_output, disc_state = ...
        
        ##################################################################


        states = P2PTuple(gen=gen_state, disc=disc_state)

        # Compute the loss.
        total_loss, gan_loss, l1_loss = generator_loss(
            disc_generated_output, output, target
            )

        return total_loss, (output, states, gan_loss, l1_loss)

    def disc_loss(self,
                  params: P2PTuple,
                  state: P2PTuple,
                  batch: Tuple[jnp.ndarray, jnp.ndarray],
                  gen_output: jnp.ndarray, rng):
        """Computes a regularized loss for the given batch."""
        input, target = batch
        real_output, state = self.disc_transform.apply(
            params, state, rng, jnp.concatenate([input, target], axis=-1), True
        )

            
        ##################################################################
        #  YOUR CODE HERE:

        # Apply the disc_transform (like in the step above) to the concatenated input and generated output.
        # generated_output, state = ...
        
        ##################################################################

        # Compute discriminator loss.
        loss = discriminator_loss(real_output, generated_output)
        return loss, state

    @functools.partial(jax.jit, static_argnums=0)
    def update(self, rng, p2p_state, batch):
        """ Performs a parameter update. """
        rng, gen_rng, disc_rng = jax.random.split(rng, 3)

        # Update the generator.
        (gen_loss, gen_aux), gen_grads = \
            jax.value_and_grad(self.gen_loss,
                               has_aux=True)(
            p2p_state.params.gen,
            p2p_state.states.gen,
            batch,
            p2p_state.params.disc,
            p2p_state.states.disc,
            gen_rng, disc_rng)

        generated_output, states, gan_loss, l1_loss = gen_aux
        gen_update, gen_opt_state = self.gen_optimizer.update(
            gen_grads, p2p_state.opt_state.gen)
        gen_params = optax.apply_updates(p2p_state.params.gen, gen_update)

        # Update the discriminator.
        (disc_loss, disc_state), disc_grads = \
            jax.value_and_grad(self.disc_loss,
                               has_aux=True)(
            p2p_state.params.disc,
            states.disc,
            batch,
            generated_output,
            disc_rng)

        disc_update, disc_opt_state = self.disc_optimizer.update(
            disc_grads, p2p_state.opt_state.disc)
        disc_params = optax.apply_updates(p2p_state.params.disc, disc_update)

        params = P2PTuple(gen=gen_params, disc=disc_params)
        states = P2PTuple(gen=states.gen, disc=disc_state)
        opt_state = P2PTuple(gen=gen_opt_state, disc=disc_opt_state)
        p2p_state = P2PState(params=params, states=states, opt_state=opt_state)

        return p2p_state, gen_loss, disc_loss, gan_loss, l1_loss

In [None]:
# The model.
net = Pix2Pix()

# Initialize the network and optimizer.
for input, target in train_dataset.take(1):
    net_state = net.initial_state(rng, (jnp.asarray(input),
                                        jnp.asarray(target)))

In [None]:
import datetime
log_dir = "logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

### The training loop

* Iterates over the number of epochs.
* On each epoch it clears the display, and runs `generate_images` to show it's progress.
* On each epoch it iterates over the training dataset, printing a '.' for each example.
* It saves a checkpoint every 20 epochs.

In [None]:
def fit(train_ds, epochs, test_ds, net_state):
    for epoch in range(epochs):
        start = time.time()

        display.clear_output(wait=True)

        for example_input, example_target in test_ds.take(1):
            prediction = net.generate_images(net_state.params.gen,
                                             net_state.states.gen,
                                             jnp.asarray(example_input))
            plt.figure(figsize=(15, 15))

            display_list = [example_input[0], example_target[0], prediction[0]]
            title = ['Input Image', 'Ground Truth', 'Predicted Image']

            for i in range(3):
                plt.subplot(1, 3, i+1)
                plt.title(title[i])
                # Getting the pixel values between [0, 1] to plot it.
                plt.imshow(display_list[i] * 0.5 + 0.5)
                plt.axis('off')
                plt.show()

        print("Epoch: ", epoch)

        # Train loop.
        for n, (input_image, target) in train_ds.enumerate():
            # Take a training step.
            print('.', end='')
            if (n+1) % 100 == 0:
                print()

            # Main update step
            net_state, gen_total_loss, disc_loss, \
            gen_gan_loss, gen_l1_loss = \
                net.update(rng, net_state,
                           (jnp.asarray(input_image), jnp.asarray(target)))

            with summary_writer.as_default():
                tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
                tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
                tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
                tf.summary.scalar('disc_loss', disc_loss, step=epoch)
        
        print()

        # Save (checkpoint) the model every 20 epochs.
        if (epoch + 1) % 20 == 0:
            with open(
                os.path.join(checkpoint_prefix, 'pix2pix_params.pkl'),
                    'wb') as handle:
                pickle.dump(net_state.params, handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

            with open(
                os.path.join(checkpoint_prefix, 'pix2pix_states.pkl'),
                    'wb') as handle:
                pickle.dump(net_state.states, handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        print('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                           time.time()-start))

    # Save the last checkpoint.
    with open(
        os.path.join(checkpoint_prefix, 'pix2pix_params.pkl'),
            'wb') as handle:
        pickle.dump(net_state.params, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(
        os.path.join(checkpoint_prefix, 'pix2pix_states.pkl'),
            'wb') as handle:
        pickle.dump(net_state.states, handle, protocol=pickle.HIGHEST_PROTOCOL)

This training loop saves logs you can easily view in TensorBoard to monitor the training progress. Working locally you would launch a separate tensorboard process. In a notebook, if you want to monitor with TensorBoard it's easiest to launch the viewer before starting the training.

To launch the viewer run the following cell:

In [None]:
#docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir {log_dir}

Now run the training loop:

In [None]:
fit(train_dataset, EPOCHS, test_dataset, net_state)

If you want to share the TensorBoard results _publicly_ you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.

Note: This requires a Google account.

```
!tensorboard dev upload --logdir  {log_dir}
```

Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the "interrupt execution" option in your notebook tool.

You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).

TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.

It can also included inline using an `<iframe>`:

In [None]:
display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

Interpreting the logs from a GAN is more subtle than a simple classification or regression model. Things to look for:

* Check that neither model has "won". If either the `gen_gan_loss` or the `disc_loss` gets very low it's an indicator that this model is dominating the other, and you are not successfully training the combined model.
* The value `log(2) = 0.69` is a good reference point for these losses, as it indicates a perplexity of 2: That the discriminator is on average equally uncertain about the two options.
* For the `disc_loss` a value below `0.69` means the discriminator is doing better than random, on the combined set of real + generated images.
* For the `gen_gan_loss` a value below `0.69` means the generator is doing better than random at fooling the descriminator.
* As training progresses the `gen_l1_loss` should go down.

## Restore the latest checkpoint and test

In [None]:
!ls {checkpoint_dir}

In [None]:
# Restore the latest checkpoint in checkpoint_dir.
with open(
    os.path.join(checkpoint_prefix, 'pix2pix_params.pkl'),
        'rb') as handle:
    params = pickle.load(handle)

with open(
    os.path.join(checkpoint_prefix, 'pix2pix_states.pkl'),
        'rb') as handle:
    states = pickle.load(handle)

## Generate using test dataset

* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output.
* Last step is to plot the predictions and **voila!**

In [None]:
# Run the trained model on a few examples from the test dataset

for test_input, tar in test_dataset.take(5):
    prediction = net.generate_images(
        params.gen, states.gen,jnp.asarray(test_input))
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

## Analysis <a class="anchor" id="analysis"></a> 

1) What is the minimal number of EPOCHS necessary to train for, in order to get meaningful image generation?

2. How does removing skip connection affect the generated image quality?

3. How does changing $\lambda$ of the L1 loss affect the training?