In [0]:
# VARIABLES WHICH MUST BE SET
BUCKET_NAME = "" # please do create that Storage Bucket on Google Cloud Platform and make it public, so no additional permissions are required
ARCHITECTURE = "dcgan" # either "mlpgan" or "dcgan"

In [0]:
!pip -q install tensorflow-gan

In [0]:
from datetime import datetime
import math
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
%matplotlib inline  
import numpy as np
import os
import pprint
from PIL import Image
import re
import requests
import tensorflow as tf
import tensorflow_gan as tfgan
import xml.etree.ElementTree as ET

DATASET_SIZE = 60000
TRAIN_BATCH_SIZE_PER_TPU = 128
TPU_CORES = 8
PREDICT_BATCH_SIZE = 64

TRAIN_BATCH_SIZE = TRAIN_BATCH_SIZE_PER_TPU * TPU_CORES # 128 examples per tpu core, 8 tpu cores, so 1024 examples overall

TRAIN_SET_SIZE = DATASET_SIZE // TRAIN_BATCH_SIZE * TRAIN_BATCH_SIZE
assert TRAIN_SET_SIZE < 60000, "Train set size can't be bigger than whole dataset"

EVAL_BATCH_SIZE = TRAIN_SET_SIZE

In [0]:
print("Dataset size: {}".format(DATASET_SIZE))
print("Train set size: {}".format(TRAIN_SET_SIZE))
print("Train batch size: {}".format(TRAIN_BATCH_SIZE))
print("Test set size: {}".format(TRAIN_SET_SIZE))
print("Eval batch size: {}".format(EVAL_BATCH_SIZE))

In [0]:
if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', tpu_address)

  with tf.Session(tpu_address) as session:
    devices = session.list_devices()
    
  print('TPU devices:')
  pprint.pprint(devices)

In [0]:
def generate_cifar_tfrecords():
  if tf.gfile.Exists("gs://{}/tfrecords".format(BUCKET_NAME)):
    tf.gfile.DeleteRecursively("gs://{}/tfrecords".format(BUCKET_NAME))

  dataset = tf.keras.datasets.cifar10.load_data()
  x = np.concatenate([dataset[0][0], dataset[1][0]], axis=0)
  y = np.concatenate([dataset[0][1], dataset[1][1]], axis=0)

  np.random.seed(0)
  indices_permutated = np.random.permutation(x.shape[0])

  x = np.take(x, indices_permutated[:TRAIN_SET_SIZE], axis=0)
  y = np.take(y, indices_permutated[:TRAIN_SET_SIZE], axis=0)

  def convert_mnist_to_tfrecords(output_path, file_pattern, x, y, number_of_shards = 1):
    assert x.shape[0] == y.shape[0], "Number of examples in x and y must be equal."

    if not os.path.exists(output_path):
      os.makedirs(output_path)

    examples_no = x.shape[0]

    x_train_chunks = np.split(x, number_of_shards)
    y_train_chunks = np.split(y, number_of_shards)

    for i, (x_train, y_train) in enumerate(zip(x_train_chunks, y_train_chunks)):
      with tf.python_io.TFRecordWriter(os.path.join(output_path, file_pattern.format(i + 1, number_of_shards))) as writer:
        for image, label_id in zip(x_train, y_train):
          example = tf.train.Example(features=tf.train.Features(feature={'image': tf.train.Feature(int64_list=tf.train.Int64List(value=image.flatten())),
                                                                         'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label_id]))}))

          writer.write(example.SerializeToString())

  convert_mnist_to_tfrecords('gs://{}/tfrecords'.format(BUCKET_NAME), 'train-{}-of-{}.tfrecords', x, y, number_of_shards=2)
  
generate_cifar_tfrecords()

In [0]:
r = requests.get('https://storage.googleapis.com/{}/'.format(BUCKET_NAME))
root = ET.fromstring(r.content)
bucket_files = [contents_key.text for contents_key in root.findall('{http://doc.s3.amazonaws.com/2006-03-01}Contents/{http://doc.s3.amazonaws.com/2006-03-01}Key')]

train_tfrecord_files = ["gs://{}/{}".format(BUCKET_NAME, item) for item in bucket_files if re.search(re.escape("tfrecords/train-*-of-*.tfrecords").replace("\\*", "[0-9]+"), item)]

In [0]:
!wget -q -O ngrok-stable-linux-amd64.zip https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip -q -o ngrok-stable-linux-amd64.zip -d .

In [0]:
LOG_DIR = "gs://{}/models_dir/experiments_mlpgan_17042019_151813".format(BUCKET_NAME)
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)

In [0]:
get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

In [0]:
def make_input_with_noise_fn(filenames, shuffle=False):

  def input_fn_with_noise(params):
    batch_size = params["batch_size"]
    noise_dim = params["noise_dim"]

    def parser(serialized_example):
      features = tf.parse_single_example(
          serialized_example,
          features={
              "image": tf.FixedLenFeature([32, 32, 3], tf.int64),
              "label": tf.FixedLenFeature([], tf.int64),
          })
      
      image = tf.cast(features["image"], tf.float32) * (2. / 255) - 1
      
      return image

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat()
    
    if shuffle:
      dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
    
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    images = dataset.make_one_shot_iterator().get_next()    
    random_noise = tf.random_normal([batch_size, noise_dim], dtype=tf.float32)
    
    return random_noise, images

  return input_fn_with_noise

def noise_input_fn(params):
  batch_size = params["batch_size"]
  noise_dim = params["noise_dim"]
    
  np.random.seed(0)
  return tf.data.Dataset.from_tensors(tf.constant(
      np.random.randn(batch_size, noise_dim), dtype=tf.float32))

train_input_with_noise_fn = make_input_with_noise_fn(train_tfrecord_files, shuffle=True)
test_input_with_noise_fn = make_input_with_noise_fn(train_tfrecord_files)

In [0]:
def _leaky_relu(x):
  return tf.nn.leaky_relu(x, alpha=0.2)

def _relu(x):
  return tf.nn.relu(x)

def _tanh(x):
  return tf.tanh(x)

def _batch_norm(x, is_training, name):
  return tf.layers.batch_normalization(
      x, momentum=0.9, epsilon=1e-5, training=is_training, name=name)


def _dense(x, channels, name):
  return tf.layers.dense(
      x, channels,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _conv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _deconv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d_transpose(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)

def _flatten(x):
  return tf.layers.flatten(x)

def _dropout(x, is_training):
  return tf.layers.dropout(x, rate=0.25, training=is_training)

In [0]:
# MLPGAN
def mlpgan_discriminator_fn(images, unused_conditioning, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
    x = _flatten(images)
    
    x = _dense(x, 256, 'd_fc1')
    x = _dropout(x, is_training)
    x = _tanh(x)

    x = _dense(x, 1, 'd_fc2')
    
  return x

def mlpgan_generator_fn(random_noise, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
    x = _dense(random_noise, 256, 'g_fc1')
    x = _dropout(x, is_training)
    x = _tanh(x)
    
    x = _dense(x, 32 * 32 * 3, 'g_fc2')

    x = tf.reshape(x, [-1, 32, 32, 3])
    x = _tanh(x)

  return x

In [0]:
# DCGAN
def dcgan_discriminator_fn(images, unused_conditioning, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
    x = _conv2d(images, 64, 5, 2, 'd_conv1')
    x = _leaky_relu(x)

    x = _conv2d(x, 128, 5, 2, 'd_conv2')
    x = _batch_norm(x, is_training, 'd_bn2')
    x = _leaky_relu(x)
    
    x = _conv2d(x, 256, 5, 2, 'd_conv3')
    x = _batch_norm(x, is_training, 'd_bn3')
    x = _leaky_relu(x)

    x = tf.reshape(x, [-1, 4 * 4 * 256])

    x = _dense(x, 1, 'd_fc4')
    
  return x

def dcgan_generator_fn(random_noise, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
    x = _dense(random_noise, 4096, 'g_fc1')
    x = _batch_norm(x, is_training, 'g_bn1')
    x = _relu(x)

    x = tf.reshape(x, [-1, 4, 4, 256])

    x = _deconv2d(x, 128, 5, 2, 'g_dconv2')
    x = _batch_norm(x, is_training, 'g_bn2')
    x = _relu(x)
    
    x = _deconv2d(x, 64, 4, 2, 'g_dconv3')
    x = _batch_norm(x, is_training, 'g_bn3')
    x = _relu(x)

    x = _deconv2d(x, 3, 4, 2, 'g_dconv4')
    x = _tanh(x)

  return x

In [0]:
architectures = {
    "mlpgan": {
        "generator_fn": mlpgan_generator_fn,
        "discriminator_fn": mlpgan_discriminator_fn,
        "generator_loss_fn": tfgan.losses.modified_generator_loss,
        "discriminator_loss_fn": tfgan.losses.modified_discriminator_loss,
        "generator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "discriminator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "estimator_class": tfgan.estimator.TPUGANEstimator,
        "noise_dim": 128
    },
    "dcgan": {
        "generator_fn": dcgan_generator_fn,
        "discriminator_fn": dcgan_discriminator_fn,
        "generator_loss_fn": tfgan.losses.modified_generator_loss,
        "discriminator_loss_fn": tfgan.losses.modified_discriminator_loss,
        "generator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "discriminator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "estimator_class": tfgan.estimator.TPUGANEstimator,
        "noise_dim": 128
    }
}

In [0]:
if ARCHITECTURE not in architectures.keys():
  print("Architecture {} not found".format(ARCHITECTURE))

In [0]:
tf.logging.set_verbosity(tf.logging.ERROR)

time = datetime.now().strftime('%d%m%Y_%H%M%S')

cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver()
master = cluster_resolver.get_master()
  
model_name = "experiments_{}_{}".format(ARCHITECTURE, time)
model_dir = "gs://{}/models_dir/{}".format(BUCKET_NAME, model_name)

batches_per_epoch = int(TRAIN_SET_SIZE / TRAIN_BATCH_SIZE)
steps_per_epoch = int(batches_per_epoch) * 2
iterations_per_loop = int(batches_per_epoch) * 2 # 2 epochs per iteration
iterations_steps = iterations_per_loop

my_tpu_run_config = tf.contrib.tpu.RunConfig(
    master=master,
    evaluation_master=master,
    model_dir=model_dir,
    session_config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=True),
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=iterations_per_loop),
    save_summary_steps=batches_per_epoch, # save summary every TPU iteration
    save_checkpoints_steps=None, # see below
    save_checkpoints_secs=3600, # save checkpoints every one hour
    keep_checkpoint_max=None        
) 

selected_architecture_config = architectures[ARCHITECTURE]
my_gan_estimator = selected_architecture_config['estimator_class'](
  generator_fn=selected_architecture_config['generator_fn'],
  discriminator_fn=selected_architecture_config['discriminator_fn'],
  generator_loss_fn=selected_architecture_config['generator_loss_fn'],
  discriminator_loss_fn=selected_architecture_config['discriminator_loss_fn'],
  generator_optimizer=selected_architecture_config['generator_optimizer'],
  discriminator_optimizer=selected_architecture_config['discriminator_optimizer'],
  train_batch_size=TRAIN_BATCH_SIZE,
  joint_train=False,
  config=my_tpu_run_config,
  use_tpu=True,
  params={"noise_dim": selected_architecture_config['noise_dim']},
  # EVAL
  eval_on_tpu=True,
  eval_batch_size=EVAL_BATCH_SIZE,
  # PREDICT
  predict_batch_size=PREDICT_BATCH_SIZE)

In [0]:
PREVIEW = False

epochs_total = 1000

# create dir for generated images
tf.gfile.MakeDirs(os.path.join(model_dir, "generated_images"))

# eval_predict_epochs = [] # predict only after epochs_total
eval_predict_epochs = [10, 20, 50, 100, 200, 400, 500, 600, 700, 800, 900] # predict every epochs listed on the left

# add epochs_total
eval_predict_epochs.append(epochs_total)

if PREVIEW:
  eval_predict_epochs = [epochs_total]

dpi = 80.0
stop = False
for to_epoch in eval_predict_epochs:
  if to_epoch >= epochs_total:
    to_epoch = epochs_total
    stop = True
    
  my_gan_estimator.train(train_input_with_noise_fn, max_steps=to_epoch * steps_per_epoch)
  
  if not PREVIEW:
    evaluate_ret = my_gan_estimator.evaluate(test_input_with_noise_fn, steps=1)
    evaluate_ret['epoch'] = int(evaluate_ret['global_step'] / float(steps_per_epoch))
    print(evaluate_ret)
  
  generated_iter = my_gan_estimator.predict(input_fn=noise_input_fn)
  images = [((p['generated_data'][:, :, :] + 1.0) / 2.0) for p in generated_iter]
  assert len(images) == PREDICT_BATCH_SIZE
  image_rows = [np.concatenate(images[i:i+8], axis=0)
                for i in range(0, PREDICT_BATCH_SIZE, 8)]
  tiled_image = np.concatenate(image_rows, axis=1)
  
  shape = tiled_image.shape

  fig, ax = plt.subplots(figsize=(shape[1]/float(dpi), shape[0]/float(dpi)), dpi=dpi, frameon=False)
  ax.imshow(tiled_image, extent=(0, 1, 1 ,0))
  ax.set_xticks([])  # remove xticks
  ax.set_yticks([])  # remove yticks
  ax.axis('off')     # hide axis
  fig.subplots_adjust(bottom=0, top=1, left=0, right=1, wspace=0, hspace=0)  # streches the image and removes margins
  plt.show()
  
  if not PREVIEW:
    step_string = str(to_epoch).zfill(5)
    file_obj = tf.gfile.Open(os.path.join(model_dir, "generated_images", "{}.jpeg".format(step_string)), 'w')
    fig.savefig(file_obj, quality=100, dpi=int(dpi))
    plt.close(fig)
    file_obj.close()
  else:
    tf.gfile.DeleteRecursively(model_dir)
  
  if stop:
    break