In [0]:
# VARIABLES WHICH MUST BE SET
BUCKET_NAME = "" # please do create that Storage Bucket on Google Cloud Platform and make it public, so no additional permissions are required
ARCHITECTURE = "dcgan" # either "mlpgan" or "dcgan"
model_name = "experiments_dcgan_17042019_142218" # model directory name from you BUCKET_NAME/models_dir directory
from_step = 116000 # from which step checkpoint should be use in prediction

In [0]:
!pip -q install tensorflow-gan scikit-image

In [0]:
from datetime import datetime
import math
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
%matplotlib inline  
import numpy as np
import os
import pprint
from PIL import Image
import re
import requests
import tensorflow as tf
import tensorflow_gan as tfgan
import xml.etree.ElementTree as ET

DATASET_SIZE = 60000
TRAIN_BATCH_SIZE_PER_TPU = 128
TPU_CORES = 8
PREDICT_BATCH_SIZE = 128

TRAIN_BATCH_SIZE = TRAIN_BATCH_SIZE_PER_TPU * TPU_CORES # 128 examples per tpu core, 8 tpu cores, so 1024 examples overall

TRAIN_SET_SIZE = DATASET_SIZE // TRAIN_BATCH_SIZE * TRAIN_BATCH_SIZE
assert TRAIN_SET_SIZE < 60000, "Train set size can't be bigger than whole dataset"

EVAL_BATCH_SIZE = TRAIN_SET_SIZE

In [0]:
print("Dataset size: {}".format(DATASET_SIZE))
print("Train set size: {}".format(TRAIN_SET_SIZE))
print("Train batch size: {}".format(TRAIN_BATCH_SIZE))
print("Test set size: {}".format(TRAIN_SET_SIZE))
print("Eval batch size: {}".format(EVAL_BATCH_SIZE))

In [0]:
if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', tpu_address)

  with tf.Session(tpu_address) as session:
    devices = session.list_devices()
    
  print('TPU devices:')
  pprint.pprint(devices)

In [0]:
dataset = tf.keras.datasets.cifar10.load_data()
x = np.concatenate([dataset[0][0], dataset[1][0]], axis=0)

In [0]:
def noise_input_fn(params):
  batch_size = params["batch_size"]
  noise_dim = params["noise_dim"]
    
  np.random.seed(0)
  return tf.data.Dataset.from_tensors(tf.constant(
      np.random.randn(batch_size, noise_dim), dtype=tf.float32))

In [0]:
def _leaky_relu(x):
  return tf.nn.leaky_relu(x, alpha=0.2)

def _relu(x):
  return tf.nn.relu(x)

def _tanh(x):
  return tf.tanh(x)

def _batch_norm(x, is_training, name):
  return tf.layers.batch_normalization(
      x, momentum=0.9, epsilon=1e-5, training=is_training, name=name)


def _dense(x, channels, name):
  return tf.layers.dense(
      x, channels,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _conv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _deconv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d_transpose(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)

def _flatten(x):
  return tf.layers.flatten(x)

def _dropout(x, is_training):
  return tf.layers.dropout(x, rate=0.25, training=is_training)

In [0]:
# MLPGAN
def mlpgan_discriminator_fn(images, unused_conditioning, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
    x = _flatten(images)
    
    x = _dense(x, 256, 'd_fc1')
    x = _dropout(x, is_training)
    x = _tanh(x)

    x = _dense(x, 1, 'd_fc2')
    
  return x

def mlpgan_generator_fn(random_noise, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
    x = _dense(random_noise, 256, 'g_fc1')
    x = _dropout(x, is_training)
    x = _tanh(x)
    
    x = _dense(x, 32 * 32 * 3, 'g_fc2')

    x = tf.reshape(x, [-1, 32, 32, 3])
    x = _tanh(x)

  return x

In [0]:
# DCGAN
def dcgan_discriminator_fn(images, unused_conditioning, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
    x = _conv2d(images, 64, 5, 2, 'd_conv1')
    x = _leaky_relu(x)

    x = _conv2d(x, 128, 5, 2, 'd_conv2')
    x = _batch_norm(x, is_training, 'd_bn2')
    x = _leaky_relu(x)
    
    x = _conv2d(x, 256, 5, 2, 'd_conv3')
    x = _batch_norm(x, is_training, 'd_bn3')
    x = _leaky_relu(x)

    x = tf.reshape(x, [-1, 4 * 4 * 256])

    x = _dense(x, 1, 'd_fc4')
    
  return x

def dcgan_generator_fn(random_noise, mode):
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  
  with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
    x = _dense(random_noise, 4096, 'g_fc1')
    x = _batch_norm(x, is_training, 'g_bn1')
    x = _relu(x)

    x = tf.reshape(x, [-1, 4, 4, 256])

    x = _deconv2d(x, 128, 5, 2, 'g_dconv2')
    x = _batch_norm(x, is_training, 'g_bn2')
    x = _relu(x)
    
    x = _deconv2d(x, 64, 4, 2, 'g_dconv3')
    x = _batch_norm(x, is_training, 'g_bn3')
    x = _relu(x)

    x = _deconv2d(x, 3, 4, 2, 'g_dconv4')
    x = _tanh(x)

  return x

In [0]:
architectures = {
    "mlpgan": {
        "generator_fn": mlpgan_generator_fn,
        "discriminator_fn": mlpgan_discriminator_fn,
        "generator_loss_fn": tfgan.losses.modified_generator_loss,
        "discriminator_loss_fn": tfgan.losses.modified_discriminator_loss,
        "generator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "discriminator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "estimator_class": tfgan.estimator.TPUGANEstimator,
        "noise_dim": 128
    },
    "dcgan": {
        "generator_fn": dcgan_generator_fn,
        "discriminator_fn": dcgan_discriminator_fn,
        "generator_loss_fn": tfgan.losses.modified_generator_loss,
        "discriminator_loss_fn": tfgan.losses.modified_discriminator_loss,
        "generator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "discriminator_optimizer": tf.train.AdamOptimizer(0.0002, 0.5),
        "estimator_class": tfgan.estimator.TPUGANEstimator,
        "noise_dim": 128
    }
}

In [0]:
if ARCHITECTURE not in architectures.keys():
  print("Architecture {} not found".format(ARCHITECTURE))

In [0]:
tf.logging.set_verbosity(tf.logging.ERROR)

cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver()
master = cluster_resolver.get_master()
  
model_dir = "gs://{}/models_dir/{}".format(BUCKET_NAME, model_name)

batches_per_epoch = int(TRAIN_SET_SIZE / TRAIN_BATCH_SIZE)
steps_per_epoch = int(batches_per_epoch) * 2
iterations_per_loop = int(batches_per_epoch) * 2 # 2 epochs per iteration
iterations_steps = iterations_per_loop

my_tpu_run_config = tf.contrib.tpu.RunConfig(
    master=master,
    evaluation_master=master,
    model_dir=model_dir,
    session_config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=True),
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=iterations_per_loop),
    save_summary_steps=batches_per_epoch, # save summary every TPU iteration
    save_checkpoints_steps=None, # see below
    save_checkpoints_secs=3600, # save checkpoints every one hour
    keep_checkpoint_max=None        
) 

selected_architecture_config = architectures[ARCHITECTURE]
my_gan_estimator = selected_architecture_config['estimator_class'](
  generator_fn=selected_architecture_config['generator_fn'],
  discriminator_fn=selected_architecture_config['discriminator_fn'],
  generator_loss_fn=selected_architecture_config['generator_loss_fn'],
  discriminator_loss_fn=selected_architecture_config['discriminator_loss_fn'],
  generator_optimizer=selected_architecture_config['generator_optimizer'],
  discriminator_optimizer=selected_architecture_config['discriminator_optimizer'],
  train_batch_size=TRAIN_BATCH_SIZE,
  joint_train=False,
  config=my_tpu_run_config,
  use_tpu=True,
  params={"noise_dim": selected_architecture_config['noise_dim']},
  # EVAL
  eval_on_tpu=True,
  eval_batch_size=EVAL_BATCH_SIZE,
  # PREDICT
  predict_batch_size=PREDICT_BATCH_SIZE)

In [0]:
import heapq
import itertools
from operator import itemgetter
from skimage import data, img_as_float
from skimage.measure import compare_ssim as ssim
from tqdm import tqdm_notebook as tqdm

In [0]:
image_rows = [np.concatenate(x[i:i+8], axis=0)
              for i in range(0, 64, 8)]
tiled_image = np.concatenate(image_rows, axis=1)

fig, ax = plt.subplots()

ax.imshow(tiled_image)
ax.set_title('training set')

In [0]:
from_epoch = from_step / 2 / 58

generated_iter = my_gan_estimator.predict(input_fn=noise_input_fn,
                                          checkpoint_path=os.path.join(model_dir, 'model.ckpt-{}'.format(from_step)))
images = [((p['generated_data'][:, :, :] + 1.0) / 2.0) for p in generated_iter]

image_rows = [np.concatenate(images[i:i+8], axis=0)
              for i in range(0, 64, 8)]
tiled_image = np.concatenate(image_rows, axis=1)

fig, ax = plt.subplots()

ax.imshow(tiled_image)
ax.set_title('{}, epoch: {}'.format(ARCHITECTURE, from_epoch))

In [0]:

images_new = [np.array(image * 255.0, dtype=np.uint8) for image in images]
all_elements = len(images_new) * len(x)
cartesian_product = list(itertools.product(images_new, x))

x_images_new_sims = []
for element in tqdm(np.random.permutation(len(cartesian_product)), total=all_elements):
  
  similarity = ssim(cartesian_product[element][0], cartesian_product[element][1], multichannel=True) 
  
  if similarity > 0.4:
    x_images_new_sims.append((cartesian_product[element][0], cartesian_product[element][1], similarity))
    
top20 = heapq.nlargest(20, x_images_new_sims, itemgetter(2))

image_rows = [np.concatenate([example[0], example[1]], axis=1)
              for example in top20]
tiled_image = np.concatenate(image_rows, axis=0)

dpi = 80
shape = tiled_image.shape

fig, ax = plt.subplots(figsize=(shape[1]/float(dpi), shape[0]/float(dpi)), dpi=dpi, frameon=False)
ax.imshow(tiled_image, extent=(0, 1, 1 ,0), aspect='auto')
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
fig.subplots_adjust(bottom=0, top=1, left=0, right=1, wspace=0, hspace=0)
plt.show()