<a href="https://colab.research.google.com/github/Snojj25/Machine_Learning/blob/main/horse2zebra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
zip_path = "/content/drive/MyDrive/Datasets/horse2zebra.zip"

In [None]:
import os
import zipfile
import shutil

In [None]:
shutil.rmtree('/tmp')

local_zip = zip_path
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/horse2zebra')

parent_dir = "/tmp/horse2zebra/horse2zebra/"

train_path = os.path.join(parent_dir, "train")
test_path = os.path.join(parent_dir, "test")

os.mkdir(train_path) 
os.mkdir(test_path) 

os.mkdir(os.path.join(train_path, "horse"))
os.mkdir(os.path.join(train_path, "zebra"))
os.mkdir(os.path.join(test_path, "horse"))
os.mkdir(os.path.join(test_path, "zebra"))

train_H = "/tmp/horse2zebra/horse2zebra/trainA"
train_Z = "/tmp/horse2zebra/horse2zebra/trainB"

test_H = "/tmp/horse2zebra/horse2zebra/testA"
test_Z = "/tmp/horse2zebra/horse2zebra/testB"

train_dir_H = "/tmp/horse2zebra/horse2zebra/train/horse"
train_dir_Z = "/tmp/horse2zebra/horse2zebra/train/zebra"
test_dir_H = "/tmp/horse2zebra/horse2zebra/test/horse"
test_dir_Z = "/tmp/horse2zebra/horse2zebra/test/zebra"

_ = shutil.move(train_H, train_dir_H)  
_ = shutil.move(train_Z, train_dir_Z)  
_ = shutil.move(test_H, test_dir_H)  
_ = shutil.move(test_Z, test_dir_Z)   

zip_ref.close()

del parent_dir, train_H, train_Z, test_H, test_Z, train_path, test_path, _

In [None]:
train_horse_path = os.path.join(train_dir_H, "trainA")

In [None]:
# Plotting the image ======================================
from PIL import Image

image = Image.open(os.path.join(train_horse_path, "n02381460_1025.jpg"))
# summarize some details about the image
print(image.format)
print(image.size)
print(image.mode)

In [None]:
%matplotlib inline

plt.imshow(image)

In [None]:
# ===== Generators ==============================

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
def get_train_generators(batch_size):
  train_datagen = ImageDataGenerator(rescale=1. /255,
                                    width_shift_range=0.2,
                                    height_shift_range=0.2,
                                    rotation_range=25,
                                    zoom_range=0.2,
                                    horizontal_flip=True,
                                    fill_mode='reflect')


  train_datagen_H = train_datagen.flow_from_directory(train_dir_H,
                                                      batch_size=batch_size,
                                                      target_size=(256, 256),
                                                      class_mode=None) 
   
  train_datagen_Z = train_datagen.flow_from_directory(train_dir_Z,  
                                                      batch_size=batch_size,
                                                      target_size=(256, 256),
                                                      class_mode=None)  
  
  return train_datagen_H, train_datagen_Z



def get_test_generators(batch_size):

  test_datagen = ImageDataGenerator(rescale=1. /255)

    
  test_datagen_H =  test_datagen.flow_from_directory(test_dir_H,
                                                     batch_size=batch_size,
                                                     target_size=(256, 256),
                                                     class_mode=None)
    
  test_datagen_Z =  test_datagen.flow_from_directory(test_dir_Z,
                                                     batch_size=batch_size,
                                                     target_size=(256, 256),
                                                     class_mode=None)
  
  return test_datagen_H, test_datagen_Z


In [None]:
# Defining the model components =====================================

In [None]:
!pip install -U tensorflow-addons

In [None]:
from tensorflow.keras.layers import Layer 
from tensorflow.keras.layers import Dense, Flatten, Dropout, Conv2D, Conv2DTranspose, ReLU, LeakyReLU, Activation
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers 
# from tensorflow.keras.activations import tanh

from tensorflow_addons.layers import InstanceNormalization

In [None]:
class ResidualBlock(Model):

  '''
    ResidualBlock Class:
    Performs two convolutions and an instance normalization, the input is added
    to this output to form the residual block output.
    Values:
        hidden_channels: the number of channels to expect from a given input
    '''

  def __init__(self, hidden_channels=256, **kwargs):
    super(ResidualBlock, self).__init__(**kwargs)
    self.conv1 = Conv2D(hidden_channels, kernel_size=3, padding="same")
    self.conv2 = Conv2D(hidden_channels, kernel_size=3, padding="same")
    self.instanceNorm = InstanceNormalization()
    self.activation = ReLU()

  # def forward(self, input):
  def call(self, inputs):
        '''
        Function for completing a forward pass of ResidualBlock: 
        Given an image tensor, completes a residual block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        original_x = inputs
        # x = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]], "CONSTANT")
        x = self.conv1(inputs)
        x = self.instanceNorm(x)
        x = self.activation(x)
        # x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], "CONSTANT")
        x = self.conv2(x)
        x = self.instanceNorm(x)
        return original_x + x



In [None]:
class ContractingBlock(Model):
  '''
    ContractingBlock Class
    Performs a convolution followed by a max pool operation and an optional instance norm.
    Values:
        hidden_channels: the number of channels to return from a given input
    '''
  
  def __init__(self, hidden_channels, kernel_size=3, stride=2, use_bn=True, activation="relu", **kwargs):
    super(ContractingBlock, self).__init__(**kwargs)
    self.conv1 = Conv2D(filters=hidden_channels, kernel_size=kernel_size, strides=(stride,stride))
    self.relu = ReLU() if activation =="relu" else LeakyReLU()
    if use_bn:
      self.instanceNorm = InstanceNormalization()
    self.use_bn = use_bn

  def call(self, inputs):
    '''
      Function for completing a forward pass of ContractingBlock: 
      Given an image tensor, completes a contracting block and returns the transformed tensor.
      Parameters:
          x: image tensor of shape (batch size, channels, height, width)
    '''
    x = self.conv1(inputs)
    if self.use_bn:
      x = self.instanceNorm(x)
    x = self.relu(x)
    return x


In [None]:
class ExpandingBlock(Model):

  def __init__(self, hidden_channels=64, kernel_size=3, stride=2, use_bn=True, activation="relu", **kwargs):
    super(ExpandingBlock, self).__init__(**kwargs)
    self.convT1 = Conv2DTranspose(filters=hidden_channels, kernel_size=kernel_size, strides=(stride, stride), )
    self.activation = ReLU() if activation=="relu" else LeakyReLU()
    if use_bn:
      self.instanceNorm = InstanceNormalization()
    self.use_bn = use_bn

  def call(self, inputs):

    x = self.convT1(inputs)
    if self.use_bn:
      x = self.instanceNorm(x)
    x = self.activation(x)
    return x

In [None]:
class FeatureMapLayer(Model):

  def __init__(self, out_channels, **kwargs):
    super(FeatureMapLayer, self).__init__(**kwargs)
    self.conv = Conv2D(filters=out_channels, kernel_size=7, strides=(1,1), padding="same")

  def call(self, inputs):
    x = self.conv(inputs)
    return x

In [None]:
# Defining the Model

In [None]:
class Generator(Model):
  
  def __init__(self, hidden_channels=64, **kwargs):
    super(Generator, self).__init__(**kwargs)
    self.up_feature = FeatureMapLayer(out_channels= hidden_channels, name="up_feature")
    self.contracting1 = ContractingBlock(hidden_channels= 2*hidden_channels, name="contracting1")
    self.contracting2 = ContractingBlock(hidden_channels= 4*hidden_channels, name="contracting2")
    self.res1 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual1")
    self.res2 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual2")
    self.res3 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual3")
    self.res4 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual4")
    self.res5 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual5")
    self.res6 = ResidualBlock(hidden_channels= 4*hidden_channels, name="residual6")
    self.expanding1 = ExpandingBlock(hidden_channels= 2*hidden_channels, name="expanding1")
    self.expanding2 = ExpandingBlock(hidden_channels= hidden_channels, kernel_size=4, name="expanding2")
    self.down_feature = FeatureMapLayer(out_channels= 3, name="down_feature")
    self.tanh = Activation("tanh")
    
  def call(self, inputs):
    '''
    Propagates the images through the contacting layers,
    than the residual layers and finally through the expanding layers,
    to return the output image.
    '''
    x1 = self.up_feature(inputs)
    x2 = self.contracting1(x1)
    x3 = self.contracting2(x2)
    x4 = self.res1(x3)
    x5 = self.res1(x4)
    x6 = self.res1(x5)
    x7 = self.res1(x6)
    x8 = self.res1(x7)
    x9 = self.res1(x8)
    x10 = self.expanding1(x9)
    x11 = self.expanding2(x10)
    x12 = self.down_feature(x11)
    out = self.tanh(x12)
    return out


In [None]:
# gen = Generator()

# gen.build((32, 256,256,3))
# gen.get_layer("residual2").build((32,256,256,256))
# gen.get_layer("residual3").build((32,256,256,256))
# gen.get_layer("residual4").build((32,256,256,256))
# gen.get_layer("residual5").build((32,256,256,256))
# gen.get_layer("residual6").build((32,256,256,256))
# gen.summary()

# inp = tf.random.uniform([1,256,256,3], dtype=tf.dtypes.float32)
# out = gen(inp)

In [None]:

class Discriminator(Model):

  '''
    Discriminator Class
    Structured like the contracting path of the U-Net, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        hidden_channels: the initial number of discriminator convolutional filters
    '''

  def __init__(self, hidden_channels=64):
    super(Discriminator, self).__init__()

    self.up_feature = FeatureMapLayer(64)
    self.contracting1 = ContractingBlock(hidden_channels=2*hidden_channels, kernel_size=4, activation="LReLU")
    self.contracting2 = ContractingBlock(hidden_channels=4*hidden_channels, kernel_size=4, activation="LReLU")
    self.contracting3 = ContractingBlock(hidden_channels=8*hidden_channels, kernel_size=4, activation="LReLU")
    self.final = Conv2D(filters=1, kernel_size=1)
    self.sigmoid = Activation("sigmoid")

  def call(self, inputs):
    
    x = self.up_feature(inputs)
    x = self.contracting1(x)
    x = self.contracting2(x)
    x = self.contracting3(x)
    x = self.final(x)
    out = self.sigmoid(x)
    return out




In [None]:
# disc = Discriminator()

# disc.build((32,256,256,3))

# disc.summary()

In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError

In [None]:
# test1 = tf.random.uniform([1, 30, 30])
# test2 = tf.random.uniform([1, 30, 30])

# loss = MeanSquaredError(reduction=tf.keras.losses.Reduction.AUTO)
# print(loss(test1,test2))

In [None]:
def get_disc_loss(real_X, fake_X, disc_X, adv_critereon):
  '''
    Return the loss of the discriminator given inputs.
    Parameters:
        real_X: the real images from pile X
        fake_X: the generated images of class X
        disc_X: the discriminator for class X; takes images and returns real/fake class X
            prediction matrices
        adv_criterion: the adversarial loss function; takes the discriminator 
            predictions and the target labels and returns a adversarial 
            loss (which you aim to minimize)
    '''

  real_pred = disc_X(real_X)
  fake_pred = disc_X(fake_X)
  real_loss = adv_critereon(real_pred, tf.ones_like(real_pred))
  fake_loss = adv_critereon(fake_pred, tf.zeros_like(fake_pred))
  disc_loss = (real_loss + fake_loss)/2
  return disc_loss
  

In [None]:
def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
    '''
    Return the adversarial loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        disc_Y: the discriminator for class Y; takes images and returns real/fake class Y
            prediction matrices
        gen_XY: the generator for class X to Y; takes images and returns the images 
            transformed to class Y
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the target labels and returns a adversarial 
                  loss (which you aim to minimize)
    '''

    fake_Y = gen_XY(real_X)
    # fake_Y_pred = tf.squeeze(disc_Y(fake_Y))
    fake_Y_pred = disc_Y(fake_Y)
    generator_loss = adv_criterion(fake_Y_pred, tf.ones_like(fake_Y_pred))
    return generator_loss, fake_Y



In [None]:
def get_identity_loss(real_X, gen_YX, identity_criterion):
    '''
    Return the identity loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        gen_YX: the generator for class Y to X; takes images and returns the images 
            transformed to class X
        identity_criterion: the identity loss function; takes the real images from X and
                        those images put through a Y->X generator and returns the identity 
                        loss (which you aim to minimize)
    '''

    identity_X = gen_YX(real_X)
    identity_loss = identity_criterion(identity_X, real_X)
    return identity_loss

In [None]:
def get_cycle_consistancy_loss(real_X, fake_Y, gen_YX, cycle_critereon):
    '''
    Return the cycle consistency loss of the generator given inputs
    (and the generated images for testing purposes).
    Parameters:
        real_X: the real images from pile X
        fake_Y: the generated images of class Y
        gen_YX: the generator for class Y to X; takes images and returns the images 
            transformed to class X
        cycle_criterion: the cycle consistency loss function; takes the real images from X and
                        those images put through a X->Y generator and then Y->X generator
                        and returns the cycle consistency loss (which you aim to minimize)
    '''
    cycle_X = gen_YX(fake_Y)
    cycle_loss = cycle_critereon(cycle_X, real_X)
    return cycle_loss


In [None]:
def get_generator_loss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B,
                       adv_critereon, identity_critereon, cycle_critereon,
                       lambda_identity=0.1, lambda_cycle=10):
  '''
  Return the loss of the generator given inputs.
    Parameters:
        real_A: the real images from pile A
        real_B: the real images from pile B
        gen_AB: the generator for class A to B; takes images and returns the images 
            transformed to class B
        gen_BA: the generator for class B to A; takes images and returns the images 
            transformed to class A
        disc_A: the discriminator for class A; takes images and returns real/fake class A
            prediction matrices
        disc_B: the discriminator for class B; takes images and returns real/fake class B
            prediction matrices
        adv_criterion: the adversarial loss function; takes the discriminator 
            predictions and the true labels and returns a adversarial 
            loss (which you aim to minimize)
        identity_criterion: the reconstruction loss function used for identity loss
            and cycle consistency loss; takes two sets of images and returns
            their pixel differences (which you aim to minimize)
        cycle_criterion: the cycle consistency loss function; takes the real images from X and
            those images put through a X->Y generator and then Y->X generator
            and returns the cycle consistency loss (which you aim to minimize).
            Note that in practice, cycle_criterion == identity_criterion == L1 loss
        lambda_identity: the weight of the identity loss
        lambda_cycle: the weight of the cycle-consistency loss
    '''

  adv_loss_A, fake_B = get_gen_adversarial_loss(real_X=real_A, disc_Y=disc_B, gen_XY=gen_AB, adv_criterion=adv_critereon)
  adv_loss_B, fake_A = get_gen_adversarial_loss(real_X=real_B, disc_Y=disc_A, gen_XY=gen_BA, adv_criterion=adv_critereon)

  identity_loss_A = get_identity_loss(real_X=real_A, gen_YX=gen_BA, identity_criterion=identity_critereon)
  identity_loss_B = get_identity_loss(real_X=real_B, gen_YX=gen_BA, identity_criterion=identity_critereon)

  cycle_loss_A = get_cycle_consistancy_loss(real_X=real_A, fake_Y=fake_B, gen_YX=gen_BA, cycle_critereon=cycle_critereon)
  cycle_loss_B = get_cycle_consistancy_loss(real_X=real_B, fake_Y=fake_A, gen_YX=gen_AB, cycle_critereon=cycle_critereon)

  adversarial_loss = tf.math.add(adv_loss_A, adv_loss_B)
  identity_loss = tf.math.add(identity_loss_A, identity_loss_B)  
  cycle_consistancy_loss = tf.math.add(cycle_loss_A, cycle_loss_B) 

  generator_loss = adversarial_loss + lambda_identity * identity_loss + lambda_cycle * cycle_consistancy_loss
  return generator_loss, fake_A, fake_B



In [None]:
# Training the networks

In [None]:
batch_size = 1
gen_lr = 0.0002
disc_lr = 0.00001

adv_critereon = MeanSquaredError()
identity_critereon = MeanAbsoluteError()
cycle_critereon = MeanAbsoluteError()



In [None]:
# Generator Models and optimizer
gen_HZ = Generator()
gen_ZH = Generator()
gen_optim = Adam(learning_rate=gen_lr, beta_1=0.5, name="gen_optim")

# Check for saved generator model weights 
if os.path.exists("/content/drive/MyDrive/Models/horse2zebra/gen_HZ.index"):
  print("loaded weigths for gen_HZ")
  gen_HZ.load_weights("/content/drive/MyDrive/Models/horse2zebra/gen_HZ")

if os.path.exists("/content/drive/MyDrive/Models/horse2zebra/gen_ZH.index"):
  print("loaded weigths for gen_ZH")
  gen_ZH.load_weights("/content/drive/MyDrive/Models/horse2zebra/gen_ZH")


# Generator models and optimizers
disc_H = Discriminator()
disc_Z = Discriminator()
disc_optim = Adam(learning_rate=disc_lr, beta_1=0.5, name="disc_optim")

# Check for saved generator model weights 
if os.path.exists("/content/drive/MyDrive/Models/horse2zebra/disc_H.index"):
  print("loaded weigths for disc_H")
  disc_H.load_weights("/content/drive/MyDrive/Models/horse2zebra/disc_H")

if os.path.exists("/content/drive/MyDrive/Models/horse2zebra/disc_Z.index"):
  print("loaded weigths for disc_Z")
  disc_Z.load_weights("/content/drive/MyDrive/Models/horse2zebra/disc_Z")



In [None]:
#  CALLBACKS

def change_lr(optimizer): 
    new_lr = optimizer.lr / 1.001
    tf.keras.backend.set_value(optimizer.lr,new_lr)
    
def earlyStopping(loss_list, min_delta=0.1, patience=8):
    #No early stopping for 2*patience epochs 
    if len(loss_list)//patience < 2 :
        return False
    treshold = loss_list[-patience]
    for l in loss_list[-patience:]:
      if l <= treshold - min_delta:
        return False
    return True
 
def save_best_weights(model, loss_list, model_path):
    '''
    Saves the model weights, if the loss for the current run was the lowest.
      Parameters:
        model: The model you want to save the weights for.
        loss_list: The list of the losses of the above model.
        model_path: string appended to the end of the path where the model is gonna be saved.
    '''
    if loss_list[-1] != min(loss_list):
      return


    checkpoint_path = "/content/drive/MyDrive/Models/horse2zebra"
    cp_path = os.path.join(checkpoint_path, model_path)
    model.save_weights(cp_path)

In [None]:
# Define the gradient functions

@tf.function
def get_gen_loss_and_grads(real_H, real_Z):
  '''
  Returns the loss and gradients for both generators and the generated
  fake images of class H and Z for the discriminator loss.
      Parameters:
        real_H: real examples of class H
        real_Z: real examples of class Z
      Returns: (gen_loss, grads_HZ, grads_ZH, fake_H, fake_Z)
  '''
  with tf.GradientTape() as tape_HZ, tf.GradientTape() as tape_ZH:
    gen_loss, fake_H, fake_Z = get_generator_loss(real_A= real_H, real_B= real_Z, gen_AB= gen_HZ, gen_BA= gen_ZH,
                                  disc_A= disc_H, disc_B= disc_Z, adv_critereon= adv_critereon, 
                                  identity_critereon= identity_critereon, cycle_critereon= cycle_critereon)
    
  grads_HZ = tape_HZ.gradient(gen_loss, gen_HZ.trainable_variables)
  grads_ZH = tape_ZH.gradient(gen_loss, gen_ZH.trainable_variables)
  return gen_loss, grads_HZ, grads_ZH, fake_H, fake_Z


@tf.function
def get_disc_loss_and_grads(real_H, real_Z, fake_H, fake_Z):
  '''
  Returns the loss and gradients for both discriminators.
      Parameters:
        real_H: real examples of class H
        real_Z: real examples of class Z
        real_H: generated images of class H
        real_Z: generated images of class Z
        Returns: (disc_loss, grads_H, grads_Z)
  '''
  with tf.GradientTape() as tape_H, tf.GradientTape() as tape_Z:
    disc_loss_H = get_disc_loss(real_H, fake_H, disc_H, adv_critereon)
    disc_loss_Z = get_disc_loss(real_Z, fake_Z, disc_Z, adv_critereon)

    grads_H = tape_H.gradient(disc_loss_H, disc_H.trainable_variables)
    grads_Z = tape_Z.gradient(disc_loss_Z, disc_Z.trainable_variables)

    disc_loss = (disc_loss_H + disc_loss_Z)/2

  return disc_loss, grads_H, grads_Z



In [None]:
# Training the network

In [None]:
datagen_H, datagen_Z = get_train_generators(1)

In [None]:
 !pip install colorama
 
import time
# from tqdm import tqdm, 
from tqdm.notebook import tqdm
from colorama import Fore

In [None]:
epochs = 10
 
epoch_gen_losses = []
epoch_disc_losses = []
mean_epoch_gen_loss = 0
mean_epoch_disc_loss = 0
 
fake_H_imgs = []
fake_Z_imgs = []
temp_gen_losses = []
temp_disc_losses = []
 
print("Training started!")
for epoch in range(epochs):
  mean_batch_gen_loss = 0
  mean_batch_disc_loss = 0
  
  print("Epoch number "+ str(epoch))
 
  start_time = time.time()
  for i, (real_H, real_Z) in tqdm(enumerate(zip(datagen_H, datagen_Z)), total=len(datagen_Z), bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.BLUE, Fore.RESET)):
 
    # Get the loss and gradient values from the generator
    gen_loss, grads_HZ, grads_ZH, fake_H, fake_Z = get_gen_loss_and_grads(real_H, real_Z)
      
    # Apply the gradients to the generators
    gen_optim.apply_gradients(zip(grads_HZ, gen_HZ.trainable_variables))
    gen_optim.apply_gradients(zip(grads_HZ, gen_HZ.trainable_variables))
 
    # Get the loss and gradient values from the discriminator
    disc_loss, grads_H, grads_Z = get_disc_loss_and_grads(real_H, real_Z, fake_H, fake_Z)
 
    # Apply the gradients to the discriminator
    disc_optim.apply_gradients(zip(grads_H, disc_H.trainable_variables))
    disc_optim.apply_gradients(zip(grads_Z, disc_Z.trainable_variables))
 
    # Append some fake images to see progress 
    if i % 100 == 0:
      temp_gen_losses.append(gen_loss)
      temp_disc_losses.append(disc_loss)
    if i % 250 == 0:
      fake_H_imgs.append(fake_H)
      fake_Z_imgs.append(fake_Z)     
      print("i = " + str(i) + ":  gen loss = " + str(gen_loss))
      print("i = " + str(i) + ":  disc loss = " + str(disc_loss))
 
    # Save the losses for the current batch
    mean_batch_gen_loss += gen_loss.numpy()/len(datagen_Z)
    mean_batch_disc_loss += disc_loss.numpy()/len(datagen_Z)
 
    # Stop this epoch after one run through the generator
    if (i == len(datagen_Z)):
      break
 
      
 
  # == End of the first loop =============================
 
  # Save the losses for the current epoch
  epoch_gen_losses.append(mean_batch_gen_loss)
  mean_epoch_gen_loss += mean_batch_gen_loss/epochs
 
  epoch_disc_losses.append(mean_batch_disc_loss)
  mean_epoch_disc_loss += mean_batch_disc_loss/epochs
 
  # Update the learning rate
  change_lr(gen_optim)
  change_lr(disc_optim)
 
  # Check for early stopping on the generator
  if earlyStopping(epoch_gen_losses):
    break
 
  # Save the model weights if needed
  save_best_weights(gen_HZ, epoch_gen_losses, "gen_HZ")
  save_best_weights(gen_ZH, epoch_gen_losses, "gen_ZH")
 
  save_best_weights(disc_H, epoch_disc_losses, "disc_H")
  save_best_weights(disc_Z, epoch_disc_losses, "disc_Z")
 
  end_time = time.time()
  print("time: " + str(end_time-start_time))
 
  plt.plot(temp_gen_losses)
  plt.ylabel("generator losses")
  plt.show()

  plt.plot(temp_disc_losses)
  plt.ylabel("Discriminator losses")
  plt.show()

In [None]:
plt.imshow(np.squeeze(fake_Z_imgs[15]))

In [None]:
reals = next(datagen_H)

In [None]:
pred_fake = disc_H(fake_H_imgs[0])
pred_real = disc_H(reals)