# Imports

In [None]:
import os
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from IPython import display
from pathlib import Path

tf.random.set_seed(123)
np.random.seed(123)

print(tf.__version__)

In [None]:
#@title Global Parameters

BUFFER_SIZE = 10000
BATCH_SIZE = 64
EPOCHS = 200
noise_dim = 100
num_examples_to_generate = 16
image_size = 128

# Data

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!cp '/content/gdrive/My Drive/img_align_celeba.zip' 'img_align_celeba.zip'
!unzip -q img_align_celeba.zip
!rm img_align_celeba.zip

In [None]:
#@title Load tf.dataset

@tf.function
def resize_image_keep_aspect(image):
  s = tf.cast(tf.shape(image), tf.float32)
  w, h = s[0], s[1]

  min_dim = tf.minimum(w, h)
  ratio = tf.cast(min_dim / image_size, tf.float32)

  new_width = tf.cast(w / ratio, tf.int32)
  new_height = tf.cast(h / ratio, tf.int32)

  return tf.image.resize(image, [new_width, new_height])

@tf.function
def process_image(image_file):
  # read image from file
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)
  image = tf.cast(image, tf.float32)

  # resize
  image = resize_image_keep_aspect(image)
  
  # random crop
  image = tf.image.random_crop(image, size=[image_size, image_size, 3])

  # normalize [-1, 1]
  image = (image / 127.5) - 1

  return image


train_dataset = tf.data.Dataset.list_files('img_align_celeba/0[0-1]*.jpg').map(process_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# train_dataset = tf.data.Dataset.list_files('images/*.jpg').flat_map(process_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print(train_dataset.element_spec)
print(tf.data.experimental.cardinality(train_dataset).numpy() * BATCH_SIZE)

In [None]:
#@title Utilities

def visualize_imgs(imgs, shape):
  (row, col) = shape
  (height, width) = imgs[0].shape[:2]
  total_img = np.zeros((height * row, width * col, 3))
  for idx, img in enumerate(imgs):
    i, j = idx % col, idx // col
    total_img[j*height:(j+1)*height, i*width:(i+1)*width, :] = img
  return total_img


In [None]:
#@title See Ground Truth

batch = next(train_dataset.prefetch(1).as_numpy_iterator())
batch = batch[:16, :, :, :]
print(np.shape(batch))

img = visualize_imgs(batch, (4, 4)) * 0.5 + 0.5
plt.figure()
plt.imshow(img)
plt.axis('off')
plt.show()
plt.imsave(f'ground_truth.png', img)

In [None]:
#@title Load Pretrained weights

!cp '/content/gdrive/My Drive/plots_imgs_loss.zip' 'plots_imgs_loss.zip'
!unzip -q plots_imgs_loss.zip
!rm plots_imgs_loss.zip

!cp '/content/gdrive/My Drive/train_ckpts.zip' 'train_ckpts.zip'
!unzip -q train_ckpts.zip
!rm train_ckpts.zip

# GAN

Paper: https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

In [None]:
bce = K.losses.BinaryCrossentropy(from_logits=True)

W_INIT = K.initializers.RandomNormal(stddev=0.02)
G_INIT = K.initializers.RandomNormal(1.0, 0.02)

class BaseGAN:
  def __init__(self, load=True): 
    self.G = self.create_generator()
    self.D = self.create_discriminator()

    self.gen_update = K.optimizers.Adam(1e-4, beta_1=0.5)
    self.disc_update = K.optimizers.Adam(1e-4, beta_1=0.5)

    self.ckpt_prefix = f'training_checkpoints/{self.name}'
    self.ckpt = tf.train.Checkpoint(generator_optimizer=self.gen_update,
                                    discriminator_optimizer=self.disc_update,
                                    generator=self.G,
                                    discriminator=self.D)
    
    self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.ckpt_prefix, max_to_keep=3)

    if load: self.load()

    Path(f'./figs/{self.name}').mkdir(parents=True, exist_ok=True)

    self.global_g_loss, self.global_d_loss = [], []

  @property
  def name(self):
    return str(type(self).__name__).lower()

  def create_generator(self):
    raise NotImplementedError

  def create_discriminator(self):
    raise NotImplementedError
    
  def discriminator_loss(self, real_output, fake_output, *args):
    raise NotImplementedError
  
  def generator_loss(self, fake_output):
    raise NotImplementedError

  @tf.function
  def train_step(self, images):
    z = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      real_logits = self.D(images, training=True)
      fake_logits = self.D(self.G(z, training=True), training=True)

      gen_loss = self.generator_loss(fake_logits)
      disc_loss = self.discriminator_loss(real_logits, fake_logits)

    gen_grads = gen_tape.gradient(gen_loss, self.G.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, self.D.trainable_variables)

    self.gen_update.apply_gradients(zip(gen_grads, self.G.trainable_variables))
    self.disc_update.apply_gradients(zip(disc_grads, self.D.trainable_variables))

    return gen_loss, disc_loss

  def train(self, dataset, epochs, start_epoch=0):
    seed = tf.random.normal([num_examples_to_generate, noise_dim])

    for epoch in range(start_epoch, epochs):
      g_loss, d_loss = [], []

      for image_batch in dataset:
        gl, dl = self.train_step(image_batch)
        g_loss.append(gl.numpy())
        d_loss.append(dl.numpy())

      display.clear_output(wait=True)
      
      # Plot epoch loss
      self.plot_epoch_loss(g_loss, d_loss, epoch + 1)
      self.global_g_loss.extend(g_loss), self.global_d_loss.extend(d_loss)

      # Produce images for the GIF as we go
      self.generate_and_save_images(epoch + 1, seed)

      # Save the model
      self.ckpt_manager.save()

    # Generate after the final epoch
    display.clear_output(wait=True)
    self.plot_epoch_loss(self.global_g_loss, self.global_d_loss, 0, 'Global Loss')
    self.generate_and_save_images(epochs, seed)

  def generate_and_save_images(self, epoch, test_input):
    predictions = self.G(test_input, training=False)
    img = visualize_imgs(predictions, (4, 4)) * 0.5 + 0.5

    plt.figure()
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    plt.imsave(f'figs/{self.name}/image_at_epoch_{epoch:04d}.png', img)

  def plot_epoch_loss(self, G_losses, D_losses, epoch, title=''):
    plt.figure(figsize=(6,3))
    t = title if not title == '' else f'Generator and Discriminator Loss - Epoch {epoch}'
    plt.title(t)
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f'figs/{self.name}/loss_at_epoch_{epoch:04d}.png')

  def load(self):
    if self.ckpt_manager.latest_checkpoint:
      self.ckpt.restore(self.ckpt_manager.latest_checkpoint)

  def sample(self, save=False):
    noise = tf.random.normal([1, 100])
    generated_image = self.G(noise, training=False)
    print(f'generated image shape: {np.shape(generated_image)}')

    img = generated_image[0, :, :].numpy() * 0.5 + 0.5

    plt.imshow(img)
    if save: plt.imsave(f'{self.name}_sample.png', img)
    plt.axis('off')
    plt.show()

    decision = self.D(generated_image)
    print(f'decision: {decision}')


In [None]:
class GAN(BaseGAN):
  def create_generator(self):
    dim = image_size // 4 # 7
    x_in = K.layers.Input(shape=(100,))

    x = K.layers.Dense(1024, activation=tf.nn.relu, kernel_initializer=W_INIT)(x_in)
    x = K.layers.Dense(1024, activation=tf.nn.relu, kernel_initializer=W_INIT)(x)
    x = K.layers.Dense(1024, activation=tf.nn.relu, kernel_initializer=W_INIT)(x)
    x = K.layers.Dense(image_size * image_size * 3, activation=tf.nn.tanh, kernel_initializer=W_INIT)(x)
    x_out = K.layers.Reshape((image_size, image_size, 3))(x)
    
    assert x_out.shape[1:] == (image_size, image_size, 3)
    
    return K.Model(x_in, x_out)

  def create_discriminator(self):
    x_in = K.layers.Input(shape=(image_size, image_size, 3))
    # x = K.layers.Reshape((image_size * image_size * 3))(x_in)
    x = K.layers.Flatten()(x_in)
    
    x = K.layers.Dense(256, activation=tf.nn.relu, kernel_initializer=W_INIT)(x)
    x = K.layers.Dense(256, activation=tf.nn.relu, kernel_initializer=W_INIT)(x)
    x = K.layers.Dense(256, activation=tf.nn.relu, kernel_initializer=W_INIT)(x)
    x_out = K.layers.Dense(1, kernel_initializer=W_INIT)(x)

    return K.Model(x_in, x_out)
    
  def discriminator_loss(self, real_output, fake_output):
    real_loss = bce(y_true=tf.ones_like(real_output), y_pred=real_output)
    fake_loss = bce(y_true=tf.zeros_like(fake_output), y_pred=fake_output)
    return real_loss + fake_loss
  
  def generator_loss(self, fake_output):
    return bce(y_true=tf.ones_like(fake_output), y_pred=fake_output)


In [None]:
gan = GAN(load=True)

In [None]:
gan.sample(save=True)

In [None]:
gan.train(train_dataset, EPOCHS)

# DCGAN

Paper: https://arxiv.org/pdf/1511.06434.pdf

In [None]:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://www.tensorflow.org/tutorials/generative/dcgan

class DCGAN(GAN):
  def create_generator(self):
    dim = image_size // 16
    x_in = K.layers.Input(shape=(100,))

    x = K.layers.Dense(dim * dim * 512, use_bias=False)(x_in)
    x = K.layers.Reshape((dim, dim, 512))(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.ReLU()(x) # 8x8x512

    x = K.layers.Conv2DTranspose(256, 5, strides=2, padding='same', use_bias=False, kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.ReLU()(x) # 16x16x256

    x = K.layers.Conv2DTranspose(128, 5, strides=2, padding='same', use_bias=False, kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.ReLU()(x) # 32x32x128

    x = K.layers.Conv2DTranspose(64, 5, strides=2, padding='same', use_bias=False, kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.ReLU()(x) # 64x64x64

    x_out = K.layers.Conv2DTranspose(3, 5, strides=2, padding='same', use_bias=False, 
                                     activation=tf.nn.tanh, kernel_initializer=W_INIT)(x) # 128x128x3
    assert x_out.shape[1:] == (image_size, image_size, 3)
    
    return K.Model(x_in, x_out)

  def create_discriminator(self):
    x_in = K.layers.Input(shape=(image_size, image_size, 3)) # 128x128x3

    x = K.layers.Conv2D(64, 5, strides=2, padding='same', kernel_initializer=W_INIT)(x_in)
    x = K.layers.LeakyReLU(0.2)(x) # 64x64x64

    x = K.layers.Conv2D(128, 5, strides=2, padding='same', kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.LeakyReLU(0.2)(x) # 32x32x128

    x = K.layers.Conv2D(256, 5, strides=2, padding='same', kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.LeakyReLU(0.2)(x) # 16x16x256

    x = K.layers.Conv2D(512, 5, strides=2, padding='same', kernel_initializer=W_INIT)(x)
    x = K.layers.BatchNormalization(gamma_initializer=G_INIT)(x)
    x = K.layers.LeakyReLU(0.2)(x) # 8x8x512

    x = K.layers.Flatten()(x)
    x_out = K.layers.Dense(1, kernel_initializer=W_INIT)(x)

    return K.Model(x_in, x_out)

In [None]:
dcgan = DCGAN(load=True)

In [None]:
dcgan.sample(save=True)

In [None]:
dcgan.train(train_dataset, EPOCHS)

# WGAN

Paper: https://arxiv.org/pdf/1701.07875.pdf

In [None]:
# https://github.com/hcnoh/WGAN-tensorflow2/blob/master/train.py#L78

class WGAN(DCGAN):
  def __init__(self, load=True):
    super().__init__(load)

    # self.gen_update = K.optimizers.Adam(1e-4, beta_1=0.5)
    self.gen_update = K.optimizers.RMSprop(0.00005)
    # self.disc_update = K.optimizers.Adam(1e-4, beta_1=0.5)
    self.disc_update = K.optimizers.RMSprop(0.00005)

  def discriminator_loss(self, real_output, fake_output):
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss
  
  def generator_loss(self, fake_output):
    return -tf.reduce_mean(fake_output)

  @tf.function
  def train_step(self, images):
    z = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      real_logits = self.D(images, training=True)
      fake_logits = self.D(self.G(z, training=True), training=True)

      gen_loss = self.generator_loss(fake_logits)
      disc_loss = self.discriminator_loss(real_logits, fake_logits)

    gen_grads = gen_tape.gradient(gen_loss, self.G.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, self.D.trainable_variables)

    self.gen_update.apply_gradients(zip(gen_grads, self.G.trainable_variables))
    self.disc_update.apply_gradients(zip(disc_grads, self.D.trainable_variables))

    for var in self.D.trainable_variables:
      var.assign(tf.clip_by_value(var, -0.1, 0.1))

    return gen_loss, disc_loss

In [None]:
wgan = WGAN(load=True)

In [None]:
wgan.sample(save=True)

In [None]:
wgan.train(train_dataset, EPOCHS)

# WGAN &mdash; GP

Paper: https://arxiv.org/pdf/1704.00028.pdf

In [None]:
# https://github.com/KUASWoodyLIN/TF2-WGAN/blob/master/utils/losses.py
class WGANGP(DCGAN):
  def discriminator_loss(self, real_output, fake_output, images_real, images_fake):
    images_real = tf.squeeze(images_real)
    images_fake = images_fake[:tf.shape(images_real)[0], :, :]

    def _interpolate(a, b):
        shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1)
        alpha = tf.random.uniform(shape=shape, minval=0., maxval=1.)
        inter = (alpha * a) + ((1 - alpha) * b)
        inter.set_shape(a.shape)
        return inter

    x = _interpolate(images_real, images_fake)
    with tf.GradientTape() as tape:
        tape.watch(x)
        pred = self.D(x)
    grad = tape.gradient(pred, x)
    norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1)
    gp = 10 * tf.reduce_mean((norm - 1.) ** 2)
    
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    
    return fake_loss - real_loss + gp

  def generator_loss(self, fake_output):
    return -tf.reduce_mean(fake_output)

  @tf.function
  def train_step(self, images):
    z = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      images_fake = self.G(z, training=True)
      real_logits = self.D(images, training=True)
      fake_logits = self.D(images_fake, training=True)

      gen_loss = self.generator_loss(fake_logits)
      disc_loss = self.discriminator_loss(real_logits, fake_logits, images, images_fake)

    gen_grads = gen_tape.gradient(gen_loss, self.G.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, self.D.trainable_variables)

    self.gen_update.apply_gradients(zip(gen_grads, self.G.trainable_variables))
    self.disc_update.apply_gradients(zip(disc_grads, self.D.trainable_variables))

    return gen_loss, disc_loss

In [None]:
wgangp = WGANGP(load=True)

In [None]:
wgangp.sample(save=True)

In [None]:
wgangp.train(train_dataset, 100, start_epoch=50) 

# Save Weights

In [None]:
!zip -qr plots_imgs_loss.zip figs
!cp 'plots_imgs_loss.zip' '/content/gdrive/My Drive/plots_imgs_loss.zip'
!rm plots_imgs_loss.zip

!zip -qr train_ckpts.zip training_checkpoints
!cp 'train_ckpts.zip' '/content/gdrive/My Drive/train_ckpts.zip'
!rm train_ckpts.zip

# Create GIF

In [None]:
!pip install -q imageio

import imageio
import glob
from google.colab import files

In [None]:
model = 'wgangp'
anim_file = f'{model}.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob(f'figs/{model}/image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

files.download(anim_file)

# WGAN &mdash; CT

Paper: https://arxiv.org/pdf/1803.01541.pdf