<a href="https://colab.research.google.com/github/MMaggieZhou/FunModels/blob/main/GAN-Draw_Anime_Faces/model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Draw Anime Faces With Generative Adversarial Network

- The model uses DCGAN architecture per https://arxiv.org/abs/1511.06434
- Tensorflow is used as the training framework 
- The code isn't very super robust as validations are left to be implemented 

## SET UP


### Connect to Google Drive

In [19]:
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

CHECKPOINT_DIR = '/content/gdrive/MyDrive/Colab Notebooks'

Mounted at /content/gdrive/


### Download the dataset 
The dataset is collected by https://speech.ee.ntu.edu.tw/~hylee/ml/2021-spring.php

In [20]:
DATA_DIR = '.' 
# a pypi package to download large file from google drive 
!gdown --id 1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p -O "{DATA_DIR}/crypko_data.zip"
!unzip -q -o "{DATA_DIR}/crypko_data.zip" -d "{DATA_DIR}/"

Downloading...
From: https://drive.google.com/uc?id=1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p
To: /content/crypko_data.zip
100% 452M/452M [00:01<00:00, 231MB/s]


### Imports 

In [21]:
import os 
import glob

import tensorflow as tf
from tensorflow.keras import layers

import matplotlib.pyplot as plt

## Data Preprocessing 
1. Load Dataset From Directory
2. Resize the imaqe
3. **Normalize Image: it's very very very important that the image data is withint [-1, 1] for neural network!!!**

In [4]:
def load_dataset(directory_path, batch_size, image_size): 
    images = tf.keras.utils.image_dataset_from_directory(
        directory_path, 
        labels=None,
        batch_size=batch_size,
        shuffle=False,
        image_size=image_size
    )
    normalization_layer = tf.keras.layers.Rescaling(2.0/255, offset=-1)
    return images.map(lambda x: normalization_layer(x))

def data_loading_test(): 
    image_batches = load_dataset(DATA_DIR, 64, (64, 64))
    # TODO: 
    # 1.validate dimension is (batch_size, height, width, 3)
    # 2.validate that all values are within [-1, 1]
    # 3.display 16 images 

    plt.figure(figsize=(10, 10))
    data = image_batches[0].take(16).map(lambda image: tf.keras.layers.Rescaling(255/2.0, offset=127.5)(image))
    # TODO: better way of display tensors 
    for i, image in enumerate(data):
        ax = plt.subplot(4, 4, i + 1)
        plt.imshow(image.numpy().astype("uint8"))
        plt.axis("off")

## Define Model
Two Models are defined with Keras layers, aka Generator model and Discriminator model. 

DCGAN key points: 
- Generator consists of convolutional -transpose layers that given a latent vector of smaller dimension, generates a 2D image with larger dimension
- discriminator consists of convolutional layers, takes the large dimension 2D image, convolutes and eventually generate a binary output
- apply batch normalization after each layer, except for output layer for generator and input layer for discriminator. 
- apply random normal distribution for weight initialization convolution(transpose) layers 
- apply ReLU activation for convolution transpose layers and leaky ReLU for convolution layers





### Components for DCGAN tricks

In [5]:
class Clip(tf.keras.constraints.Constraint):
    def __init__(self, clip_val, enable):
        self.clip_val = clip_val
        self.enable = enable

    def __call__(self, w):
        if self.enable:
            return tf.math.minimum(self.clip_val, tf.math.maximum(w, -self.clip_val))
        return w

    def get_config(self):
        return {
            "clip_val": self.clip_val,
            "enable": self.enable
        }

def clip_test():
    weight = tf.constant((-1.0, 1.0))
    Clip(clip_val=0.1, enable=False)(weight)
    Clip(clip_val=0.1, enable=True)(weight)

w_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
gamma_initializer = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02)

def add_dense_layer_for_noise(
    model,
    input_dim, 
    output_dim,
): 
    model.add(layers.Dense(
        units=output_dim, 
        input_shape=(input_dim,), 
        use_bias=False
    ))
    model.add(layers.BatchNormalization(
        gamma_initializer=gamma_initializer
    ))
    model.add(layers.ReLU())

# image size will be doubled 
def add_conv2d_transpose(
    model,
    num_output_filters, 
    add_batch_norm=True
):
    model.add(layers.Conv2DTranspose(
        num_output_filters, 
        5, # filter size
        strides=2, 
        padding='same', 
        use_bias=False, 
        kernel_initializer=w_init,
    ))
    if add_batch_norm:
      model.add(layers.BatchNormalization(
          gamma_initializer=gamma_initializer
      ))
      model.add(layers.ReLU())

# shrink size by half
def add_conv2d_for_input(model, input_dim, num_output_filters, weight_cilp):
    model.add(layers.Conv2D(
      num_output_filters, 
      5, # filter size
      strides=2, 
      padding='same',
      input_shape=[input_dim, input_dim, 3],
      kernel_initializer=w_init,
      kernel_constraint=weight_cilp,
    ))
    model.add(layers.LeakyReLU(0.2))

# shrink size by half
def add_conv2d(
    model, num_output_filters, weight_clip, filter_size=5, 
    use_batch_norm=True, padding='same', stride=2
):
    model.add(layers.Conv2D(
        num_output_filters, 
        filter_size, # filter size
        strides=stride, 
        padding=padding,
        kernel_initializer=w_init,
        kernel_constraint=weight_clip,
    ))
    if use_batch_norm:
      model.add(layers.BatchNormalization(
          gamma_initializer=gamma_initializer, 
          beta_constraint=weight_clip, 
          gamma_constraint=weight_clip
      ))
      model.add(layers.LeakyReLU(0.2))

### Model Architecture

In [6]:
def create_unconditional_generator(
    noise_dim,
    image_dim, # output image
):
    model = tf.keras.Sequential()
    add_dense_layer_for_noise(
        model, input_dim=noise_dim, 
        output_dim=(image_dim * 8) * (image_dim/16) * (image_dim/16)
    )

    model.add(layers.Reshape(
        (int(image_dim/16), int(image_dim/16), image_dim * 8))
    ) # image_dim/16 * image_dim/16 * filters

    add_conv2d_transpose(model, image_dim * 4) # image_dim/8 * image_dim/8 * filters
    add_conv2d_transpose(model, image_dim * 2) # image_dim/4 * image_dim/4 * filters
    add_conv2d_transpose(model, image_dim * 1) # image_dim/2 * image_dim/2 * filters

    add_conv2d_transpose(model, 3, add_batch_norm=False) # image_dim * image_dim * 3
    model.add(layers.Activation("tanh"))
    return model

def create_discriminator(image_dim, enable_clip, clip_value): 
    model = tf.keras.Sequential()
    clip = Clip(clip_val=clip_value, enable=enable_clip)
    add_conv2d_for_input(model, image_dim, image_dim, clip) # (image_dim /2, image_dim /2, image_dim)

    add_conv2d(model, image_dim * 2, clip) # (image_dim /4, image_dim /4, image_dim * 2)
    add_conv2d(model, image_dim * 4, clip) # (image_dim /8, image_dim /8, image_dim * 4)
    add_conv2d(model, image_dim * 8, clip) # (image_dim /16, image_dim /16, image_dim * 8)

    add_conv2d(model, 1, clip, filter_size=int(image_dim/16), use_batch_norm=False, padding='valid', stride=1) # (1, 1, 1)

    model.add(layers.Flatten())

    return model

In [7]:
def validate_generator():
  # TODO: validate layers dimensions
  generator = create_unconditional_generator(100, 64)


def validate_discriminator():
  # TODO
  discriminator = create_discriminator(64)

def test_output_values():
    generator = create_unconditional_generator(100, 64)
    discriminator = create_discriminator(64)
    noise = tf.random.normal([10, 100])
    fake_images = generator(noise, training=True)
    output = discriminator(fake_images, training=True)
    print(tf.reduce_mean(output))
    print(output)

## Train

### Loss Functions

In [8]:
# loss functions 
def discriminator_loss_wasserstein(real_output, fake_output): 
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss_wasserstein(fake_output):
    return -tf.reduce_mean(fake_output)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)

def discriminator_loss_entropy(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss_entropy(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

### Single Training Step

In [9]:
# training step 

# may use @tf.function for optimization, but have to deal with dynamic variable step
def train_step(
    image_batch, batch_size, noise_dim, generator, discriminator, 
    generator_optimizer, discriminator_optimizer, discriminator_loss_func, 
    generator_loss_func, step
):
    # TODO: validate that image_batch size is same as batch_size
    # TODO：fine tune ratio of frequency that generator and discriminator are trained
    metrics = {}
    # 1. update discriminator 
    noise = tf.random.normal([batch_size, noise_dim])
    with tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(image_batch, training=True)
        fake_output = discriminator(generated_images, training=True)
        disc_loss = discriminator_loss_func(real_output, fake_output)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    metrics['D_loss'] = disc_loss
  
    # 2.update generator
    if (step + 1) % 5 == 0:
        noise = tf.random.normal([batch_size, noise_dim])
        with tf.GradientTape() as gen_tape:
            generated_images = generator(noise, training=True)
            fake_output = discriminator(generated_images, training=True)
            gen_loss = generator_loss_func(fake_output)
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        metrics['G_loss'] = gen_loss

    return metrics


### Progress Tracking

In [34]:
from IPython.display import clear_output

class Metrics():
    def __init__(self, epoch, loss_d, loss_g):
        self.x = [i for i in range(1, epoch + 1)]
        self.loss_d = loss_d.tolist()
        self.loss_g = loss_g.tolist()

    def on_epoch_end(self, epoch, metrics):
        self.x.append(epoch)
        self.loss_d.append(float(metrics['D_loss']))
        self.loss_g.append(float(metrics['G_loss']))

        #clear_output(wait=True)
        if epoch % 5 == 0:
            f, (ax1, ax2) = plt.subplots(1, 2, sharex=True)
            ax1.plot(self.x, self.loss_d, label="D_loss")
            ax1.legend()
            ax2.plot(self.x, self.loss_g, label="G_loss")
            ax2.legend()
            plt.show()

    def get_metrics(self):
        return {
            'D_loss': self.loss_d,
            'G_loss': self.loss_g
        }

class ProgressManager:
    def __init__(
        self, 
        output_path, 
        generator, 
        discriminator, 
        generator_optimizer, 
        discriminator_optimizer, 
        num_batches, 
        persist_enabled
    ):
        self.generator = generator
        self.num_batches = num_batches
        self.persist_enabled = persist_enabled

        self.output_path = output_path
        self.samples_path = f'{output_path}/samples'
        if self.persist_enabled and not os.path.exists(self.samples_path):
            os.makedirs(self.samples_path)
        self.checkpoint_path = f'{output_path}/checkpoints'

        self.ckpt = tf.train.Checkpoint(
            epoch=tf.Variable(0),
            generator=generator,
            discriminator=discriminator,
            generator_optimizer=generator_optimizer,
            discriminator_optimizer=discriminator_optimizer,
            generator_loss=tf.Variable([], shape=tf.TensorShape(None)),
            discriminator_loss=tf.Variable([], shape=tf.TensorShape(None)),
        )
        
        self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=20)
        if self.ckpt_manager.latest_checkpoint:
            self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
            print("Restored from {}".format(self.ckpt_manager.latest_checkpoint))

        self.last_epoch = int(self.ckpt.epoch)
        self.metrics = Metrics(
            self.last_epoch, 
            self.ckpt.discriminator_loss.numpy(), 
            self.ckpt.generator_loss.numpy()
        )

    def on_new_epoch(self, epoch):
        print("Epoch: ", epoch)
        self.pbar = tf.keras.utils.Progbar(target=self.num_batches, stateful_metrics=[])

    def on_step_end(self, batch_number, metrics):
        self.pbar.update(batch_number, values=metrics.items(), finalize=False)
        self.last_metrics = metrics

    def on_epoch_end(self, test_noises, test_image_grid_size):
        self.ckpt.epoch.assign_add(1) # epoch should be bumped right after the training of an epoch is finished. as it indicate the epoch reflecting its current values 
        self.last_epoch += 1
        metrics_history = self.metrics.get_metrics()
        self.ckpt.discriminator_loss.assign(metrics_history['D_loss'])
        self.ckpt.generator_loss.assign(metrics_history['G_loss'])
        
        self.pbar.update(self.num_batches, values=self.last_metrics.items(), finalize=True)
        self.metrics.on_epoch_end(self.last_epoch, self.last_metrics)

        self._save_plot(self.generator(test_noises, training=False), test_image_grid_size)
        if self.persist_enabled and self.last_epoch % 10 == 0:
            ckpt_save_path = self.ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(self.last_epoch, ckpt_save_path))

    def _save_plot(self, examples, n):
        examples = (examples + 1) / 2.0 # (0-1 float or 0-255 int) for RGB values
        for i in range(n * n):
            plt.subplot(n, n, i+1)
            plt.axis("off")
            plt.imshow(examples[i])  
        filename = f"{self.samples_path}/generated_plot_epoch-{self.last_epoch}.png"
        if self.persist_enabled:
            plt.savefig(filename)
        plt.show()
        plt.close()

### Resumable Training Process

In [32]:
BATCH_SIZE = 64
IMAGE_DIM = 64
NOISE_DIM = 100
TEST_IMAGE_GRID_SIZE = 4

CLIP_VALUE_FOR_WGAN = 0.03
LEARNING_RATE = 1e-4
NUM_EPOCH = 50


def initiate_objects(use_wgan, training_name, save_model, num_batches):
    if use_wgan:
        generator_loss_func = generator_loss_wasserstein
        discriminator_loss_func = discriminator_loss_wasserstein
        generator_optimizer = tf.keras.optimizers.RMSprop(LEARNING_RATE)
        discriminator_optimizer = tf.keras.optimizers.RMSprop(LEARNING_RATE)
        enable_weight_clip = True
    else:
        generator_loss_func = generator_loss_entropy
        discriminator_loss_func = discriminator_loss_entropy
        generator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=0.5)
        discriminator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=0.5)
        enable_weight_clip = False

    generator = create_unconditional_generator(NOISE_DIM, IMAGE_DIM)
    generator.summary()
    discriminator = create_discriminator(IMAGE_DIM, enable_weight_clip, CLIP_VALUE_FOR_WGAN)
    discriminator.summary()

    model = 'WGAN' if use_wgan else 'DCGAN'
    progress_manager = ProgressManager(
        f'{CHECKPOINT_DIR}/{model}/{training_name}', 
        generator, 
        discriminator, 
        generator_optimizer, 
        discriminator_optimizer,
        num_batches,
        save_model,
    )

    return (
        generator, discriminator, generator_loss_func, discriminator_loss_func, 
        generator_optimizer, discriminator_optimizer, progress_manager
    ) 

def train(training_name, save_model=True, speed_up_epoch=False, use_wgan=True):
    # create/load data
    tf.random.set_seed(2022) # so that we can get the same test noises each time after resuming from checkpoint  
    test_noises = tf.random.normal([TEST_IMAGE_GRID_SIZE ** 2, NOISE_DIM])
    image_batches = load_dataset(f'{DATA_DIR}/faces', BATCH_SIZE, (IMAGE_DIM, IMAGE_DIM))

    # initiate objects for training: models, optimizers, progress manager 
    (
        generator, 
        discriminator, 
        generator_loss_func, 
        discriminator_loss_func, 
        generator_optimizer, 
        discriminator_optimizer,
        progress_manager,
    ) = initiate_objects(use_wgan, training_name, save_model, int(image_batches.cardinality()))

    # actual training cycles
    step = 0
    for epoch in range(progress_manager.last_epoch + 1, progress_manager.last_epoch + NUM_EPOCH + 1):
        progress_manager.on_new_epoch(epoch)
        i = 0 
        for image_batch in image_batches:
            if speed_up_epoch and i == 10:
                break
            progress_manager.on_step_end(i, train_step(
                image_batch=image_batch, 
                batch_size=BATCH_SIZE, 
                noise_dim=NOISE_DIM, 
                generator=generator, 
                discriminator=discriminator, 
                generator_optimizer=generator_optimizer, 
                discriminator_optimizer=discriminator_optimizer, 
                discriminator_loss_func=discriminator_loss_func, 
                generator_loss_func=generator_loss_func,
                step=step,
            ))
            i += 1
            step += 1
        progress_manager.on_epoch_end(test_noises, TEST_IMAGE_GRID_SIZE)

## Run the Process

In [None]:
#train(training_name='11_05_2022_test1', save_model=True, speed_up_epoch=True,)
#train(training_name='11_05_2022_test2', save_model=True, speed_up_epoch=True,)
#train(training_name='11_05_2022', save_model=True, speed_up_epoch=False,)
#train(training_name='11_07_2022_test1', save_model=True, speed_up_epoch=True, use_wgan=False)
train(training_name='11_07_2022', save_model=True, speed_up_epoch=False, use_wgan=False)

Found 71314 files belonging to 1 classes.
Model: "sequential_16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_8 (Dense)             (None, 8192)              819200    
                                                                 
 batch_normalization_55 (Bat  (None, 8192)             32768     
 chNormalization)                                                
                                                                 
 re_lu_32 (ReLU)             (None, 8192)              0         
                                                                 
 reshape_8 (Reshape)         (None, 4, 4, 512)         0         
                                                                 
 conv2d_transpose_32 (Conv2D  (None, 8, 8, 256)        3276800   
 Transpose)                                                      
                                                                 
 batch_norm

In [None]:
var=tf.Variable([],shape=tf.TensorShape(None))
var.assign([1,2])

<tf.Variable 'UnreadVariable' shape=<unknown> dtype=float32, numpy=array([1., 2.], dtype=float32)>