<a href="https://colab.research.google.com/github/AnoushkaVijay/Leukemia_GAN/blob/main/VAE_Image_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
PATH = ""

In [None]:
from tensorflow.keras.utils import image_dataset_from_directory
image_dataset = image_dataset_from_directory(directory = PATH, batch_size = 32,color_mode='rgb', image_size = (64,64), shuffle=True)
# Normalize the pixel values to be between 0 and 1
image_dataset = image_dataset.map(lambda x, y: (x / 255.0, y))

In [None]:
# for the imports
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import plot_model
from keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import LeakyReLU, Dense, Dropout, Input, BatchNormalization
from tensorflow.keras.optimizers import Adam, RMSprop
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import os
import sys

In [None]:
import matplotlib.pyplot as plt
import numpy as np
for image_batch, label_batch in image_dataset:
    plt.imshow(image_batch[0])
    print(np.max(image_batch[0]))
    plt.show()
    print(label_batch[0])
    break

In [None]:
# Set the dimensions of the noise
latent_dim = 128
#Using this class to get the sampling using the gaussian distribution and standard mean to basically create the noise
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
#VAE ENCODER
encoder_inputs = keras.Input(shape=(64, 64, 3))
x = layers.Conv2D(16, 5, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(32, 5, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(64, 5, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(128, 5, activation="relu", strides=2, padding="same")(x)

x = layers.Flatten()(x)

z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
#each layer has that many feature maps layered on top of each other - the third line is 32 32X32 feature maps layered on top of each other
#at the end it flattens it out meaning that the final product is just a one dimensional array of values - take mean and variation then add it to size

In [None]:
#Code for Decoder
#VAE Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(8 * 8 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((8, 8, 64))(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 5, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 5, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 5, activation="sigmoid", padding="same")(x)

decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 64*64
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }


In [None]:
vae_optimizer = keras.optimizers.Adam()
# learning_rate=0.00001
epochs = 500
batch_size = 15


# Create the vae model
vae = VAE(encoder, decoder)
vae.compile(optimizer=vae_optimizer)

In [None]:
BASE_DIR = ''

In [None]:
vae_checkpoint_dir = os.path.join(BASE_DIR, 'VAE_checkpoints')
vae_checkpoint_prefix = os.path.join(vae_checkpoint_dir, "checkpoints")

vae_checkpoint = tf.train.Checkpoint(vae_optimizer=vae_optimizer,
                                 vae=vae,)

In [None]:
usePreTrainVae = False

if usePreTrainVae:
    vae_checkpoint.restore(tf.train.latest_checkpoint(vae_checkpoint_dir))
else:
    vae.fit(image_dataset, epochs=500, batch_size=15)
    vae_checkpoint.save(file_prefix = vae_checkpoint_prefix)

In [None]:
#test VAE - has it been trained correctly or not?
import numpy as np
for image_batch, label_batch in image_dataset:
    first_image = np.array(image_batch[0])
    first_image = first_image.reshape(1,64,64,3)

    noise = np.random.normal(0, 1, size=(1, latent_dim))
    z_p = encoder.predict(([first_image]))[2]

    x_p = decoder.predict(z_p)

    plt.imshow(first_image[0])
    plt.show()
    plt.axis('off')
    plt.imshow(x_p[0])
    plt.axis('off')
    plt.show()

    break

In [None]:
#adam = Adam(learning_rate=0.0002, beta_1=0.5)
g_optim = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
d_optim = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)

In [None]:
# adding in the image generator
generator_inputs = keras.Input(shape=(latent_dim,))

x = layers.Dense(1024, activation="relu")(generator_inputs)
x = BatchNormalization()(x)
x = layers.Dense(2048, activation="relu")(x)
x = BatchNormalization()(x)
x = layers.Reshape((4, 4, 128))(x)
x = layers.Conv2DTranspose(128, 5, activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = layers.Conv2DTranspose(64, 5, activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x= layers.Conv2DTranspose(32, 5, activation="relu", strides=2, padding="same")(x)
generator_outputs = layers.Conv2DTranspose(3, 5, activation="sigmoid", strides=2, padding="same")(x)


generator = keras.Model(generator_inputs, generator_outputs, name="generator")
generator.compile(loss='binary_crossentropy', optimizer=g_optim, metrics=['accuracy'])
generator.summary()

In [None]:
# Discriminator - is it real or not?
discriminator_inputs = keras.Input(shape=(64, 64, 3))
x = layers.Conv2D(16, 5, strides=2, padding="same")(discriminator_inputs)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = Dropout(0.2)(x)
x = layers.Conv2D(32, 5, strides=2, padding="same")(x)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = Dropout(0.2)(x)
x = layers.Conv2D(32, 5, strides=1, padding="same")(x)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = Dropout(0.2)(x)
x = layers.Conv2D(64, 5, strides=2, padding="same")(x)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = Dropout(0.2)(x)
x = layers.Conv2D(64, 5, strides=1, padding="same")(x)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = layers.Conv2D(128, 5, strides=2, padding="same")(x)
x = LeakyReLU(alpha=0.3)(x)
x = BatchNormalization()(x)
x = layers.Flatten()(x)
x = layers.Dense(1024)(x)
discriminator_outputs = layers.Dense(1, activation="sigmoid")(x)



discriminator = keras.Model(discriminator_inputs, discriminator_outputs, name="discriminator")
discriminator.compile(loss='binary_crossentropy', optimizer=d_optim, metrics=['accuracy'])
discriminator.trainable = False

discriminator.summary()

In [None]:
# Combines both models so I can use training to work on the model
inputs = keras.Input(shape=(latent_dim, ))
hidden = generator(inputs)
output = discriminator(hidden)
gan = Model(inputs, output)
gan_optim = tf.keras.optimizers.Adam(0.002, beta_1=0.5)
gan.compile(loss='binary_crossentropy', optimizer=gan_optim, metrics=['accuracy'])
gan.summary()

In [None]:
# Helper functions to plot losses/generated images and observe if the loss is reducing with the number of epochs
def plot_loss(losses):
    """
    @losses.keys():
        0: loss
        1: accuracy
    """
    d_loss = [v[0] for v in losses["D"]]
    g_loss = [v[0] for v in losses["G"]]
    print(d_loss)
    print(g_loss)
    plt.figure(figsize=(10,8))
    plt.plot(d_loss, label="Discriminator loss")
    plt.plot(g_loss, label="Generator loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_generated(n_ex=10, dim=(1, 10), figsize=(12, 2)):
    noise = np.random.normal(0, 1, size=(n_ex, latent_dim))
    x_p = decoder.predict(noise)
    z_p = encoder.predict(x_p)[2]
    generated_images = generator.predict(z_p)
    generated_images = generated_images.reshape(n_ex, 64, 64,3)

    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Saving the model
def save_model(epochs):
    generator.save("".format(epochs))
    discriminator.save("".format(epochs))
    generator.summary()

In [None]:
# Set up a vector (dict) to store the losses
losses = {"D":[], "G":[]}

def train(epochs=500, plt_frq=1, BATCH_SIZE=128):
    set_tf_loglevel(logging.FATAL)
    batchCount = len(image_dataset)
    print('Epochs:', epochs)
    print('Batch size:', BATCH_SIZE)
    print('Batches per epoch:', batchCount)

    for e in tqdm_notebook(range(1, epochs+1)):
        if e == 1 or e%plt_frq == 0:
            print('-'*15, 'Epoch %d' % e, '-'*15)
        print(len(image_dataset))
        count = 0
        for image_batch,batch_label in image_dataset:  # tqdm_notebook(range(batchCount), leave=False):
            count = count + 1
            # Create a batch by drawing random index numbers from the training set
            print("test" + str(count))

            # image_batch = image_batch.reshape(128, 784) #why TODO
            # Create noise vectors for the generator
            noise = np.random.normal(0, 1, size=(BATCH_SIZE, latent_dim))

            #z => vae decoder
            x_p = decoder.predict(noise)
            z_p = encoder.predict(x_p)[2]


            # Generate the images from the noise
            generated_images = generator.predict(z_p)
            X = np.concatenate((image_batch, generated_images))

            # Create labels
            y = np.zeros(len(image_batch) + BATCH_SIZE)
            y[:len(image_batch)] = 0.9  # One-sided label smoothing

            # Train discriminator on generated images
            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(X, y)

            # Train generator
            # noise = np.random.normal(0, 1, size=(BATCH_SIZE, latent_dim))
            #z => vae decoder
            # x_p = decoder.predict(noise)
            # z_p = encoder.predict(x_p)[2]


            #y2 = np.ones(BATCH_SIZE)
            y2 = np.ones(BATCH_SIZE)
            discriminator.trainable = False
            g_loss = gan.train_on_batch(z_p, y2)
            #losses["D"].append(d_loss)
            #losses["G"].append(g_loss)
            #if count ==23:
            #  losses["D"].append(d_loss)
            #  losses["G"].append(g_loss)
        # Only store losses from final batch of epoch
        losses["D"].append(d_loss)
        losses["G"].append(g_loss)
        print('epochs', e)
        if e%500 ==0:
            save_model(e)
        # Update the plots
        if e == 1 or e%plt_frq == 0:
            plot_generated()
    plot_loss(losses)

In [None]:
import logging
import os

def set_tf_loglevel(level):
    if level >= logging.FATAL:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    if level >= logging.ERROR:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    if level >= logging.WARNING:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
    else:
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
    logging.getLogger('tensorflow').setLevel(level)
set_tf_loglevel(logging.FATAL)
train(epochs=2000, plt_frq=20, BATCH_SIZE=20)

In [None]:
num_images = 2232


noise =np.random.normal(0, 1, size=(num_images, latent_dim))
x_p = decoder.predict(noise)
z_p = encoder.predict(x_p)[2]
generated_images = generator.predict(z_p)
generated_images = generated_images.reshape(num_images, 64, 64,3)

#plt.figure(figsize=(10,10))
#for i in range(generated_images.shape[0]):
#    plt.subplot(3, 4, i+1)
#    plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
#    plt.axis('off')

#plt.show()


## CHANGE THE CLASS NAME TO THE TRAINED GAN CLASS
from matplotlib import pyplot as plt
for i in range(num_images):

    plt.imsave("<Drive Path>" + f"gan_generated_normal_category_{i}.png", generated_images[i])


In [None]:
# Loads generator and discriminator to get the already trained version
generator = tf.keras.models.load_model("")
discriminator = tf.keras.models.load_model("")