<a href="https://colab.research.google.com/github/Pragna235/ACM-Winter-School-2023-Hands-on-Labs/blob/main/Lab_8_TF_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Colab Prelims


### Steps to run this notebook

This notebook should be run in Colaboratory. If you are viewing this from GitHub, follow the GitHub instructions. If you are viewing this from Colaboratory, you should skip to the Colaboratory instructions.

#### Steps from GitHub

1. Navigate your web brower to the main Colaboratory website: https://colab.research.google.com.
1. Click the `GitHub` tab.
1. In the field marked `Enter a GitHub URL or search by organization or user`, put in the URL of this notebook in GitHub and click the magnifying glass icon next to it.
1. Run the notebook in colaboratory by following the instructions below.

#### Steps from Colaboratory

This colab will run much faster on GPU. To use a Google Cloud
GPU:

1. Go to `Runtime > Change runtime type`.
1. Click `Hardware accelerator`.
1. Select `GPU` and click `Save`.
1. Click `Connect` in the upper right corner and select `Connect to hosted runtime`.

In [None]:
# Check that imports for the rest of the file work.
import tensorflow.compat.v1 as tf
!pip install tensorflow-gan
import tensorflow_gan as tfgan
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
# Allow matplotlib images to render immediately.
%matplotlib inline
tf.logging.set_verbosity(tf.logging.ERROR)  # Disable noisy outputs.

### Frequently Asked Questions (FAQs) for TensorFlow GAN Tutorial Code:

**Q1: Why are there imports at the beginning of the code?**
    
**A:** These imports bring in the necessary libraries and modules for the code to function. `tensorflow.compat.v1` is imported as `tf`, and the `tensorflow_gan`, `tensorflow_datasets`, `matplotlib.pyplot`, and `numpy` libraries are also imported. These are essential for building and training Generative Adversarial Networks (GANs) and visualizing results.

---

**Q2: What does `!pip install tensorflow-gan` do?**

**A:** This line installs the `tensorflow-gan` library, which is a TensorFlow extension specifically designed for Generative Adversarial Networks. It provides additional tools and functions to simplify GAN implementation.

---

**Q3: Why is `tensorflow.compat.v1` used instead of just `tensorflow`?**

**A:** The code uses `tensorflow.compat.v1` to maintain compatibility with TensorFlow version 1.x syntax. This may be necessary if the tutorial was written using an older version of TensorFlow. It ensures that the code works seamlessly with both TensorFlow 1.x and 2.x.

---

**Q4: What is `tfgan` used for?**

**A:** `tfgan` is a module within the `tensorflow_gan` library that provides tools and utilities for building and training GANs. It simplifies the implementation of GANs by offering pre-defined functions and abstractions.

---

**Q5: What is the purpose of `tensorflow_datasets` in this code?**

**A:** `tensorflow_datasets` is a library for easily downloading and managing datasets. In this code, it may be used to load a dataset that the GAN will be trained on. GANs require large datasets for training to generate realistic images.

---

**Q6: Why is `matplotlib.pyplot` imported?**

**A:** `matplotlib.pyplot` is a library for creating visualizations in Python. In this code, it is likely used to plot and display images generated by the GAN, allowing the user to visualize the progress of the training.

---

**Q7: What does `%matplotlib inline` do?**

**A:** This is a Jupyter Notebook magic command. It enables the inline rendering of Matplotlib plots, meaning that the plots will be displayed directly in the notebook below the code cell.

---

**Q8: Why is `tf.logging.set_verbosity(tf.logging.ERROR)` used?**

**A:** This line sets the TensorFlow logging verbosity to error level, suppressing unnecessary log outputs. It helps keep the console clean by reducing the amount of information displayed during the training process.

## Overview

This colab will walk you through the basics of using [TF-GAN](https://github.com/tensorflow/gan) to define, train, and evaluate Generative Adversarial Networks (GANs). We describe the library's core features as well as some extra features. This colab assumes a familiarity with TensorFlow's Python API. For more on TensorFlow, please see [TensorFlow tutorials](https://www.tensorflow.org/tutorials/).

## Learning objectives

In this Colab, you will learn how to:
*   Use TF-GAN Estimators to quickly train a GAN

## Unconditional MNIST with GANEstimator

This exercise uses TF-GAN's GANEstimator and the MNIST dataset to create a GAN for generating fake handwritten digits.

### MNIST

The [MNIST dataset](https://wikipedia.org/wiki/MNIST_database) contains tens of thousands of images of handwritten digits. We'll use these images to train a GAN to generate fake images of handwritten digits. This task is small enough that you'll be able to train the GAN in a matter of minutes.

### GANEstimator

TensorFlow's Estimator API that makes it easy to train models. TF-GAN offers `GANEstimator`, an Estimator for training GANs.

### Input Pipeline

We set up our input pipeline by defining an `input_fn`. in the "Train and Eval Loop" section below we pass this function to our GANEstimator's `train` method to initiate training.  The `input_fn`:

1.  Generates the random inputs for the generator.
2.  Uses `tensorflow_datasets` to retrieve the MNIST data.
3.  Uses the tf.data API to format the data.

In [None]:
import tensorflow_datasets as tfds
import tensorflow.compat.v1 as tf

def input_fn(mode, params):
  assert 'batch_size' in params
  assert 'noise_dims' in params
  bs = params['batch_size']
  nd = params['noise_dims']
  split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
  shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
  just_noise = (mode == tf.estimator.ModeKeys.PREDICT)

  noise_ds = (tf.data.Dataset.from_tensors(0).repeat()
              .map(lambda _: tf.random.normal([bs, nd])))

  if just_noise:
    return noise_ds

  def _preprocess(element):
    # Map [0, 255] to [-1, 1].
    images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5
    return images

  images_ds = (tfds.load('mnist:3.*.*', split=split)
               .map(_preprocess)
               .cache()
               .repeat())
  if shuffle:
    images_ds = images_ds.shuffle(
        buffer_size=10000, reshuffle_each_iteration=True)
  images_ds = (images_ds.batch(bs, drop_remainder=True)
               .prefetch(tf.data.experimental.AUTOTUNE))

  return tf.data.Dataset.zip((noise_ds, images_ds))

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code:

**Q1: What is the purpose of the `input_fn` function?**

**A:** The `input_fn` function is designed to create an input pipeline for the GAN model. It prepares and provides input data in the required format for training, evaluation, or prediction. The function takes a `mode` argument to distinguish between training, testing, and prediction modes.

---

**Q2: What are `params['batch_size']` and `params['noise_dims']` used for?**

**A:** `params['batch_size']` and `params['noise_dims']` are parameters that determine the batch size and the dimensionality of the noise vector used as input to the GAN. Batch size controls how many samples are processed in each training iteration, and `noise_dims` specifies the size of the random noise vector that is an input to the generator.

---

**Q3: How is the noise dataset generated in this code?**

**A:** The noise dataset is created using `tf.data.Dataset.from_tensors(0).repeat().map(lambda _: tf.random.normal([bs, nd]))`. It generates an infinite dataset of random noise vectors with dimensions `[batch_size, noise_dims]`.

---

**Q4: What does the `_preprocess` function do?**

**A:** The `_preprocess` function is applied to each element of the dataset and is responsible for preprocessing the images. It scales pixel values from the range [0, 255] to the range [-1, 1], which is a common practice in GANs to normalize the input images.

---

**Q5: Why is the MNIST dataset used, and how is it loaded?**

**A:** The MNIST dataset is a commonly used dataset for image classification tasks, and it contains images of handwritten digits. In this code, the MNIST dataset is loaded using `tfds.load('mnist:3.*.*', split=split)`. The `_preprocess` function is then applied to each image in the dataset.

---

**Q6: What does `images_ds = images_ds.shuffle(buffer_size=10000, reshuffle_each_iteration=True)` do?**

**A:** This line shuffles the images in the dataset. Shuffling is important during training to prevent the model from learning patterns based on the order of the input data. The `buffer_size` parameter determines the number of elements from which the shuffling is done.

---

**Q7: Why is `.batch(bs, drop_remainder=True)` used in the `images_ds` pipeline?**

**A:** This operation batches the images into batches of size `bs` (batch size), and `drop_remainder=True` ensures that any remaining samples that don't form a complete batch are dropped. This is often necessary for compatibility with the GAN training process, which requires consistent batch sizes.

---

**Q8: What does `.prefetch(tf.data.experimental.AUTOTUNE)` do?**

**A:** This line prefetches batches of data to be ready for the next iteration, improving the efficiency of training by overlapping the data loading and model execution. `tf.data.experimental.AUTOTUNE` automatically determines the appropriate number of batches to prefetch based on available system resources.

Download the data and sanity check the inputs.

In [None]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
import numpy as np

params = {'batch_size': 100, 'noise_dims':64}
with tf.Graph().as_default():
  ds = input_fn(tf.estimator.ModeKeys.TRAIN, params)
  numpy_imgs = next(iter(tfds.as_numpy(ds)))[1]
img_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(10, 10))
plt.axis('off')
plt.imshow(np.squeeze(img_grid))
plt.show()

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code:

**Q1: What is the purpose of this code snippet?**

**A:** This code snippet is designed to visualize a grid of images generated by the GAN. It utilizes the `input_fn` to create a dataset for training mode, generates a batch of images, and then uses the `tfgan.eval.python_image_grid` function to arrange these images into a grid. Finally, it uses Matplotlib to display the generated image grid.

---

**Q2: What are `params = {'batch_size': 100, 'noise_dims':64}` used for?**

**A:** `params` is a dictionary specifying the parameters needed for generating the input data. In this case, it sets the batch size to 100 and the dimensionality of the noise vector to 64. These parameters are then used when calling the `input_fn` function.

---

**Q3: How is the dataset (`ds`) created, and what does `next(iter(tfds.as_numpy(ds)))[1]` do?**

**A:** The dataset `ds` is created by calling the `input_fn` function with the training mode and the specified parameters. `next(iter(tfds.as_numpy(ds)))[1]` converts the dataset to a NumPy array and retrieves the images from the first batch. It's a way to extract and work with the actual image data for visualization.

---

**Q4: What does `tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(10, 10))` do?**

**A:** This function, `tfgan.eval.python_image_grid`, takes a batch of images (`numpy_imgs`) and arranges them into a grid. In this case, it creates a 10x10 grid of images. The purpose is to visualize a set of generated images in a structured manner for better inspection.

---

**Q5: Why is `plt.axis('off')` used?**

**A:** `plt.axis('off')` turns off the axis labels and ticks in the Matplotlib plot. In this context, it's used to create a cleaner visualization without distracting axis information.

---

**Q6: What does `plt.imshow(np.squeeze(img_grid))` do?**

**A:** `plt.imshow` is a Matplotlib function used to display an image. `np.squeeze` is used to remove single-dimensional entries from the shape of `img_grid` (if any). This step is often necessary when dealing with image grids, as it ensures compatibility with the `plt.imshow` function.

---

**Q7: Why is `plt.show()` used at the end?**

**A:** `plt.show()` is a Matplotlib function that displays the plot. It is necessary to actually visualize the generated image grid. Once all configurations and operations are set, calling `plt.show()` renders the image grid in the output.

### Neural Network Architecture

To build our GAN we need two separate networks:

*  A generator that takes input noise and outputs generated MNIST digits
*  A discriminator that takes images and outputs a probability of being real or fake

We define functions that build these networks. In the GANEstimator section below we pass the builder functions to the `GANEstimator` constructor. `GANEstimator` handles hooking the generator and discriminator together into the GAN.


In [None]:
def _dense(inputs, units, l2_weight):
  return tf.layers.dense(
      inputs, units, None,
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _batch_norm(inputs, is_training):
  return tf.layers.batch_normalization(
      inputs, momentum=0.999, epsilon=0.001, training=is_training)

def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d_transpose(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride],
      activation=tf.nn.relu, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _conv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride],
      activation=None, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (Utility Functions):

**Q1: What does the `_dense` function do?**

**A:** The `_dense` function defines a dense (fully connected) layer. It uses the `tf.layers.dense` function to create a dense layer with Glorot uniform initialization for the weights and L2 regularization for both the kernel and bias. This function is commonly used in neural network architectures.

---

**Q2: What is L2 regularization, and how is it applied in these functions?**

**A:** L2 regularization is a technique to prevent overfitting in machine learning models by penalizing large weights. In these functions, L2 regularization is applied to both the kernel and bias of the layers. The strength of the regularization is controlled by the `l2_weight` parameter.

---

**Q3: What does the `_batch_norm` function do?**

**A:** The `_batch_norm` function defines a batch normalization layer using `tf.layers.batch_normalization`. Batch normalization is a technique to improve the training stability and speed of deep neural networks by normalizing the input of each layer. The `is_training` parameter indicates whether the model is in training mode.

---

**Q4: How does the `_deconv2d` function differ from `_conv2d`?**

**A:** The `_deconv2d` function is used for transpose convolution (also known as deconvolution or fractionally strided convolution), which is often used in the generator part of a GAN to upsample the input. `_conv2d` is a standard convolution layer. Both functions have similar structures but differ in their use cases.

---

**Q5: What does the `_deconv2d` function's `activation=tf.nn.relu` parameter do?**

**A:** The `activation=tf.nn.relu` parameter specifies that Rectified Linear Unit (ReLU) activation should be applied to the output of the `_deconv2d` layer. ReLU is a common activation function that introduces non-linearity to the model by thresholding the output at zero.

---

**Q6: Why is the `padding='same'` parameter used in convolutional layers?**

**A:** The `padding='same'` parameter ensures that the spatial dimensions of the output tensor match the input tensor. It pads the input tensor with zeros if necessary to achieve this. This is particularly useful when designing neural networks to preserve spatial information through convolutional layers.

---

**Q7: How are these functions related to GANs?**

**A:** These functions are building blocks commonly used in the architecture of Generative Adversarial Networks (GANs). They define layers like dense, batch normalization, and convolutional layers, which are essential for constructing the generator and discriminator networks in a GAN.

In [None]:
def unconditional_generator(noise, mode, weight_decay=2.5e-5):
  """Generator to produce unconditional MNIST images."""
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)

  net = _dense(noise, 1024, weight_decay)
  net = _batch_norm(net, is_training)
  net = tf.nn.relu(net)

  net = _dense(net, 7 * 7 * 256, weight_decay)
  net = _batch_norm(net, is_training)
  net = tf.nn.relu(net)

  net = tf.reshape(net, [-1, 7, 7, 256])
  net = _deconv2d(net, 64, 4, 2, weight_decay)
  net = _deconv2d(net, 64, 4, 2, weight_decay)
  # Make sure that generator output is in the same range as `inputs`
  # ie [-1, 1].
  net = _conv2d(net, 1, 4, 1, 0.0)
  net = tf.tanh(net)

  return net

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (Generator Function):

**Q1: What is the purpose of the `unconditional_generator` function?**

**A:** The `unconditional_generator` function defines the generator network for an unconditional GAN designed to produce MNIST-like images. Given a noise vector as input, the function generates synthetic images.

---

**Q2: What does `is_training = (mode == tf.estimator.ModeKeys.TRAIN)` do?**

**A:** This line sets the `is_training` flag based on the mode of the GAN (training or not). It is used to control the behavior of batch normalization layers. During training, batch normalization normalizes the input based on the batch statistics, while during inference, it uses population statistics.

---

**Q3: How is the generator network structured?**

**A:** The generator network consists of fully connected layers (`_dense`), batch normalization layers (`_batch_norm`), ReLU activation functions (`tf.nn.relu`), and transpose convolutional layers (`_deconv2d`). The network starts with a dense layer followed by batch normalization and ReLU activation. It then proceeds with another dense layer, batch normalization, and ReLU activation. The subsequent layers reshape the tensor and apply transpose convolutional layers for upsampling. The final layer ensures the output is in the range [-1, 1] using a convolutional layer followed by the hyperbolic tangent (`tf.tanh`) activation function.

---

**Q4: Why is `tf.reshape(net, [-1, 7, 7, 256])` used?**

**A:** This line reshapes the tensor to have dimensions `[-1, 7, 7, 256]`. It prepares the tensor for the subsequent transpose convolutional layers, shaping it into a 4D tensor that can be interpreted as an image with dimensions 7x7 and 256 channels.

---

**Q5: What is the purpose of `_deconv2d(net, 64, 4, 2, weight_decay)` being called twice?**

**A:** The `_deconv2d` function is called twice to perform upsampling. Each call increases the spatial dimensions of the tensor while reducing the number of channels. This is a common pattern in GAN generator networks to create larger and more complex features from the input noise.

---

**Q6: Why is `_conv2d(net, 1, 4, 1, 0.0)` used with a kernel size of 4 and a stride of 1?**

**A:** The final `_conv2d` layer is used to ensure that the generator's output is in the same range as the input data (MNIST images). The kernel size of 4 and stride of 1 maintain the spatial dimensions while adjusting the channel size to 1. The `tf.tanh` activation function is then applied to squash the output values to the range [-1, 1], matching the range of the MNIST dataset.

---

**Q7: What does `tf.tanh(net)` do?**

**A:** The `tf.tanh(net)` applies the hyperbolic tangent activation function to the output of the generator. It squashes the output values to the range [-1, 1], which is a common practice for GANs to match the pixel value range of real images (MNIST in this case).

In [None]:
_leaky_relu = lambda net: tf.nn.leaky_relu(net, alpha=0.01)

def unconditional_discriminator(img, unused_conditioning, mode, weight_decay=2.5e-5):
  del unused_conditioning
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)

  net = _conv2d(img, 64, 4, 2, weight_decay)
  net = _leaky_relu(net)

  net = _conv2d(net, 128, 4, 2, weight_decay)
  net = _leaky_relu(net)

  net = tf.layers.flatten(net)

  net = _dense(net, 1024, weight_decay)
  net = _batch_norm(net, is_training)
  net = _leaky_relu(net)

  net = _dense(net, 1, weight_decay)

  return net

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (Discriminator Function):

**Q1: What is the purpose of the `_leaky_relu = lambda net: tf.nn.leaky_relu(net, alpha=0.01)` line?**

**A:** This line defines a leaky ReLU activation function with a small negative slope (alpha=0.01). Leaky ReLU is a variation of the standard ReLU activation function that allows a small gradient when the input is negative, which can help with learning in certain situations.

---

**Q2: What does the `del unused_conditioning` line do?**

**A:** The `del unused_conditioning` line deletes the reference to the `unused_conditioning` variable. This is done to avoid any warnings about an unused variable and to make it clear that this variable is intentionally not being used in the function.

---

**Q3: How is the discriminator network structured?**

**A:** The discriminator network consists of convolutional layers (`_conv2d`), leaky ReLU activation functions (`_leaky_relu`), flattening (`tf.layers.flatten`), fully connected layers (`_dense`), batch normalization layers (`_batch_norm`), and weight decay regularization. The network takes an image (`img`) as input and produces a scalar output representing the likelihood of the input being a real image.

---

**Q4: Why are there two convolutional layers with leaky ReLU activation functions?**

**A:** The two convolutional layers with leaky ReLU activation functions are responsible for capturing hierarchical features from the input image. These layers learn to extract increasingly complex and abstract features, helping the discriminator distinguish between real and generated images.

---

**Q5: What does `tf.layers.flatten(net)` do?**

**A:** This operation flattens the 3D tensor (`net`) into a 1D tensor. It is typically used when transitioning from convolutional layers to fully connected layers in a neural network. In this case, it prepares the tensor for processing by the subsequent fully connected layers.

---

**Q6: Why is `_dense(net, 1, weight_decay)` used for the final layer?**

**A:** The final `_dense` layer with one unit is used to produce a scalar output, representing the discriminator's confidence in the input being a real image. The weight decay is applied to regularize the weights of the dense layer.

---

**Q7: How is the leaky ReLU activation applied in this discriminator?**

**A:** The leaky ReLU activation is applied after each convolutional layer and the first fully connected layer using the `_leaky_relu` function. It introduces non-linearity to the model while allowing a small, non-zero gradient for negative inputs. This can help prevent issues like dead neurons and improve the learning of complex features.

### Evaluating Generative Models, and evaluating GANs


TF-GAN provides some standard methods of evaluating generative models. In this example, we measure:

*  Inception Score: called `mnist_score` below.
*  Frechet Inception Distance

We apply a pre-trained classifier to both the real data and the generated data calculate the *Inception Score*.  The Inception Score is designed to measure both quality and diversity. See [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) by Salimans et al for more information about the Inception Score.

*Frechet Inception Distance* measures how close the generated image distribution is to the real image distribution.  See [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) by Heusel et al for more information about the Frechet Inception distance.

In [None]:
from tensorflow_gan.examples.mnist import util as eval_util
import os

def get_eval_metric_ops_fn(gan_model):
  real_data_logits = tf.reduce_mean(gan_model.discriminator_real_outputs)
  gen_data_logits = tf.reduce_mean(gan_model.discriminator_gen_outputs)
  real_mnist_score = eval_util.mnist_score(gan_model.real_data)
  generated_mnist_score = eval_util.mnist_score(gan_model.generated_data)
  frechet_distance = eval_util.mnist_frechet_distance(
      gan_model.real_data, gan_model.generated_data)
  return {
      'real_data_logits': tf.metrics.mean(real_data_logits),
      'gen_data_logits': tf.metrics.mean(gen_data_logits),
      'real_mnist_score': tf.metrics.mean(real_mnist_score),
      'mnist_score': tf.metrics.mean(generated_mnist_score),
      'frechet_distance': tf.metrics.mean(frechet_distance),
  }

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (Evaluation Metrics):

**Q1: What is the purpose of the `get_eval_metric_ops_fn` function?**

**A:** The `get_eval_metric_ops_fn` function is designed to define a set of evaluation metrics for assessing the performance of a GAN model. These metrics provide insights into how well the GAN is generating realistic images compared to real data.

---

**Q2: What do `real_data_logits` and `gen_data_logits` represent?**

**A:** `real_data_logits` and `gen_data_logits` represent the average logit scores (discriminator outputs) for real and generated data, respectively. These scores are indicative of how well the discriminator is distinguishing between real and generated samples. Higher logit scores for real data and lower scores for generated data are desirable.

---

**Q3: What does `eval_util.mnist_score` do, and how is it used in this code?**

**A:** `eval_util.mnist_score` is a utility function that calculates the Inception Score (IS) for a given set of images. Inception Score is a metric commonly used to evaluate the quality and diversity of generated images. The function is applied to both real and generated data, providing insight into how well the GAN is performing.

---

**Q4: What is the purpose of `frechet_distance` in the evaluation metrics?**

**A:** `frechet_distance` represents the Frechet Inception Distance (FID) between the distributions of real and generated images. FID is another widely used metric for evaluating GANs. It measures the similarity between the distribution of real images and the distribution of generated images. Lower FID values indicate better performance.

---

**Q5: Why are these evaluation metrics important?**

**A:** Evaluation metrics provide quantitative measures of the GAN's performance. They help assess the quality, diversity, and similarity of generated images to real data. Monitoring these metrics during training can guide model development and tuning, ensuring that the GAN generates high-quality and diverse samples.

---

**Q6: How are these metrics used in practice?**

**A:** These metrics can be used to compare different GAN models, identify potential issues during training (such as mode collapse or lack of diversity), and guide hyperparameter tuning. Monitoring the metrics over training epochs helps practitioners understand how well the GAN is learning and whether adjustments are needed to improve performance.

### GANEstimator

The `GANEstimator` assembles and manages the pieces of the whole GAN model. The `GANEstimator` constructor takes the following compoonents for both the generator and discriminator:

*  Network builder functions: we defined these in the "Neural Network Architecture" section above.
*  Loss functions: here we use the wasserstein loss for both.
*  Optimizers: here we use `tf.train.AdamOptimizer` for both generator and discriminator training.

In [None]:
train_batch_size = 32 #@param
noise_dimensions = 64 #@param
generator_lr = 0.001 #@param
discriminator_lr = 0.0002 #@param

def gen_opt():
  gstep = tf.train.get_or_create_global_step()
  base_lr = generator_lr
  # Halve the learning rate at 1000 steps.
  lr = tf.cond(gstep < 1000, lambda: base_lr, lambda: base_lr / 2.0)
  return tf.train.AdamOptimizer(lr, 0.5)

gan_estimator = tfgan.estimator.GANEstimator(
    generator_fn=unconditional_generator,
    discriminator_fn=unconditional_discriminator,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    params={'batch_size': train_batch_size, 'noise_dims': noise_dimensions},
    generator_optimizer=gen_opt,
    discriminator_optimizer=tf.train.AdamOptimizer(discriminator_lr, 0.5),
    get_eval_metric_ops_fn=get_eval_metric_ops_fn)

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (GAN Estimator Configuration):

**Q1: What is the purpose of the `gan_estimator` object?**

**A:** The `gan_estimator` object is an instance of the `tfgan.estimator.GANEstimator` class, which is part of the TensorFlow GAN library. It is configured to train a Generative Adversarial Network (GAN) using the specified generator and discriminator functions, loss functions, and optimization parameters.

---

**Q2: How are the learning rates for the generator and discriminator determined?**

**A:** The learning rate for the generator (`generator_lr`) is set to 0.001, and the learning rate for the discriminator (`discriminator_lr`) is set to 0.0002. These values are determined based on the user-provided parameters and are common choices for GAN training. The learning rate for the generator is halved after 1000 global steps.

---

**Q3: What does `gen_opt` do, and how is it related to the generator's learning rate?**

**A:** `gen_opt` is a function that returns an Adam optimizer for the generator. It uses a learning rate (`lr`) that is initially set to `generator_lr` and is halved after 1000 global steps (`gstep`). This adaptive learning rate strategy is commonly used to stabilize GAN training.

---

**Q4: What are the generator and discriminator loss functions used in this GAN setup?**

**A:** The generator loss function is specified as `tfgan.losses.wasserstein_generator_loss`, and the discriminator loss function is specified as `tfgan.losses.wasserstein_discriminator_loss`. Wasserstein GAN (WGAN) loss functions are known for promoting more stable training compared to traditional GAN loss functions.

---

**Q5: How are the batch size and noise dimensionality configured for training?**

**A:** The batch size (`train_batch_size`) is set to 32, and the noise dimensionality (`noise_dimensions`) is set to 64. These parameters control the number of samples in each training batch and the dimensionality of the random noise vector used as input to the generator.

---

**Q6: How are the generator and discriminator functions (`unconditional_generator` and `unconditional_discriminator`) used in the GANEstimator?**

**A:** The generator and discriminator functions are provided as arguments to the GANEstimator. The GANEstimator uses these functions to construct the generator and discriminator networks. During training, the generator is optimized to generate realistic images, and the discriminator is optimized to distinguish between real and generated images.

---

**Q7: What role does `get_eval_metric_ops_fn` play in the GANEstimator?**

**A:** The `get_eval_metric_ops_fn` function provides a set of evaluation metrics that are computed during the training process. These metrics offer insights into the performance of the GAN, such as the average logits for real and generated data, Inception Score, and Frechet Inception Distance. These metrics can be monitored to assess the quality and diversity of the generated images.

### Train and eval loop

The `GANEstimator`'s `train()` method initiates GAN training, including the alternating generator and discriminator training phases.

The loop in the code below calls `train()` repeatedly in order to periodically display generator output and evaluation results. But note that the code below does not manage the alternation between discriminator and generator: that's all handled automatically by `train()`.

In [None]:
# Disable noisy output.
tf.autograph.set_verbosity(0, False)

import time
steps_per_eval = 500 #@param
max_train_steps = 5000 #@param
batches_for_eval_metrics = 100 #@param

# Used to track metrics.
steps = []
real_logits, fake_logits = [], []
real_mnist_scores, mnist_scores, frechet_distances = [], [], []

cur_step = 0
start_time = time.time()
while cur_step < max_train_steps:
  next_step = min(cur_step + steps_per_eval, max_train_steps)

  start = time.time()
  gan_estimator.train(input_fn, max_steps=next_step)
  steps_taken = next_step - cur_step
  time_taken = time.time() - start
  print('Time since start: %.2f min' % ((time.time() - start_time) / 60.0))
  print('Trained from step %i to %i in %.2f steps / sec' % (
      cur_step, next_step, steps_taken / time_taken))
  cur_step = next_step

  # Calculate some metrics.
  metrics = gan_estimator.evaluate(input_fn, steps=batches_for_eval_metrics)
  steps.append(cur_step)
  real_logits.append(metrics['real_data_logits'])
  fake_logits.append(metrics['gen_data_logits'])
  real_mnist_scores.append(metrics['real_mnist_score'])
  mnist_scores.append(metrics['mnist_score'])
  frechet_distances.append(metrics['frechet_distance'])
  print('Average discriminator output on Real: %.2f  Fake: %.2f' % (
      real_logits[-1], fake_logits[-1]))
  print('Inception Score: %.2f / %.2f  Frechet Distance: %.2f' % (
      mnist_scores[-1], real_mnist_scores[-1], frechet_distances[-1]))

  # Vizualize some images.
  iterator = gan_estimator.predict(
      input_fn, hooks=[tf.train.StopAtStepHook(num_steps=21)])
  try:
    imgs = np.array([next(iterator) for _ in range(20)])
  except StopIteration:
    pass
  tiled = tfgan.eval.python_image_grid(imgs, grid_shape=(2, 10))
  plt.axis('off')
  plt.imshow(np.squeeze(tiled))
  plt.show()


# Plot the metrics vs step.
plt.title('MNIST Frechet distance per step')
plt.plot(steps, frechet_distances)
plt.figure()
plt.title('MNIST Score per step')
plt.plot(steps, mnist_scores)
plt.plot(steps, real_mnist_scores)
plt.show()

### Frequently Asked Questions (FAQs) for the TensorFlow GAN Tutorial Code (Training and Evaluation Loop):

**Q1: What is the purpose of the training and evaluation loop in this code?**

**A:** The training and evaluation loop is designed to iteratively train a Generative Adversarial Network (GAN) and evaluate its performance over multiple steps. It monitors various metrics, such as discriminator outputs, Inception Score, and Frechet Distance, to assess the quality and diversity of the generated images.

---

**Q2: How is the training loop structured?**

**A:** The training loop consists of multiple iterations, where each iteration corresponds to a training step. The GAN is trained using the `train` method of the `GANEstimator`, and training progress is printed, including the time taken for each iteration.

---

**Q3: What metrics are being tracked during training?**

**A:** The following metrics are tracked during training:
- Average discriminator output on real data (`real_logits`).
- Average discriminator output on generated (fake) data (`fake_logits`).
- Inception Score for generated data (`mnist_scores`).
- Inception Score for real data (`real_mnist_scores`).
- Frechet Inception Distance between real and generated data (`frechet_distances`).

---

**Q4: Why are some metrics evaluated after a certain number of training steps (`batches_for_eval_metrics`)?**

**A:** Evaluating metrics after a certain number of training steps provides insights into the current performance of the GAN. It allows for monitoring how well the GAN is learning and generating realistic images over time.

---

**Q5: What is the purpose of visualizing images during training?**

**A:** Visualizing images during training provides a qualitative assessment of the generated samples. The code generates and displays a grid of 20 images every few training steps (specified by `tf.train.StopAtStepHook(num_steps=21)`). This allows the user to visually inspect the quality and diversity of the generated images.

---

**Q6: How are the metrics and visualizations plotted over training steps?**

**A:** The code collects the metrics (e.g., Inception Score, Frechet Distance) at each evaluation step and plots them against the corresponding training steps using Matplotlib. This provides a graphical representation of the GAN's performance trends over time.

---

**Q7: What do the plotted metrics reveal about the GAN's performance?**

**A:** The plotted metrics, such as Frechet Distance and Inception Score, offer insights into the quality and diversity of the generated images. A decreasing Frechet Distance and increasing Inception Score are generally indicative of improved GAN performance. These metrics help guide the training process and can signal when the GAN is producing high-quality results.