# Imports

In [0]:
%tensorflow_version 1.x
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

tf.set_random_seed(123)
np.random.seed(123)

print(tf.__version__)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Data

In [0]:
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = (x_train - 127.5) / 127.5

# Utilities

In [0]:
def batch_generator(X, batch_size=32, shuffle=True):
  if not shuffle:
      perm = np.array(range(len(X)))
  else:
      perm = np.random.permutation(len(X))

  for start in range(0, len(X), batch_size):
      end = start + batch_size if start + batch_size < len(X) else len(X)

      yield X[perm[start:end]]

def visualize_imgs(imgs, shape):
  (row, col) = shape
  (height, width) = imgs[0].shape[:2]
  total_img = np.zeros((height * row, width * col))
  for idx, img in enumerate(imgs):
      i, j = idx % col, int(idx / col)
      total_img[j*height:(j+1)*height, i*width:(i+1)*width] = img
  return total_img

def plot_img(img, title, prefix=''):
  plt.imshow(img, cmap='gray')
  plt.axis('off')
  plt.title(title)
  plt.show()
  plt.imsave(f'figs/{prefix}{title}.png', img, cmap='gray')

# GAN

Paper: https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

In [0]:
class GAN:
  def __init__(self, image_size, z_dim=128, lr=1e-4):
    self.image_size = image_size
    self.z_dim = z_dim
    self.lr = lr
    self.disc_epoch_loss, self.gen_epoch_loss = [], []

    self.use_bn = True
    self.build_model()

  def build_model(self):
    tf.reset_default_graph()

    self.image_real = tf.placeholder(
        tf.float32, [None, self.image_size, self.image_size])
    
    self.image_fake = self.generator()

    logits_real = self.discriminator(self.image_real)
    logits_fake = self.discriminator(self.image_fake)

    # discriminator loss
    loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=logits_real, labels=tf.ones_like(logits_real)
        )
    )

    loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=logits_fake, labels=tf.zeros_like(logits_fake)
        )
    )

    self.disc_loss = loss_real + loss_fake

    # generator loss
    self.gen_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits=logits_fake, labels=tf.ones_like(logits_fake)
        )
    )

    # optimizer
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
    self.disc_update = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(self.disc_loss, var_list=disc_vars)
    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    self.gen_update = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(self.gen_loss, var_list=gen_vars)

    # weight clipping
    self.weight_clip = None

  def generator(self):
    relu = tf.nn.relu
    linear = tf.layers.dense

    with tf.variable_scope('generator'):
      self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim])

      x = self.z
      x = relu(linear(x, 1024))
      x = relu(linear(x, 1024))
      x = tf.nn.tanh(linear(x, self.image_size * self.image_size))

      return tf.reshape(x, [-1, self.image_size, self.image_size])

  def discriminator(self, input_image):
    relu = tf.nn.relu
    linear = tf.layers.dense

    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
      x = tf.reshape(input_image, [-1, self.image_size * self.image_size])
      x = relu(linear(x, 256))
      x = relu(linear(x, 256))
      x = linear(x, 1)

      return x

  def train(self, sess, X, nepochs=50, batch_size=32, ncritic=1, nsamples=100, show_samples=True):
    sess.run(tf.global_variables_initializer())

    counter = 0
    logscale = iter(np.unique(np.logspace(0, np.log10(nepochs), num=np.sqrt(nepochs).astype(np.int) + 1).astype(np.int)))
    x = next(logscale, None)

    for epoch in tqdm(range(nepochs), desc='Train loop'):
      disc_avg_loss, gen_avg_loss = 0, 0

      for batch in batch_generator(X):
        z_batch = np.random.uniform(-1.0, 1.0, size=(batch_size, self.z_dim)).astype(np.float32)

        # train discriminator
        if (counter % (ncritic + 1)) != 0:
          disc_loss, _ = sess.run([self.disc_loss, self.disc_update], 
                                  feed_dict={ self.image_real: batch, self.z: z_batch })

          if self.weight_clip is not None:
            sess.run(self.weight_clip)

          disc_avg_loss = np.mean([disc_avg_loss, disc_loss])

        # train generator
        else:
          gen_loss, _ = sess.run([self.gen_loss, self.gen_update], feed_dict={ self.z: z_batch })
          gen_avg_loss = np.mean([gen_avg_loss, gen_loss])

        counter += 1
      
      self.disc_epoch_loss.append(disc_avg_loss)
      self.gen_epoch_loss.append(gen_avg_loss)

      # sample images from generator
      if (epoch + 1) == x and show_samples:
        x = next(logscale, None)
        z_sample = np.random.uniform(-1.0, 1.0, size=(nsamples, self.z_dim))
        sample_images = sess.run(self.image_fake, feed_dict={ self.z: z_sample })
        img = visualize_imgs(sample_images, shape=(10,10))
        plot_img(img, f'Epoch {epoch + 1}')

  def plot_losses(self):
    plt.plot(range(len(self.disc_epoch_loss)), self.disc_epoch_loss, color = 'blue', label = 'D')
    plt.plot(range(len(self.gen_epoch_loss)), self.gen_epoch_loss, color = 'red', label = 'G')
    plt.legend(loc="upper right")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training loss of D & G')
    plt.savefig('gan_losses.png', dpi=300)
    plt.show()

  def save(self, sess):
    saver = tf.train.Saver()
    saver.save(sess, f'{self.__class__.__name__}')

  def load(self, sess):
    saver = tf.train.Saver()
    saver.restore(sess, f'{self.__class__.__name__}')


In [0]:
gan = GAN(image_size=28, lr=2e-4)

sess = tf.Session()
gan.train(sess, x_train, nepochs=500)
gan.plot_losses()
gan.save(sess)

# DCGAN

Paper: https://arxiv.org/pdf/1511.06434.pdf

In [0]:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://www.tensorflow.org/tutorials/generative/dcgan

class DCGAN(GAN):
  def generator(self):
    relu = tf.nn.relu
    deconv2d = tf.layers.conv2d_transpose
    bn = tf.layers.batch_normalization if self.use_bn else lambda x, training: x
    linear = tf.layers.dense
    dim = self.image_size // 4

    with tf.variable_scope('generator'):
      self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim])

      x = self.z
      x = relu(bn(linear(x, dim * dim * 128, use_bias=False), training=True))
      x = tf.reshape(x, [-1, dim, dim, 128])
      x = relu(bn(deconv2d(x, 128, kernel_size=5, strides=1, padding='same', use_bias=False), training=True))
      x = relu(bn(deconv2d(x, 64, kernel_size=5, strides=2, padding='same', use_bias=False), training=True))
      x = tf.nn.tanh(deconv2d(x, 1, kernel_size=5, strides=2, padding='same', use_bias=False))

      return tf.reshape(x, [-1, self.image_size, self.image_size])

  def discriminator(self, input_image):
    lrelu = tf.nn.leaky_relu
    conv2d = tf.layers.conv2d
    linear = tf.layers.dense

    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
      x = tf.reshape(input_image, [-1, self.image_size, self.image_size, 1])
      x = lrelu(conv2d(x, 64, kernel_size=5, strides=2, padding='same'))
      x = lrelu(conv2d(x, 128, kernel_size=5, strides=2, padding='same'))
      x = tf.layers.flatten(x)
      x = linear(x, 1)

      return x

In [0]:
gan = DCGAN(image_size=28, lr=2e-4)

sess = tf.Session()
gan.train(sess, x_train, nepochs=500)
gan.plot_losses()

# WGAN

Paper: https://arxiv.org/pdf/1701.07875.pdf

In [0]:
class WGAN(DCGAN):
  def build_model(self):
    tf.reset_default_graph()
    
    self.image_real = tf.placeholder(
        tf.float32, [None, self.image_size, self.image_size])
    
    self.image_fake = self.generator()

    logits_real = self.discriminator(self.image_real)
    logits_fake = self.discriminator(self.image_fake)

    # discriminator loss
    loss_real = tf.reduce_mean(logits_real)
    loss_fake = tf.reduce_mean(logits_fake)

    self.disc_loss = loss_fake - loss_real

    # generator loss
    self.gen_loss = -tf.reduce_mean(logits_fake)

    # optimizer
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
    self.disc_update = tf.train.RMSPropOptimizer(self.lr).minimize(self.disc_loss, var_list=disc_vars)
    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    self.gen_update = tf.train.RMSPropOptimizer(self.lr).minimize(self.gen_loss, var_list=gen_vars)

    # weight clipping
    clip_ops = []
    for var in disc_vars:
      clip_ops.append(var.assign(tf.clip_by_value(var, -0.01, 0.01)))
    self.weight_clip = tf.group(*clip_ops)

In [0]:
gan = WGAN(image_size=28, lr=2e-4)

sess = tf.Session()
gan.train(sess, x_train, nepochs=500, ncritic=1)
gan.plot_losses()

# WGAN &mdash; GP

Paper: https://arxiv.org/pdf/1704.00028.pdf

In [0]:
class WGANGP(DCGAN):
  def __init__(self, image_size, z_dim=128, lr=1e-4):
    super().__init__(image_size=image_size, z_dim=z_dim, lr=lr)

    self.use_bn = False
    self.build_model()

  def build_model(self):
    tf.reset_default_graph()

    self.image_real = tf.placeholder(
        tf.float32, [None, self.image_size, self.image_size])
    
    self.image_fake = self.generator()

    logits_real = self.discriminator(self.image_real)
    logits_fake = self.discriminator(self.image_fake)

    # gradient penalty
    batch_size = 32
    alpha = tf.random_uniform([batch_size, self.image_size, self.image_size], 0.0, 1.0)
    distribution = alpha * self.image_real + (1 - alpha) * self.image_fake
    # add `1e-18` to prevent zero-norm vector (https://github.com/tensorflow/tensorflow/issues/12071)
    gradients = tf.gradients(tf.nn.sigmoid(self.discriminator(distribution)) + 1e-18, [distribution])[0] + 1e-18
    grad_pen = 10 * tf.reduce_mean(tf.square(tf.norm(gradients) - 1.0))

    # discriminator loss
    loss_real = tf.reduce_mean(logits_real)
    loss_fake = tf.reduce_mean(logits_fake)

    self.disc_loss = loss_fake - loss_real + grad_pen

    # generator loss
    self.gen_loss = -tf.reduce_mean(logits_fake)

    # optimizer
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
    self.disc_update = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.9).minimize(self.disc_loss, var_list=disc_vars)
    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    self.gen_update = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.9).minimize(self.gen_loss, var_list=gen_vars)

    # weight clipping
    self.weight_clip = None
  

In [0]:
gan = WGANGP(image_size=28, lr=2e-4)

sess = tf.Session()
gan.train(sess, x_train, nepochs=500)
gan.plot_losses()

# WGAN &mdash; CT

Paper: https://arxiv.org/pdf/1803.01541.pdf