<a href="https://colab.research.google.com/github/NicoEssi/GAN_Synthetic_Medical_Image_Augmentation/blob/master/Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Dependencies
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import PIL

# Data Imports



# Project Description
Generative adversarial networks, or GANs, are a class of machine learning systems that are used to generate new data given a data set, with the same statistical properties. Our work aims to display the effectiveness of new data generation as a form of image augmentation technique to improve the predictive performance of a convolutional neural network. Our work will be demonstrated with and without the synthetic image augmentation, for comparative purposes.

![alt text](https://i.imgur.com/7ouwAEA.png)

# 1. Data Exploration

First we start by inspecting the data, as it may be potentially fruitful to us to try to gain an intuition for what kind of visual features we might find particularly useful.

# 2. Data Preparation

Now we prepare our data. We create small batches of 200 images, belonging to three classes. Each class with different distributions (between 50 to 100 images in each class). Due to this distributional bias, we will present our findings with sensitivity and specificity metrics.

!["recall"](https://i.imgur.com/1tWumNb.png)

We will also look into using classic augmentations to enhance our dataset, including: rotations, flips, translations, and scaling. And note down how our metrics improve using these.

Later on, we will create more data through synthetic means - by utilizing our own built and trained DCGAN, to see whether or not our model improves and by how much if it does.

# 3. CNN Construction
We will implement a fairly simple convolutional neural network, as it is our intention to show that the data will overfit the model and then to show how the synthetic data augmentation will help us improve our model. Below is a visualization of how our network is constructed.

![Our CNN](https://i.imgur.com/30fh1DJ.png)

In [0]:
# PLEASE VERIFY THAT THE MODEL IS BUILT CORRECTLY

model = tf.keras.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(layers.Dense(256, activation = 'relu'))
model.add(layers.Dense(3, activation = 'softmax'))

#4. Training and Results

Now we train our simple CNN with our prepared data, and note down our results for both the non-augmented and augmented dataset.

#5. DCGAN Construction

While the paper itself also implemented the AC-GAN; it is significantly more complex to implement in comparison to the DCGAN - which is the most straightforward model to understand. DCGAN improves upon the original GAN design by incorporating upsampling convolutional layers, amongst other details such as Batch Normalization and Leaky ReLU activations with the slope set to 0.2.

The paper also found that they had significantly better results using the synthetic data from the DCGAN than the AC-GAN. For these reasons, we exclusively build the DCGAN for this particular project.

Important to note that the input to the DCGAN model consists only of individual classes, thus the generated output will only consist of one class. 

## 5.1. Generator Construction

![alt text](https://i.imgur.com/01krVDp.png)

In [0]:
# Note that 'None' in the "assert" is our batch size ;)
def dcgenerator_model():

    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

## 5.2. Discriminator Construction

In [0]:
def discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

## 5.3. Loss Functions

In [0]:
# Loss function for our generator
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Loss function for our discriminator
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [0]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

## 5.4. Defining Training

In [0]:
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

In [0]:
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [0]:
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Might be a good idea to add function that generates an image during training
  # time to time.

## 5.4. Checkpointing

Will be very helpful in case our model starts training for a long time.

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

#6. Training the GAN

In [0]:
train(train_dataset, EPOCHS)

In [0]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# 7. Synthetic Data Generation

# 8. Training CNN with Synthetic Data

# 9. Separability Visualization (optional)

An interesting optional step is to visualize how the synthetic data strenghtens the separability between the three classes, with some dimensionality reduction method such as t-SNE (which is what the authors used). Should look something like the image below, where left is without synthetic data and right is with the synthetic data. Need to emphasize though: this is optional :)

![separability](https://www.henryailabs.com/ArticlePictures/MedicalImageAugmentation-8.jpg)