<a href="https://colab.research.google.com/github/HSE-LAMBDA/MLDM-2021/blob/master/11-gans/MLDM_2021_seminar11_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN example

Today we'll try to generate people's faces.
As always, let's start with the imports:

In [None]:
import numpy as np
import tensorflow as tf
print(tf.__version__)
import tensorflow_datasets as tfds
from tqdm import tqdm

from PIL import Image

from tensorflow import keras
from tensorflow.keras import layers

from IPython.display import clear_output

And now we'll get the dataset:

In [None]:
lfw = tfds.image_classification.LFW()
lfw.download_and_prepare()
ds = lfw.as_dataset()

Original images are a bit too large for this exercise - we want to keep it lightweight (although feel free to try different image sizes for the homework if you want).

In [None]:
def get_img(x):
  return x['image'][80:-80,80:-80]

data = np.array([
  np.array(Image.fromarray(img.numpy()).resize((36, 36)))
  for img in tqdm(ds['train'].map(get_img))
])

Let's have a look at the result:

In [None]:
data.shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(data[:25].reshape(5, 5, 36, 36, 3).transpose((0, 2, 1, 3, 4)).reshape(5 * 36, 5 * 36, 3));

Ok, now let's build our GAN!

First we'll preprocess the data:

In [None]:
X_train = data.astype('float32') / 255
print(X_train.min(), X_train.max(), X_train.dtype)

Defining the architecture. 

In [None]:
ll = tf.keras.layers

LATENT_DIM = 32

generator = tf.keras.Sequential([
  ll.Dense(32, input_shape=(LATENT_DIM,), activation='relu'),
  ll.Dense(64, activation='relu'),
  ll.Dense(36 * 36 * 3, activation='sigmoid'),
  ll.Reshape((36, 36, 3)),
])

discriminator = tf.keras.Sequential([
  ll.Reshape((36 * 36 * 3,), input_shape=(36, 36, 3)),
  ll.Dense(64, activation='relu'),
  ll.Dropout(0.1),
  ll.Dense(32, activation='relu'),
  ll.Dropout(0.1),
  ll.Dense(1),
])

generator.summary()
discriminator.summary()

Here we'll define our loss functions and optimization steps. Implement all the parts below. (3 points)

In [None]:
def gen_images(num):
  return generator(tf.random.normal(shape=(num, LATENT_DIM)))

# @tf.function decorator below compiles the function
# it decorates into a static graph. This improves the performance
# but there are some pitfalls one should be aware of when using it,
# check out https://www.tensorflow.org/guide/function
# for more details
@tf.function
def forward(batch):
  real = batch
  fake = gen_images(len(batch))

  shape = (len(batch), 1)
  labels_real = tf.ones (shape=shape)
  ### Optional regularization technique:
  ### set small amount of the 'real' labels
  ### to being 'fake':
  # labels_real = tf.cast(
  #     tf.random.uniform(shape=shape) > 0.1,
  #     'float32'
  # )
  labels_fake = tf.zeros(shape=shape)

  X = tf.concat([real, fake], axis=0)
  y = tf.concat([labels_real, labels_fake], axis=0)

  # Note: it's important to call the discriminator with `training=True`
  #       to make use of the Dropout layers.
  loss = <YOUR CODE>
  return loss

opt_d = tf.optimizers.RMSprop()
opt_g = tf.optimizers.RMSprop()

@tf.function
def disc_step(batch):
  with tf.GradientTape() as t:
    d_loss = <YOUR CODE>
  grads = <YOUR CODE>
  opt_d.apply_gradients(<YOUR CODE>)
  return d_loss

@tf.function
def gen_step(batch):
  with tf.GradientTape() as t:
    g_loss = <YOUR CODE>
  grads = <YOUR CODE>
  opt_g.apply_gradients(<YOUR CODE>)
  return g_loss

Finally, let's write our training loop:

In [None]:
from IPython.display import clear_output
from tqdm import trange

In [None]:
BATCH_SIZE = 256

N_EPOCHS = 25
NUM_DISC_STEPS = 5

i_disc_step = 0
losses_gen = []
losses_disc = []
for i_ep in range(N_EPOCHS):
  shuffle_ids = np.random.choice(len(X_train), len(X_train), replace=False)
  epoch_loss_gen = []
  epoch_loss_disc = []
  for i_img in trange(0, len(X_train), BATCH_SIZE):
    batch = X_train[shuffle_ids][i_img:i_img + BATCH_SIZE]

    if i_disc_step < NUM_DISC_STEPS:
      # discriminator update
      i_disc_step += 1
      epoch_loss_disc.append(disc_step(batch).numpy())
    else:
      # generator update
      i_disc_step = 0
      epoch_loss_gen.append(gen_step(batch).numpy())

  losses_gen.append(np.mean(epoch_loss_gen))
  losses_disc.append(np.mean(epoch_loss_disc))

  opt_d.learning_rate.assign(opt_d.learning_rate * 0.99)
  opt_g.learning_rate.assign(opt_g.learning_rate * 0.99)

  imgs = (gen_images(25).numpy() * 255).astype('uint8')
  clear_output(wait=True)
  plt.figure(figsize=(12, 7))
  plt.subplot(1, 2, 1)
  plt.imshow(imgs.reshape((5, 5, 36, 36, 3)).transpose(0, 2, 1, 3, 4).reshape(36 * 5, 36 * 5, 3))
  plt.subplot(1, 2, 2)
  plt.plot(losses_gen, label='generator')
  plt.plot(losses_disc, label='discriminator')
  plt.xlabel('epoch')
  plt.ylabel('loss')
  plt.legend()
  plt.show()
  print("Done with epoch #", i_ep)

# Interpolations in the latent space

In [None]:
def plot100(imgs):
  plt.imshow(
      np.array(imgs).reshape((10, 10, 36,36,3)).transpose(0, 2, 1, 3,4).reshape((360,360,3))
  )
  plt.axis('off')

Fix two noise-values, interpolate between them and generate objects, using these values as an input for the generator. Nearly the same we did for AE. (2 points)

In [None]:
num = 10

# Fix some noise values with corresponding shapes:
representation_1 = <YOUR CODE>
representation_2 = <YOUR CODE>

# Now create a matrix of linear interpolations between
# the two representations:
w = np.linspace(0, 1, 10)[None,:,None]
representation_mixed = representation_1[:,None] * (1 - w) + representation_2[:,None] * w

# Then generate the images from the mixed representations:
mixed_imgs = <YOUR CODE>

plt.figure(figsize=(6, 6), dpi=100)
plot100(mixed_imgs)

# Conditional GAN

A simple GAN that we built didn't let us control any parameters (e.g. hair color, gender) of the samples we generated. 
Let's pick up a classical dataset that we have alredy worked with - MNIST. In case we would create and fit GAN the same way, it wouldn't let us choose the class of digits we're generating. To be able to control what we generate, we need to somehow plug [*conditions*](https://arxiv.org/abs/1411.1784) into our model.

In this example, we'll build a Conditional GAN that can generate MNIST handwritten digits conditioned on a given class. We'll create a simle FC-model just to cover some basics, but you may try to replace it with CNN model and play around some parameters like number of filters/layers, learning rates, etc, or check out [Keras Conditional GAN example](https://keras.io/examples/generative/conditional_gan/) that is used as a reference for the current notebook. 

In [None]:
batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128

Let's load and preprocess the MNIST dataset.

In [None]:
# We'll use all the available examples from both the training and test
# sets.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])

# Scale the pixel values to [0, 1] range, add a channel dimension to
# the images, and one-hot encode the labels.
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = keras.utils.to_categorical(all_labels, 10)

# Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

print(f"Shape of training images: {all_digits.shape}")
print(f"Shape of training labels: {all_labels.shape}")

In a common unconditional GAN building procedure, we start by sampling noise (of some fixed dimension) from a normal distribution. In conditional situation, we also need to account for the class labels or other conditions that we have. Try to figure out the proper shape for the parameters bellow. (1 point)

In [None]:
generator_in_channels = <YOUR CODE>
discriminator_in_channels = <YOUR CODE>
print(generator_in_channels, discriminator_in_channels)

In [None]:
# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((discriminator_in_channels)),
        layers.Reshape((image_size * image_size + discriminator_in_channels,)),
        layers.Dense(32, activation="elu"),
        layers.Dense(32, activation="elu"),
        layers.Dense(16, activation="elu"),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        layers.Dense(64, activation="elu"),
        layers.Dense(64, activation="elu"),
        layers.Dense(64, activation="elu"),
        layers.Dense(image_size * image_size),
        layers.Reshape((image_size, image_size, num_channels)),
    ],
    name="generator",
)

Let's define the class for our model.

In [None]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        #https://www.tensorflow.org/guide/extension_type
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def generate_images(self, one_hot_labels):
        # Generate noise and concat it with the conditions
        random_latent_vectors = tf.random.normal(
            shape=(one_hot_labels.shape[0],latent_dim) )
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=-1
        )
        return self.generator(random_vector_labels)

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :,]

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=-1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = tf.reshape(self.generator(random_vector_labels),(-1, image_size * image_size))

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([tf.reshape(real_images,(-1, image_size * image_size)), image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=-1
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = tf.reshape(self.generator(random_vector_labels),(-1,28*28))
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }


Function that helps us to plot the samples during the fitting procedure.

In [None]:
def plot_mn(images, m=10, n=10, shuffle=False):
    if shuffle:
        images = images[np.random.permutation(len(images))[:m * n]]
    _, h, w, _ = images.shape
    images = images[:m*n].reshape(m, n, *images.shape[1:])
    images = images.transpose(0, 2, 1, 3, 4).reshape(m * h, n * w)
    plt.imshow(images)

Let's create a `PlotImgsCallback` that is going to be triggered at the end of the epoch to plot samples using `plot_mn`.

In [None]:
class PlotImgsCallback(keras.callbacks.Callback):
    def __init__(self, period_in_epochs=1, clear=True):
        super().__init__()
        self.period_in_epochs = period_in_epochs
        labels = np.tile(np.arange(10)[:,None], (1, 10)).reshape(100)
        self.one_hot_labels = keras.utils.to_categorical(labels, 10)
        self.clear = clear

    def on_epoch_end(self, epoch, logs=None):
        if self.clear:
          clear_output(wait=True)
        if (epoch + 1) % self.period_in_epochs == 0:
            plt.figure(figsize=(7, 7))
            plot_mn(
                self.model.generate_images(
                    self.one_hot_labels
                ).numpy().clip(0, 1)
            )
            plt.title('Epoch = ' + str(epoch + 1))
            plt.axis("off")
            plt.show()

In [None]:
discriminator.summary()

In [None]:
generator.summary()

In [None]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)

# Instead of using NUM_DISC_STEPS parameter, we use different learning rates 
# for different optimizers.
cond_gan.compile(
    d_optimizer=keras.optimizers.RMSprop(learning_rate=0.001),
    g_optimizer=keras.optimizers.RMSprop(learning_rate=0.0005),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=10, callbacks=[PlotImgsCallback(clear=False)])

In [None]:
plt.plot(cond_gan.history.history['d_loss'])
plt.plot(cond_gan.history.history['g_loss'])
plt.xlabel('epoch')
plt.legend(['d_loss', 'g_loss'], loc='upper left')
plt.show()