<a href="https://colab.research.google.com/github/ShreyNaik123/GAN-Implementations/blob/main/Cyclegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import plot_model

In [None]:
class Block(Layer):
  def __init__(self,filters,stride):
    super(Block,self).__init__()
    self.conv = Sequential([
        layers.Conv2D(filters=filters, kernel_size=4, strides=stride, padding='same'),
        layers.GroupNormalization(groups=1), #when groups is set as the number of channels in the input image to the function it acts like InstanceNormalization
        layers.LeakyReLU(0.2),
    ])

  def call(self,x):
    return self.conv(x)

In [None]:
class Discriminator(Model):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.initial = Sequential([
        layers.Conv2D(filters=64,
                      kernel_size=4,
                      strides=2,
                      padding='same'),
        layers.LeakyReLU(0.2),
    ])

    self.block1 = Block(128,2)
    self.block2 = Block(256,2)
    self.block3 = Block(512,1)
    self.final = Block(1,1)

  def call(self,x):
    x = self.initial(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    return tf.keras.activations.sigmoid(self.final(x))

In [None]:
test = tf.random.uniform((1,256,256,3))
model = Discriminator()
pred = model(test)
print(pred.shape)

(1, 32, 32, 1)


In [None]:
class ConvBlock(Layer):
  def __init__(self, filters, kernel_size=4,down=True, activation=True, stride=1, **kwargs):
    super(ConvBlock, self).__init__()
    self.conv = Sequential()
    if down:
      self.conv.add(layers.Conv2D(filters=filters, kernel_size=kernel_size,padding='same',strides=stride))
    else:
      self.conv.add(layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size, padding='same',strides=stride))
    self.conv.add(layers.GroupNormalization(groups=1))
    if activation:
      self.conv.add(layers.ReLU())
    else:
      self.conv.add(layers.Identity())

  def call(self, x):
    return self.conv(x)

In [None]:
class ResidualBlock(Layer):
  def __init__(self, filters):
    super(ResidualBlock, self).__init__()
    self.block = Sequential([
        ConvBlock(filters=filters, kernel_size=3),
        ConvBlock(filters=filters, kernel_size=3, activation=False)
    ])

  def call(self, x):
    return x + self.block(x)

In [None]:
def pad(input_tensor, padding_size):
  return tf.pad(input_tensor, [[0, 0], [padding_size, padding_size], [padding_size, padding_size], [0, 0]], mode='CONSTANT')

In [None]:
class Generator(Model):
  def __init__(self, num_residual):
    super(Generator, self).__init__()
    self.initial = Sequential([
        layers.Conv2D(64, kernel_size=7, strides=1, padding='same'),
        layers.ReLU(),
    ])

    self.down = Sequential([
        ConvBlock(128, kernel_size=3,stride=2,padding='same'),
        ConvBlock(256, kernel_size=3, stride=2, padding='same')
    ])



    self.residual = Sequential()
    for _  in range(num_residual):
      self.residual.add(ResidualBlock(256))

    self.up = Sequential([
        ConvBlock(128, down=False, kernel_size=3, stride=2,padding='same', output_padding=1),
        ConvBlock(64, down=False, kernel_size=3, stride=2, padding='same', output_padding=1),
    ])

    self.last = Sequential([
        ConvBlock(3, kernel_size=7, stride=1)
    ])


  def call(self, x):
    x = self.initial(x)
    x = self.down(x)
    x = self.residual(x)
    x = self.up(x)
    return tf.keras.activations.tanh(self.last(x))


In [None]:
test = tf.random.uniform((1,256,256,3))
model = Generator(9)
pred = model(test)
print(pred.shape)

(1, 256, 256, 3)


In [None]:
mse = tf.keras.losses.MeanSquaredError()
l1 = tf.keras.losses.MeanAbsoluteError()
disc_opt_z = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5,beta_1=0.5,beta_2=0.999)
disc_opt_h = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5,beta_1=0.5,beta_2=0.999)
gen_opt_z = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5,beta_1=0.5,beta_2=0.999)
gen_opt_h = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5,beta_1=0.5,beta_2=0.999)

In [None]:
class CycleGan(Model):
  def __init__(self, disc_h, disc_z, gen_h, gen_z):
    super(CycleGan, self).__init__()
    self.disc_h = disc_h
    self.disc_z = disc_z
    self.gen_h = gen_h
    self.gen_z = gen_z

  # l1 is mean absolute error
  def compile(self, mse, l1, disc_opt_z, disc_opt_h, gen_opt_z, gen_opt_h):
    super().compile()
    self.mse = mse
    self.l1 = l1
    self.disc_opt_z = disc_opt_z
    self.disc_opt_h = disc_opt_h
    self.gen_opt_z = gen_opt_z
    self.gen_opt_h = gen_opt_h


  def train_step(self, batch):
    real_horse, real_zebra = batch

    # training horse discriminator
    with tf.GradientTape() as d_tape:

      fake_horse = self.gen_h(real_zebra)

      real_pred = self.disc_h(real_horse)
      fake_pred = self.disc_h(fake_horse)
      actual_true_pred = tf.zeros(tf.shape(real_pred))
      actual_fake_pred = tf.ones(tf.shape(fake_pred))

      h_loss_true = self.mse(actual_true_pred, real_pred)
      h_loss_fake = self.mse(actual_fake_pred, fake_pred)

      h_total_loss = (h_loss_true + h_loss_fake)/2

    d_gradient = d_tape.gradient(h_total_loss, self.disc_h.trainable_variables)
    self.disc_opt_h.apply_gradients(zip(d_gradient, self.disc_h.trainable_variables))


    # training zebra discriminator
    with tf.GradientTape() as d_tape:

      fake_zebra = self.gen_z(real_horse)
      fake_pred = self.disc_z(fake_zebra)
      real_pred = self.disc_z(real_zebra)

      actual_true_pred = tf.zeros(tf.shape(real_pred))
      actual_fake_pred = tf.ones(tf.shape(fake_pred))

      z_loss_true = self.mse(actual_true_pred, real_pred)
      z_loss_fake = self.mse(actual_fake_pred, fake_pred)

      z_total_loss = (z_loss_true + z_loss_fake)/2

    d_gradient = d_tape.gradient(z_total_loss, self.disc_z.trainable_variables)
    self.disc_opt_z.apply_gradients(zip(d_gradient, self.disc_z.trainable_variables))


    # training the horse generator
    with tf.GradientTape() as g_tape:
      fake_horse = self.gen_h(real_zebra)

      fake_pred = self.disc_h(fake_horse)
      true_pred = tf.zeros(tf.shape(fake_pred))

      mse_loss = self.mse(true_pred, fake_pred)

      # cyclic loss
      generated_real_zebra = self.gen_z(fake_horse)
      cyclic_l1_loss = self.l1(real_zebra, generated_real_zebra)

      # identity loss
      # same_horse = self.gen_h(real_horse)
      # identity_loss = self.l1(real_horse, same_horse)

      # 10 is the lambda constant
      # not using identity loss here it only increases accuracy on some specific problems
      total_h_loss = (mse_loss + 10*cyclic_l1_loss)

    h_gradient = g_tape.gradient(total_h_loss, self.gen_h.trainable_variables)
    self.gen_opt_h.apply_gradients(zip(h_gradient, self.gen_h.trainable_variables))

    # train the zebra generator
    with tf.GradientTape() as g_tape:
      fake_zebra = self.gen_z(real_horse)
      fake_pred = self.disc_z(fake_zebra)
      true_pred = tf.zeros(tf.shape(fake_pred))

      mse_loss = self.mse(true_pred, fake_pred)

      # cyclic loss for zebra
      generated_horse = self.gen_h(fake_zebra)
      cyclic_loss = self.l1(real_horse, generated_horse)

      # identity loss
      # same_zebra = self.gen_z(real_zebra)
      # identity_loss = self.l1(real_zebra, same_zebra)

      total_loss = (mse_loss + 10*cyclic_loss)

    z_gradient = g_tape.gradient(total_loss, self.gen_z.trainable_variables)
    self.gen_opt_z.apply_gradients(zip(z_gradient, self.gen_z.trainable_variables))

    return {'d_loss_z':z_total_loss,'d_loss_h':h_total_loss, 'g_loss_z':total_loss, 'g_loss_h':total_h_loss}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
# load the data
!gdown --id 18aOUuJVV6kIgpnqKoP-AW3cY6ZxhP9Aa

!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download balraj98/horse2zebra-dataset

Downloading...
From: https://drive.google.com/uc?id=18aOUuJVV6kIgpnqKoP-AW3cY6ZxhP9Aa
To: /content/kaggle.json
100% 68.0/68.0 [00:00<00:00, 460kB/s]
Downloading horse2zebra-dataset.zip to /content
 99% 110M/111M [00:06<00:00, 19.6MB/s]
100% 111M/111M [00:06<00:00, 16.8MB/s]


In [None]:
import zipfile
zip_ref = zipfile.ZipFile('/content/horse2zebra-dataset.zip')
zip_ref.extractall()
zip_ref.close()

In [None]:
trainA = '/content/trainA' #a -> horses
trainB = '/content/trainB' #b -> zebras

In [None]:
len(os.listdir(trainA)),len(os.listdir(trainB))

(1067, 1334)

In [None]:
horse_dataset = tf.keras.preprocessing.image_dataset_from_directory(trainA,
                                                                    image_size=(256, 256),
                                                                    label_mode=None,
                                                                    batch_size=1)

zebra_dataset = tf.keras.preprocessing.image_dataset_from_directory(trainB,
                                                                     image_size=(256,256),
                                                                     label_mode=None,
                                                                     batch_size=1)

Found 1067 files belonging to 1 classes.
Found 1334 files belonging to 1 classes.


In [None]:
combined_dataset = tf.data.Dataset.zip((horse_dataset, zebra_dataset))

In [None]:
combined_dataset.cardinality().numpy()*32

34144

In [None]:
def rescale(horse, zebra):
  return (horse-127.5)/127.5, (zebra-127.5)/127.5

In [None]:
combined_dataset = combined_dataset.map(rescale, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [None]:
disc_h = Discriminator()
disc_z = Discriminator()
gen_h = Generator(9)
gen_z = Generator(9)

In [None]:
cgan = CycleGan(disc_h, disc_z, gen_h, gen_z)

In [None]:
cgan.compile(mse, l1, disc_opt_z, disc_opt_h, gen_opt_z, gen_opt_h)

In [None]:
cgan.fit(combined_dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7c66d6f13880>