In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install tensorflow_addons

Collecting tensorflow_addons
[?25l  Downloading https://files.pythonhosted.org/packages/74/e3/56d2fe76f0bb7c88ed9b2a6a557e25e83e252aec08f13de34369cd850a0b/tensorflow_addons-0.12.1-cp37-cp37m-manylinux2010_x86_64.whl (703kB)
[K     |▌                               | 10kB 26.7MB/s eta 0:00:01[K     |█                               | 20kB 15.0MB/s eta 0:00:01[K     |█▍                              | 30kB 13.1MB/s eta 0:00:01[K     |█▉                              | 40kB 12.2MB/s eta 0:00:01[K     |██▎                             | 51kB 7.7MB/s eta 0:00:01[K     |██▉                             | 61kB 7.4MB/s eta 0:00:01[K     |███▎                            | 71kB 8.3MB/s eta 0:00:01[K     |███▊                            | 81kB 9.1MB/s eta 0:00:01[K     |████▏                           | 92kB 8.6MB/s eta 0:00:01[K     |████▋                           | 102kB 7.6MB/s eta 0:00:01[K     |█████▏                          | 112kB 7.6MB/s eta 0:00:01[K     |█████▋     

In [3]:
import tensorflow as tf 
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import tensorflow_addons as tfa

In [4]:
class InputPipeline:
  def __init__(self,orig_height=286,
               orig_width=286,
               new_height=256,
               new_width=256,
               batch_size=1):
    self.orig_height = orig_height
    self.orig_width = orig_width
    self.new_height = new_height
    self.new_width = new_width
    self.batch_size = batch_size

  @staticmethod
  def load_data():
    dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                                with_info=True,
                                as_supervised = True)
    train_horses, train_zebras = dataset['trainA'],dataset['trainB']
    test_horses, test_zebras = dataset['testA'], dataset['testB']
    return train_horses, train_zebras, test_horses, test_zebras

  def train_preprocess_image(self,image,label):
    image = tf.image.resize(image, [self.orig_height,
                                    self.orig_width],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.image.random_crop(image,size=(self.new_height,
                                              self.new_width,
                                              3))
    image = (tf.cast(image,tf.float32)/127.5) - 1
    return image

  def test_preprocess_image(self,image,label):
    image = tf.image.resize(image, [self.new_height, self.new_width],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = (tf.cast(image,tf.float32)/127.5) - 1
    return image

  def data_generator(self,images,training=True):
    if training:
      dataset = images.map(self.test_preprocess_image,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    elif training == False:
      dataset = images.map(self.train_preprocess_image,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(self.batch_size)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

  def __call__(self):
      train_horses, train_zebras, test_horses, test_zebras = self.load_data()
      data_train_horses = self.data_generator(train_horses)
      data_train_zebras = self.data_generator(train_zebras)
      data_test_horses = self.data_generator(test_horses,
                                             training=False)
      data_test_zebras = self.data_generator(test_zebras,
                                             training=False)
      length_dataset = max(len(train_horses),len(train_zebras))

      train_dataset = tf.data.Dataset.zip((data_train_horses,data_train_zebras))
      test_dataset = tf.data.Dataset.zip((data_test_horses,data_test_zebras))

      return train_dataset, test_dataset, data_train_horses, length_dataset

In [5]:
class Visualize:
  def __init__(self,data,
               generator_model_AB,
               generator_model_BA):
    self.data = data
    self.generator_model_AB = generator_model_AB
    self.generator_model_BA = generator_model_BA

  def test_show_image(self):
    plt.figure(figsize=(15,15))
    for i in range(4):
        for x in self.data.take(i+30):
          img = tf.cast((tf.squeeze(x)+1)*127.5,tf.uint8)
          plt.subplot(1,4,i+1)
          plt.imshow(img)
          plt.title("Original Image")
    plt.show()
    
    plt.figure(figsize=(15,15))
    for i in range(4):
      for x in self.data.take(i+30):
        predicted_image = self.generator_model_AB(x)
        plt.subplot(2,4,i+1)
        plt.imshow(tf.cast((tf.squeeze(predicted_image)+1)*127.5,
                           tf.uint8))
        plt.title("Predicted Image")
    plt.show()  
    
    plt.figure(figsize=(15,15))
    for i in range(4):
      for x in self.data.take(i+30):
        predicted_image = self.generator_model_AB(x)
        plt.subplot(2,4,i+1)
        predicted_image_back = self.generator_model_BA(predicted_image)
        plt.imshow(tf.cast((tf.squeeze(predicted_image_back)+1)*127.5,
                           tf.uint8))
        plt.title("Reconstructed Image")
    plt.show()
  
  def __call__(self):
    self.test_show_image()


In [6]:
class Reflection_Pad(tf.keras.layers.Layer):
  def __init__(self, pad):
    super(Reflection_Pad,self).__init__()
    self.pad = tuple(pad)
  
  def call(self, input_tensor):
    pad_width, pad_height = self.pad
    return tf.pad(input_tensor, 
                  [[0, 0],
                   [pad_width, pad_height],
                   [pad_width, pad_height],
                   [0, 0]],
                  mode = "REFLECT")

In [7]:
class Resnet_Block(tf.keras.layers.Layer):
  def __init__(self,filters):
    super(Resnet_Block,self).__init__() 
    self.filters = filters
    self.weight_init = tf.keras.initializers.RandomNormal(stddev=0.02)
    self.reflection_pad_1 = Reflection_Pad(pad=(1,1))
    self.reflection_pad_2 = Reflection_Pad(pad=(1,1))
    self.conv_layer_1 = tf.keras.layers.Conv2D(self.filters,3,padding='valid',
                                             kernel_initializer = self.weight_init)
    self.conv_layer_2 = tf.keras.layers.Conv2D(self.filters,3,padding='valid',
                                             kernel_initializer = self.weight_init)
    self.instance_norm_1 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm_2 = tfa.layers.InstanceNormalization(axis=-1)
    self.activation = tf.keras.layers.ReLU()
    self.add_layer = tf.keras.layers.Add()

  def call(self,input_tensor,training=None):
    x = self.reflection_pad_1(input_tensor)
    x = self.conv_layer_1(x)
    x = self.instance_norm_1(x,training=training)
    x = self.activation(x)
    x = self.reflection_pad_2(x)
    x = self.conv_layer_2(x)
    x = self.instance_norm_2(x,training=training)
    x = self.add_layer([input_tensor,x])
    return x 
  
class Generator(tf.keras.models.Model):
  def __init__(self,filters,weights=None):
    super(Generator,self).__init__()
    self.filters = filters
    self.weight_init = tf.keras.initializers.RandomNormal(stddev=0.02)
    self.reflection_pad = Reflection_Pad(pad=(3,3))
    
    self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm3 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm4 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm5 = tfa.layers.InstanceNormalization(axis=-1)
    self.relu = tf.keras.layers.ReLU()

    self.conv_layer1 = tf.keras.layers.Conv2D(self.filters,7,padding='valid',
                                             kernel_initializer=self.weight_init)
    self.conv_layer2 = tf.keras.layers.Conv2D(self.filters*2,3,2,padding='same',
                                              kernel_initializer=self.weight_init)
    self.conv_layer3 = tf.keras.layers.Conv2D(self.filters*4,3,2,padding='same',
                                                  kernel_initializer=self.weight_init)
    self.make_resnet_block = self.make_resnet_blocks()
    self.conv_transpose_layer4 = tf.keras.layers.Conv2DTranspose(self.filters*2,
                                                                 3,2,padding='same',
                                                                 kernel_initializer = self.weight_init)
    self.conv_transpose_layer5 = tf.keras.layers.Conv2DTranspose(self.filters,3,2,padding='same',
                                                                 kernel_initializer = self.weight_init)
    self.conv_layer6 = tf.keras.layers.Conv2D(3,7,padding='same',
                                              kernel_initializer=self.weight_init,
                                              activation="tanh")
    if weights:
      try:
          self.load_weights(weights)
      except Exception:
          raise ValueError
  
  def make_resnet_blocks(self):
    label = []
    for _ in range(9):
      label.append(Resnet_Block(self.filters*4))
    return tf.keras.Sequential(label,name='residual_blocks')

  def call(self,input_tensor,training=None):
    x = self.reflection_pad(input_tensor)
    x = self.conv_layer1(x)
    x = self.instance_norm1(x,training=training)
    x = self.relu(x)
    x = self.conv_layer2(x)
    x = self.instance_norm2(x,training=training)
    x = self.relu(x)
    x = self.conv_layer3(x)
    x = self.instance_norm3(x,training=training)
    x = self.relu(x)
    x = self.make_resnet_block(x,training=training)
    x = self.conv_transpose_layer4(x)
    x = self.instance_norm4(x,training=training)
    x = self.relu(x)
    x = self.conv_transpose_layer5(x)
    x = self.instance_norm5(x,training=training)
    x = self.relu(x)
    x = self.conv_layer6(x)
    return x

In [8]:
class Discriminator(tf.keras.layers.Layer):
  def __init__(self,filters,**kwargs):
    super(Discriminator,self).__init__()
    self.filters = filters
    self.weight_init = tf.keras.initializers.RandomNormal(stddev=0.02)
    self.conv_layer1 = tf.keras.layers.Conv2D(self.filters,
                                              4,2,padding='same',
                                              kernel_initializer=self.weight_init)
    self.instance_norm1 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm2 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm3 = tfa.layers.InstanceNormalization(axis=-1)
    self.instance_norm4 = tfa.layers.InstanceNormalization(axis=-1)
    self.l_relu = tf.keras.layers.LeakyReLU()
    self.conv_layer2 = tf.keras.layers.Conv2D(self.filters*2,
                                              4,2,padding='same',
                                              kernel_initializer=self.weight_init)
    self.conv_layer3 = tf.keras.layers.Conv2D(self.filters*4,
                                              4,2,padding='same',
                                              kernel_initializer=self.weight_init)
    self.conv_layer4 = tf.keras.layers.Conv2D(self.filters*8,4,2,
                                              padding='same',
                                              kernel_initializer=self.weight_init)
    self.conv_layer5 = tf.keras.layers.Conv2D(1,4,padding='same',
                                              kernel_initializer=self.weight_init)
  
  def call(self,input_tensor):
    x = self.conv_layer1(input_tensor)
    x = self.l_relu(x)
    x = self.conv_layer2(x)
    x = self.instance_norm1(x)
    x = self.l_relu(x)
    x = self.conv_layer3(x)
    x = self.instance_norm2(x)
    x = self.l_relu(x)
    x = self.conv_layer4(x)
    x = self.instance_norm3(x)
    x = self.l_relu(x)
    patch_out = self.conv_layer5(x)
    return patch_out

In [9]:
class Cycle_GAN(tf.keras.models.Model):
  def __init__(self,generator_AB=None,generator_BA=None,**kwargs):
    super(Cycle_GAN,self).__init__()
    self.discriminator_model_A = Discriminator(64)
    self.discriminator_model_B = Discriminator(64)
    self.generator_model_AB = Generator(32,weights=generator_AB)
    self.generator_model_BA = Generator(32,weights=generator_BA)

  def compile(self,generator_optimizer,discriminator_optimizer):
    super(Cycle_GAN, self).compile()
    self.generator_AB_optimizer = generator_optimizer
    self.generator_BA_optimizer = generator_optimizer
    self.discriminator_A_optimizer = discriminator_optimizer
    self.discriminator_B_optimizer = discriminator_optimizer
  
  @staticmethod
  def discriminator_loss(real,fake):
    real_loss = tf.reduce_mean(tf.math.squared_difference(tf.ones_like(real),real))
    fake_loss = tf.reduce_mean(tf.math.squared_difference(tf.zeros_like(fake),fake))
    disc_loss = (real_loss+fake_loss)/2.0
    return disc_loss
  
  @staticmethod
  def identity_loss(real,identity):
    loss = tf.reduce_mean(tf.abs(real-identity))
    return loss*0.05

  @staticmethod
  def generator_loss(fake):
    gen_loss = tf.reduce_mean(tf.math.squared_difference(tf.ones_like(fake),fake))
    return gen_loss

  @staticmethod
  def cycle_loss(real,fake):
    loss = tf.reduce_mean(tf.abs(real-fake))
    return 10*loss
  
  def train_step(self,data):
    real_a, real_b = data

    with tf.GradientTape(persistent=True) as tape:
      fake_a = self.generator_model_BA(real_b,training=True)
      fake_b = self.generator_model_AB(real_a,training=True)
    
      generated_a = self.generator_model_BA(fake_b,training=True)
      generated_b = self.generator_model_AB(fake_a,training=True)

      identity_a = self.generator_model_BA(real_a, training=True)
      identity_b = self.generator_model_AB(real_b, training=True)

      disc_real_b = self.discriminator_model_B(real_b,training=True)
      disc_fake_b = self.discriminator_model_B(fake_b,training=True)

      disc_real_a = self.discriminator_model_A(real_a,training=True)
      disc_fake_a = self.discriminator_model_A(fake_a,training=True)
      
      identity_loss_A = self.identity_loss(real_b,identity_b)
      identity_loss_B = self.identity_loss(real_a,identity_a)
      cycle_loss_A = self.cycle_loss(real_b,generated_b)
      cycle_loss_B = self.cycle_loss(real_a,generated_a)
      gen_loss_A = self.generator_loss(disc_fake_b)
      gen_loss_B = self.generator_loss(disc_fake_a)
      
      generator_loss_A = gen_loss_A+cycle_loss_A+identity_loss_A
      generator_loss_B = gen_loss_B+cycle_loss_B+identity_loss_B

      disc_loss_A = self.discriminator_loss(disc_real_a,disc_fake_a)
      disc_loss_B = self.discriminator_loss(disc_real_b,disc_fake_b)

    
    generator_AB_gradients = tape.gradient(generator_loss_A,
                                           self.generator_model_AB.\
                                           trainable_variables)
    generator_BA_gradients = tape.gradient(generator_loss_B,
                                           self.generator_model_BA.\
                                           trainable_variables)

    discriminator_A_gradients = tape.gradient(disc_loss_A,
                                              self.discriminator_model_A.\
                                              trainable_variables)
    discriminator_B_gradients = tape.gradient(disc_loss_B,
                                              self.discriminator_model_B.\
                                              trainable_variables)

    self.generator_AB_optimizer.apply_gradients(zip(generator_AB_gradients,
                                                    self.generator_model_AB.\
                                                    trainable_variables))
    self.generator_BA_optimizer.apply_gradients(zip(generator_BA_gradients,
                                                    self.generator_model_BA.\
                                                    trainable_variables))
    self.discriminator_A_optimizer.apply_gradients(zip(discriminator_A_gradients,
                                                       self.discriminator_model_A.\
                                                       trainable_variables))
    self.discriminator_B_optimizer.apply_gradients(zip(discriminator_B_gradients,
                                                       self.discriminator_model_B.\
                                                       trainable_variables))
    return {"generator_loss_A":generator_loss_A,
            "generator_loss_B":generator_loss_B,
            "disc_loss_A":disc_loss_A,
            "disc_loss_B":disc_loss_B}


In [13]:
class Linear_Decay(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, initial_learning_rate, total_steps, step_decay):
    super(Linear_Decay, self).__init__()
    self.initial_learning_rate = initial_learning_rate
    self.total_steps = total_steps
    self.step_decay = step_decay
    self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate,
                                             trainable=False, dtype=tf.float32)
  def __call__(self, step):
    if step >= self.step_decay:
      return self.initial_learning_rate * (1 - 1 / (self.total_steps - 
                                                           self.step_decay) * 
                                                  (step - self.step_decay))
    else:
      return initial_learning_rate

In [None]:
def main():
  batch_size = 1
  epochs = 50
  epoch_decay = 100
  learning_rate = 0.0002
  image_width = 256
  image_height = 256

  save_path = "/content/drive/MyDrive/cycle_gan/"
  weights_dir_AB = "/content/drive/MyDrive/cycle_gan_AB/"
  weights_dir_BA = "/content/drive/MyDrive/cycle_gan_BA/"
  
  data = InputPipeline()
  train_dataset, test_dataset, data_train_horses, length_dataset = data.__call__()

  cycle_gan = Cycle_GAN()
  cycle_gan.discriminator_model_B.build((1,image_height,
                                         image_width,3))
  cycle_gan.discriminator_model_A.build((1,image_height,
                                         image_width,3))
  cycle_gan.generator_model_AB.build((1,image_height,
                                      image_width,3))
  cycle_gan.generator_model_BA.build((1,image_height,
                                      image_width,3))

  generator_learning_rate = Linear_Decay(learning_rate,epochs*length_dataset,
                                        epoch_decay*length_dataset)
  discriminator_learning_rate = Linear_Decay(learning_rate,epochs*length_dataset,
                                            epoch_decay*length_dataset)
  cycle_gan.compile(generator_optimizer=tf.keras.optimizers.Adam(learning_rate,
                                                                 beta_1=0.5),
                    discriminator_optimizer=tf.keras.optimizers.Adam(learning_rate,
                                                                     beta_1=0.5))
  callbacks = [
      tf.keras.callbacks.ModelCheckpoint(
          save_path,
          save_weights_only=True,
          save_best_only=True,
          monitor='generator_loss_A',
          verbose=1,
          save_freq='epoch',
          period=1),
  ]
  
  cycle_gan.fit(train_dataset,batch_size=batch_size,steps_per_epoch=length_dataset//batch_size,
              epochs=epochs,callbacks=callbacks)
  
  cycle_gan.generator_model_AB.save_weights(weights_dir_AB)
  cycle_gan.generator_model_BA.save_weights(weights_dir_BA)
  
  visualize = Visualize(data_train_horses,
                        cycle_gan.generator_model_AB,
                        cycle_gan.generator_model_BA)
  visualize.__call__()

if __name__ == "__main__":
  main()