<a href="https://colab.research.google.com/github/filsto/GAN/blob/main/GAN_keras_inf%C3%A9rieur.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from Tensorflow.keras import layers

# create the discriminator
discriminator = keras.Sequential([
                                  keras.Input(shape=(28,28,1)),
                                  layers.Conv2D(61,(3,3), strides=(2,2), padding='same'),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.Conv2D(128,(3,3), strides=(2,2), padding='same'),
                                  layers.LeakyReLU(alpha=0.2),
                                  layers.GlobalMaxPooling2D(),
                                  layers.Dense(1)
], name='discriminator')

# create the generator
latent_dim = 128
generator = keras.Sequential([
                              keras.Input(shape=(latent_dim,)),
                              #we want to generate 128 coefficients to reshape into a 7x7x128 map
                              layers.Dense(7*7*128),
                              layers.LeakyReLU(alpha=0.2),
                              layers.Reshape((7,7,128)),
                              layers.Conv2DTranspose(128,(4,4),strides=(2,2), padding='same'),
                              layers.LeakyReLU(alpha=0.2),
                              layers.Conv2DTranspose(128,(4,4),strides=(2,2), padding='same'),
                              layers.LeakyReLU(alpha=0.2),
                              layers.Conv2D(1,(7,7), padding='same', activation='sigmoid'),
], name='generator')


In [None]:
class GAN(keras.Model):
  
  def __init__(self, discriminator, generator, latent_dim):
    super(GAN, self).__init__()
    self.discriminator = discriminator
    self.generator = generator
    self.latent_dim = latent_dim

  def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

  def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
                           
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
            
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}