## Overview


This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File > View on GitHub**.

### TPUs

TPUs are chips optimized for machine learning training and inference. GAN training uses a lot of compute power, so TPUs expand the range of what we can realistically accomplish with GANs.

### CelebA Task

In this colab we'll use TPUs to train a GAN for a more difficult task than the MNIST task. We'll use the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset to train a model to generate faces.

### DCGAN Architecture

Our model implements the Deep Convolutional Generative Adversarial Network (DCGAN) architecture.

## Instructions

<h3>  &nbsp;&nbsp;Train on TPU&nbsp;&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a></h3>

### Steps to run this notebook

This notebook should be run in Colaboratory. If you are viewing this from GitHub, follow the GitHub instructions. If you are viewing this from Colaboratory, you should skip to the Colaboratory instructions.

#### Steps from GitHub

1. Click the `Open in Colab` badge.
1. Run the notebook in colaboratory by following the instructions below.

#### Steps from Colaboratory

1. Go to `Runtime > Change runtime type`.
1. Click `Hardware accelerator`.
1. Select `TPU` and click `Save`.
1. Click Runtime again and select **Runtime > Run All**. You can also run the cells manually with Shift-ENTER.  

### Authentication

In [130]:
%tensorflow_version 2.x
import tensorflow as tf
import os
import time

# Google Cloud Storage bucket for Estimator logs and storing
# the training dataset.
bucket = 'celeba-public' #@param {type:"string"}

assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))

model_dir = 'gs://{}/{}'.format(
    bucket, time.strftime('tpuestimator-tfgan/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(model_dir))

from google.colab import auth
auth.authenticate_user()

assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'
tpu_address = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])


Using bucket: celeba-public
Using model dir: gs://celeba-public/tpuestimator-tfgan/2020-05-04-08-41-06


### Check imports

In [0]:
# Check that imports for the rest of the file work.
!pip install tensorflow-gan
import tensorflow_gan as tfgan
import numpy as np
import matplotlib.pyplot as plt
# Allow matplotlib images to render immediately.
%matplotlib inline
# Disable noisy outputs.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.autograph.set_verbosity(0, False)
import warnings
warnings.filterwarnings("ignore")

## Train and evaluate a GAN model on TPU using TF-GAN.


### Input pipeline


Our input data is stored on Google Cloud Storage. To more fully use the parallelism TPUs offer us, and to avoid bottlenecking on data transfer, we've stored our input data in TFRecord files, 2025 images per file.

Below, we make heavy use of `tf.data.experimental.AUTOTUNE` to optimize different parts of input loading.

In [0]:
AUTO = tf.data.experimental.AUTOTUNE

gcs_pattern = 'gs://celeba-public/tfrecord_*.tfrec'

filenames = tf.io.gfile.glob(gcs_pattern)

GENERATE_RES = 1 # Generation resolution factor (1=32, 2=64, 3=96, 4=128, etc.)

GENERATE_SQUARE = 32 * GENERATE_RES # rows/cols (should be square)

IMAGE_CHANNELS = 3

IMAGE_SIZE = (GENERATE_SQUARE, GENERATE_SQUARE, IMAGE_CHANNELS)

def parse_attribute_list(example):
  features = {
      "names": tf.io.FixedLenFeature([], tf.string),
  }

  example = tf.io.parse_single_example(example, features)
  attributes_names = example['names']
  return attributes_names

def get_names():
  record = tf.data.TFRecordDataset('gs://celeba-test/attribute_list.tfrec')
  attributes = record.map(parse_attribute_list)
  att_names = next(attributes.as_numpy_iterator()).decode("utf-8")
  att_names_list = [elem.strip()[1:-1] for elem in att_names.split(',')]
  return att_names_list

att_names_list = get_names()

def input_fn(mode, params):
  assert 'batch_size' in params
  assert 'noise_dims' in params
  bs = params['batch_size']
  nd = params['noise_dims']
  shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
  just_noise = (mode == tf.estimator.ModeKeys.PREDICT)
  
  lambda_noise = lambda _: tf.random.normal([bs, nd])

  noise_ds = (tf.data.Dataset.from_tensors(0)
              .map(lambda_noise)
              # If 'predict', just generate one batch.
              .repeat(1 if just_noise else None))
  if just_noise:
    return noise_ds


  feature_dict = {
        "filename": tf.io.FixedLenFeature([], tf.string),
        "height": tf.io.FixedLenFeature([], tf.int64),
        "width": tf.io.FixedLenFeature([], tf.int64),
        "depth": tf.io.FixedLenFeature([], tf.int64),
        "image": tf.io.FixedLenFeature([], tf.string),
    }

  attributes_dict = dict(zip(att_names_list, [tf.io.FixedLenFeature([], tf.int64) for elem in att_names_list]))

  feature_dict.update(attributes_dict) 

  def parse_tfrecord(example):
    features = feature_dict
    example = tf.io.parse_single_example(example, features)
    #filename = example['filename']
    width = tf.cast(example['width'],tf.int64)
    height = tf.cast(example['height'],tf.int64)
    decoded = tf.image.decode_image(example['image'])  
    normalized = tf.cast(decoded, tf.float32) / 255.0 # convert each 0-255 value to floats in [0, 1] range
    image_tensor = tf.reshape(normalized, [height, width, 3])
    image_tensor = tf.image.resize(image_tensor[45:173,25:153], (IMAGE_SIZE[0], IMAGE_SIZE[1])) # crop and reshape the image 
    #attr_dict = {}
    #for name in att_names_list:
    #  attr_dict[name] = example[name]

    # return filename, image_tensor, attr_dict
    return image_tensor

  def load_dataset(filenames):
    # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
    records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    return records.map(parse_tfrecord, num_parallel_calls=AUTO)

  images_ds = load_dataset(filenames).cache().repeat()
  
  if shuffle:
    images_ds = images_ds.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
    
  images_ds = (images_ds.batch(bs, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))

  return tf.data.Dataset.zip((noise_ds, images_ds))

def noise_input_fn(params):
  np.random.seed(0)
  np_noise = np.random.randn(params['batch_size'], params['noise_dims'])
  return tf.data.Dataset.from_tensors(tf.constant(np_noise, dtype=tf.float32))

Sanity check the inputs.


In [0]:
import tensorflow_datasets as tfds

params = {'batch_size': 100, 'noise_dims':64}
ds = input_fn(tf.estimator.ModeKeys.EVAL, params)
imgs = next(tfds.as_numpy(ds))[1]

# plot a list of loaded faces
def plot_faces(faces, n):
  plt.figure(figsize=(13,13))
  for i in range(n * n):
    # define subplot
    plt.subplot(n, n, 1 + i)
    # turn off axis
    plt.axis('off')
    # plot
    plt.imshow(faces[i])
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()

plot_faces(imgs, 10)

### Neural Net Architecture

As usual, our GAN has two separate networks:

*  A generator that takes input noise and outputs images
*  A discriminator that takes images and outputs a probability of being real

We define `discriminator()` and `generator()` builder functions that assemble these networks. In the "Estimator" section below we pass the `discriminator()` and `generator()` functions to the `TPUGANEstimator`

In [0]:
def _leaky_relu(x):
  return tf.nn.leaky_relu(x, alpha=0.2)


def _batch_norm(x, is_training, name):
  return tf.compat.v1.layers.batch_normalization(
      x, momentum=0.9, epsilon=1e-5, training=is_training, name=name)


def _dense(x, channels, name):
  return tf.compat.v1.layers.dense(
      x, channels,
      kernel_initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02),
      name=name)


def _conv2d(x, filters, kernel_size, stride, name):
  return tf.compat.v1.layers.conv2d(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02),
      name=name)


def _deconv2d(x, filters, kernel_size, stride, name):
  return tf.compat.v1.layers.conv2d_transpose(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02),
      name=name)


def discriminator(images, unused_conditioning, is_training=True,
                  scope='Discriminator'):
  """Discriminator for CIFAR images.

  Args:
    images: A Tensor of shape [batch size, width, height, channels], that can be
      either real or generated. It is the discriminator's goal to distinguish
      between the two.
    unused_conditioning: The TFGAN API can help with conditional GANs, which
      would require extra `condition` information to both the generator and the
      discriminator. Since this example is not conditional, we do not use this
      argument.
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
    scope: A variable scope or string for the discriminator.

  Returns:
    A 1D Tensor of shape [batch size] representing the confidence that the
    images are real. The output can lie in [-inf, inf], with positive values
    indicating high confidence that the images are real.
  """
  with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
    x = _conv2d(images, 64, 5, 2, name='d_conv1')
    x = _leaky_relu(x)

    x = _conv2d(x, 128, 5, 2, name='d_conv2')
    x = _leaky_relu(_batch_norm(x, is_training, name='d_bn2'))

    x = _conv2d(x, 256, 5, 2, name='d_conv3')
    x = _leaky_relu(_batch_norm(x, is_training, name='d_bn3'))

    x = tf.reshape(x, [-1, 4 * 4 * 256])

    x = _dense(x, 1, name='d_fc_4')

    return x


def generator(noise, is_training=True, scope='Generator'):
  """Generator to produce CIFAR images.

  Args:
    noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
      does not use conditioning, this Tensor represents a noise vector of some
      kind that will be reshaped by the generator into CIFAR examples.
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
    scope: A variable scope or string for the generator.

  Returns:
    A single Tensor with a batch of generated CIFAR images.
  """
  with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
    net = _dense(noise, 4096, name='g_fc1')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn1'))

    net = tf.reshape(net, [-1, 4, 4, 256])

    net = _deconv2d(net, 128, 5, 2, name='g_dconv2')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn2'))

    net = _deconv2d(net, 64, 4, 2, name='g_dconv3')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn3'))

    net = _deconv2d(net, 3, 4, 2, name='g_dconv4')
    net = tf.tanh(net)

    return net

### Estimator

TF-GAN's `TPUGANEstimator` is like `GANEstimator`, but it extends TensorFlow's `TPUEstimator` class. `TPUEstimator` handles the details of deploying the network on a TPU.

In [0]:
import tensorflow.compat.v1 as tf_compat

noise_dims = 1024 #@param
generator_lr = 0.0002  #@param
discriminator_lr = 0.0002  #@param
train_batch_size = 1024  #@param
images_per_batch = 2000 #@param

config = tf.compat.v1.estimator.tpu.RunConfig(
    model_dir=model_dir,
    master=tpu_address,
    tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(iterations_per_loop=images_per_batch),
    #save_summary_steps=None,
    #save_checkpoints_secs=None
    )
    
est = tfgan.estimator.TPUGANEstimator(
    generator_fn=generator,
    discriminator_fn=discriminator,
    generator_loss_fn=tfgan.losses.modified_generator_loss,
    discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
    generator_optimizer=tf.compat.v1.train.AdamOptimizer(generator_lr, 0.5),
    discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(discriminator_lr, 0.5),
    joint_train=True,  # train G and D jointly instead of sequentially.
    train_batch_size=train_batch_size,
    predict_batch_size=images_per_batch,
    use_tpu=True,
    params={'noise_dims': noise_dims},
    config=config)

### Train and Eval Loop

Train and vizualize.

In [0]:
def hms_string(sec_elapsed):
  h = int(sec_elapsed / (60 * 60))
  m = int((sec_elapsed % (60 * 60)) / 60)
  s = sec_elapsed % 60
  return "{}:{:>02}:{:>05.2f}".format(h, m, s)

In [0]:
epochs = 10000 #@param

start = time.time()
est.train(input_fn,steps=epochs)
elapsed = time.time()-start
print ("Training time: {}".format(hms_string(elapsed)))

# Generate and show some predictions.
predictions = np.array([x['generated_data'] for x in est.predict(noise_input_fn)])[:100]
plot_faces(predictions,10)