<a href="https://colab.research.google.com/github/Ahtesham519/Genrative_Deep_learning_v2_2023/blob/main/WGAN_CELEBAA_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np

import tensorflow as tf
from tensorflow.kears import(
    Sequential,
    layers,
    optimizers,
    losses,
    metrics,
    callbacks,
    models,

)



#0. Parameters

In [None]:
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
NUM_FEATURES = 64
Z_DIM = 128
LEARNING_RATE = 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9


#1. Prepare the data

In [None]:
#Load the data
train_data = utils.image_dataset_from_directory(
    "/app/data/celeba-dataset/img_align_celeba/img_align_celeba",
    labels=None,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    shuffle=True,
    seed = 42,
    interpolation = "bilinear",
)

In [None]:
#Preprocess the data
def preprocess_image(image):
    """
    Normalize and reshape the images
    """
    img = tf.cast(image, tf.float32)
    return img

train = train_data.map(lambda x: preprocess(x))

In [None]:
#Show some features faces from the training set
train_sample = sample_batch(train)


In [None]:
display(train_sample, cmap = None)

#2. Bulid the WGAN-GP

In [None]:
critic_input = layers.Input(shape = (IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
x = layers.Conv2D(64, kernel_size = 4, strides = 2, padding = "same")(critic_input)
x = layers.LeakyReLU(alpha = 0.2)(x)
x = layers.Conv2D(128, kernel_size = 4, strides = 2, padding = "same")(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(256, kernel_size = 4, strides = 2, padding = "same")(x)
x = layers.LeakyReLU(alpha = 0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(512, kernel_size = 4, strides = 2, padding = "same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size = 4, strides = 1 , padding = "valid")(x)
critic_output = layers..Flatten()(x)

critic = models.Model(critic_input , critic_output)
critic.summary()




In [None]:
generator_input = layers.Input(shape = (Z_DIM,))
x = layers.Reshape((1,1,Z_DIM))(generator_input)
x = layers.Conv2DTranspose(
    512, kernel_size = 4, strides = 1, padding = "valid ", use_bias = False
)(x)
x = layers.BatchNormlization(momentum = 0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    256, kernel_size = 4, strides = 2, padding = "same", use_bias = False
)(x)
x = layers.BatchNormalization(momentum = 0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size = 4, strides = 2, padding = "same", use_bias = False
)(x)
x = layers.BatchNormalization(momentum = 0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size = 4, strides = 2, padding = "same", use_bias = False
)(x)
x = layers.BatchNormalization(momentum = 0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS , kernel_size = 4, strides = 2, padding = "same", activation = "tanh"
)(x)
generator = models.Model(generator_input , generator_output)
generator.summary()

In [None]:
class = WGANGP(models.Model):
  def __init__(self, critic, generator, latent_dim , critic_steps , gp_weight):
    super(WGANGP, self).__init__()
    self.critic = critic
    self.generator = generator
    self.latent_dim = latent_dim
    self.critic_steps = critic_steps
    self.gp_weight = gp_weight

  def compile(self, c_optimizer , g_optimizer):
    super(WGANGP  , self).compile()
    self.c_optimizer = c_optimizer
    self.g_optimizer = g_optimizer
    self.c_wass_loss_metric = metrics.Mean(name = "c_wass_loss")
    self.c_gp_metric = metrics.Mean(name = "c_gp")
    self.c_loss_metric = metrics.Mean(name = "c_loss")
    self.g_loss_metric = metircs.Mean(name = "g_loss")


  @property
  def metrics(self):
    return [
        self.c_loss_metric,
        self.c_wass_loss_metric,
        self.c_gp_metric,
        self.g_loss_metric,
    ]


  def gradient_penalty(self, batch_size , real_images , fake_images):
    alpha = tf.random.normal([batch_size , 1, 1, 1], 0.0, 1.0)
    diff = fake_images - real_images
    interpolated = real_images + alpha * diff

    with tf.GradientTape() as gp_tape:
      gp_tape.watch(interpolated)
      pred = self.critic(interpolated , training = True)

    grads = gp_tape.gradient(pred , [interpolated])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads) , axis = [1, 2,3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

  def train_step(self, real_images):
    batch_size = tf.shape(real_images)[0]

    for i in range(self.critic_steps):
      random_latent_vectors = tf.random.normal(
          shape = (batch_size , self.latent_dim)
      )

      with tf.GradientTape() as tape:
        fake_images = self.generator(
            random_latent_vectors , training = True
        )
        fake_predictions = self.critic(fake_images , training = True)
        real_predictions = self.critic(real_images , training = True)

        c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(
            real_predictions
        )
        c_gp = self.gradient_penalty(
            batch_size , real_images , fake_images
        )
        c_loss = c_wass_loss + c_gp * self.gp_weight

    c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
    self.c_optimizer.apply_gradients(
        zip(c_gradient, self.critic.trainable_variables)
    )

  random_latent_vectors = tf.random.normal(shape = (batch_size , self.latent_dim))
  with tf.GradientTape() as tape:
    fake_images =self.generator(random_latent_vectors , training = True)
    fake_predictions = self.critic(fake_images , training = True)
    g_loss = -tf.reduce_mean(fake_predictions)

  ge_gradients = tape.gradeint(g_loss , self.generator.trainable_variables)
  self.g_optimizer.apply_gradients(
      zip(gen_gradient , self.generator.trainable_variables)
  )

  self.c_loss_metric.update_state(c_loss)
  self.c_wass_loss_metric.update_state(c_wass_loss)
  self.c_gp_metric.update_state(c_wass_loss)
  self.g_loss_metric.update_state(g_loss)

  return {m.name : m.result() for m in self.metrics}

In [None]:
#Create a GAN
wgangp = WGANGP(
    critic = critc,
    generator =  generator,
    latent_dim = Z_DIM,
    critic_steps = CRITIC_STEPS,
    gp_weight = GP_WEIGHT,
)

In [None]:
if LOAD_MODEL:
  wgangp.load_weights("./checkpoint/checkpoint.ckpt")

#Train the GAN

In [None]:
#complie the GAN
wgangp.compile(
    c_optimizer = optimizers.Adam(
        learning_rate = LEARNING_RATE,
        beta_1 = ADAM_BETA_1,
        beta_2 = ADAM_BETA_2,
    ),
    g_optimizer = optimizers.Adam(
        learning_rate = LEARNING_RATE ,
        beta_1 = ADAM_BETA_1,
        beta_2 = ADAM_BETA_2,
    ),
)

In [None]:
#Create a model save checkpoint
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath = "./checkpoint/checkpoint.ckpt",
    save_weights_only = True,
    save_freq  = "epoch",
    verbose = 0,
)

tensorboard_callback = callbacks.TensorBoard(log_dir = "./logs")


class ImageGenerator(callbacks.Callback):
  def __init__(self, num_img, latent_dim):
    self.num_img = num_img
    self.latent_dim = latent_dim

  def on_epoch_end(self, epoch, logs = None):
    random_latent_vectors = tf.random.normal(shape = (self.num_img , self.latent_dim))

    generated_iamges = self.model.generated(random_latent_vectors)
    generated_iamges = generated_images * 127.5 + 127.5
    generated_images = generated_images.numpy()


In [None]:
wgangp.fit(
    train,
    epochs = EPOCHS,
    steps_per_epoch = 2,
    callbacks= [
        model_checkpoint_callback,
        tensorboard_callback,
        ImageGenerator(num_img = 10, latent_dim = Z_DIM),
    ],
)

In [None]:
#Save the final models
generator.save("./models/generator")
critic.save("./models/critic")

#Generate images

In [None]:
z_sample = np.random.normal(size = (10, Z_DIM))
imgs = wgangp.generator.predict(z_sample)
display(imgs, cmap = None)