<a href="https://colab.research.google.com/github/GarlandZhang/gans_in_action_notes/blob/master/cgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Concatenate, Dense, Embedding, Flatten, Input, Multiply, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam

Using TensorFlow backend.


In [2]:
def build_generator(z_dim):
  model = Sequential([
                      Dense(256 * 7 * 7, input_dim=z_dim),
                      Reshape((7, 7, 256)),
                      Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'),
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),
                      Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'),
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),
                      Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'),
                      Activation('tanh')
  ])
  return model

def build_cgan_generator(z_dim):
  z = Input(shape=(z_dim, ))
  label = Input(shape=(1, ), dtype='int32')
  label_embedding = Embedding(num_classes, z_dim, input_length = 1)(label) # " Label embedding: turns labels into dense vectors of size z_dim; produces 3D tensor with shape (batch_size, 1, z_dim)" 
  label_embedding = Flatten()(label_embedding) # Flattens 3D tensor into 2D tensor (batch_size, z_dim)
  joined_representation = Multiply()([z, label_embedding])
  generator = build_generator(z_dim)
  conditioned_img = generator(joined_representation) # generates image for given label
  return Model([z, label], conditioned_img)

In [3]:
def build_discriminator(img_shape):
  combined_shape =  (img_shape[0], img_shape[1], img_shape[2] + 1)
  model = Sequential([
                      Conv2D(64, kernel_size=3, strides=2, input_shape=combined_shape, padding='same'), # 28 x 28 x 2 => 14 x 14 x 64
                      LeakyReLU(alpha=0.01),
                      Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'), # 14 x 14 x 64 => 7 x 7 x 64. we have input_shape=img_shape because...?
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),
                      Conv2D(128, kernel_size=3, strides=2, input_shape=img_shape, padding='same'),
                      BatchNormalization(),
                      LeakyReLU(alpha=0.01),
                      Flatten(),
                      Dense(1, activation='sigmoid')
  ])
  return model

def build_cgan_discriminator(img_shape):
  img = Input(shape=img_shape)
  label = Input(shape=(1, ), dtype='int32')
  label_embedding = Embedding(num_classes, np.prod(img_shape), input_length=1)(label)
  label_embedding = Flatten()(label_embedding)
  label_embedding = Reshape(img_shape)(label_embedding)
  concatenated = Concatenate(axis=-1)([img, label_embedding])
  discriminator = build_discriminator(img_shape)
  classification = discriminator(concatenated)
  return Model([img, label], classification)

In [4]:
def build_cgan(generator, discriminator):
  z = Input(shape=(z_dim, ))
  label = Input(shape=(1, ))
  img = generator([z, label])
  classification = discriminator([img, label])
  return Model([z, label], classification)

In [5]:
accuracies = []
losses = []

def train(iterations, batch_size, sample_interval):
  (x_train, y_train), _ = mnist.load_data()
  x_train = x_train / 127.5 - 1.
  x_train = np.expand_dims(x_train, axis=3)
  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs, labels = x_train[idx], y_train[idx]

    z = np.random.normal(0, 1, (batch_size, z_dim))

    gen_imgs = generator.predict([z, labels])

    disc_real_loss = discriminator.train_on_batch([imgs, labels], real)
    disc_fake_loss = discriminator.train_on_batch([gen_imgs, labels], fake)

    disc_loss = 0.5 * np.add(disc_real_loss, disc_fake_loss)

    z = np.random.normal(0, 1, (batch_size, z_dim))

    labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)

    gen_loss = cgan.train_on_batch([z, labels], real)

    if (iteration + 1) % sample_interval == 0:
      print(f'{iteration + 1} [D loss: {disc_loss[0]}, acc.: {100 * disc_loss[1]}] [G loss: {gen_loss}]')
      losses.append((disc_loss[0], gen_loss))
      accuracies.append(100 * disc_loss[1])
      sample_images()

def sample_images(image_grid_rows=2, image_grid_columns=5):
  z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
  labels = np.arange(0, 10).reshape(-1, 1)
  gen_imgs = generator.predict([z, labels])
  gen_imgs = 0.5 * gen_imgs + 0.5
  fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(10, 4), sharey=True, sharex=True)

  count = 0

  for i in range(image_grid_rows):
    for j in range(image_grid_columns):
      axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      axs[i, j].set_title(f'Digit: {labels[count]}')
      count += 1

In [6]:
img_rows = 28
img_cols = 28
img_channels = 1

img_shape = (img_rows, img_cols, img_channels)

z_dim = 100

num_classes = 10

In [7]:
discriminator = build_cgan_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

generator = build_cgan_generator(z_dim)

discriminator.trainable = False
cgan = build_cgan(generator, discriminator)
cgan.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

In [None]:
iterations = 12000
batch_size = 32
sample_interval = 1000

train(iterations, batch_size, sample_interval)

  'Discrepancy between trainable weights and collected trainable'
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


1000 [D loss: 0.002916002180427313, acc.: 100.0] [G loss: [21.705742, 0.0]]
2000 [D loss: 0.0022667807061225176, acc.: 100.0] [G loss: [6.7411017, 0.0]]
3000 [D loss: 0.1522151529788971, acc.: 93.75] [G loss: [3.205821, 0.0625]]
4000 [D loss: 0.22683537006378174, acc.: 92.1875] [G loss: [0.6782111, 0.5625]]
