# Creating a baseline Generative Adversarial Network

In [None]:
# Building a simple GAN
# The generator is similar to an autoencoder's decoder, and the discriminator is
# a regular binary classifier (it takes an image as input and ends with a Dense
# layer contianing a single unit and using the sigmoid activation functions). For
# the second phase of each training iteration, we also need the full GAN model
# contianing the generator followed by the discriminator.

codings_size = 30

generator = keras.models.Sequential([
            keras.layers.Dense(100, activation='selu', input_shape=[codings_size]),
            keras.layers.Dense(150, activation='selu'),
            keras.layers.Dense(28*28, activation='sigmoid'),
            keras.layers.Reshape([28,28])
])

discriminator = keras.models.Sequential([
                keras.layers.Flatten(input_shape=[28,28]),
                keras.layers.Dense(150, activation='selu'),
                keras.layers.Dense(100, activation='selu'),
                keras.layers.Dense(1, activation='sigmoid')
])

gan = keras.models.Sequential([generator, discriminator])

# Next, we need to compile these models. As the discrimiantor is a binary classifier
# we can naturally use binary-crossentropy loss. The generator will only
# be trained through the gan model, so we do not need to compile it at all.
# the gan model is also binary classifier, so it can use the binary-cross entropy
# loss. Importantly, the discriminator should not be trained during the second phase,
# so we make it non-trianable before compiling the gan model.

discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')
discriminator.trainable = False
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')

# The trainable attribute is taken into account by Keras only when compiling a model,
# so after running this code, the discriminator is trainable if we call the fit(0 method
# or its train_on_batch() method (which we will be using), while it is not trainable
# when we call these methods on the gan model)

# Since the trianing loop is unusual, we cannot use the regular fit() method. Instead
# we will write a custom training loop. For this, we first need to create a Dataset
# to iterate through the images.

batch_size=32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shufflt(1000)
dataset= dataset.batch(batch_size, drop_remainder=True).prefetch(1)

# We are now ready to write the training loop. Let's wrap it in a train_gan(0) function.

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epohcs):
      for x_batch in dataset:
        #phase 1 training the discriminator
        noise = tf.random.normal(shape=[batch_size, codings_size])
        generated_images = generator(noise)
        X_fake_and_real = tf.concat([generated_images, x_batch], axis=0)
        y1 = tf.constant([[0.]]*batch_size + [[1.]] * batch_size)
        discriminator.trainable = True
        discriminator>train_on_batch(x_fake_and_real, y1)
        # phase 2, training on the generator
        y2 = tf.constant([[1.]]*batch_size)
        discriminator.trainable = False
        gan.train_on_batch(noise, y2)

train_gan(gan, dataset, batch_size, codings_size)

# A GAN can only reach a single nash equilibrium, where changing strategies offers
# no benefits: when the generator produces perfectly realistic images and the 
# discriminator is forced to gues 50/50. No guarantee this equilibrium will
# ever be reached. Mode collapse, when generator's outputs are less diverse.
# picks one thing it tricked the discrim on, and forgets everything else.
# experience replay to drop old images, so it forgets what its good at, and instead
# focuses on being good.

# Creating a Deep Convolutional GAN

In [None]:
# Deep Convolutional GANs
# For larger images. Replace any pooling layers with strided convs (in the discr)
# and transposed convolutions (in the generator).
# Use Batch Normalization in both the generator and discriminator, except in 
# the generator's output layer and the discriminator's input layer.
# Remove fully connected hidden layers for deeper architectures.
# Use RELU activation in the generator for all layers except the output layer,
# which should use tanh.
# Use Leaky Relu activation in the discriminator for all layers.

codings_size = 100

generator = keras.models.Sequential([
            keras.layers.Dense(7*7*128, input_shape=[codings_size]),
            keras.layers.Reshape([7, 7, 128]),
            keras.layers.BatchNormalization(),
            keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding='same',
                                         activation='selu')
            keras.layers.BatchNormalization(),
            keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding='same',
                                         activation='tanh')
])

discriminator = keras.models.Sequential([
                keras.layers.Conv2D(64, kernel_size=5, strides=2, padding='same',
                                    activation=keras.layers.LeakyReLU(0.2),
                                    input_shape=[28,28,1]),
                keras.layers.Dropout(0.4),
                keras.layers.Conv2D(129, kernel_size=5, strides=2, padding='same',
                                    activation=keras.layers.LeakReLU(0.2)),
                keras.layers.Dropout(0.4),
                keras.layers.Flatten(),
                keras.layers.Dense(1, activation='sigmoid')
])

gan = keras.models.Sequential([generator, discriminator])

# The generator takes codings of size 100, and it projects them to 6272 dimensions
# 7*7*128, and reshapes the reslt to get a 7x7x128 tensor. This tensor is batch norm
# and fed to a transposed conv layer with a stride of 2, which upsamples it from 7x7 to
# 14x14 and reduces the depth from 128 to 64. The result is batch normed again and fed to
# another transposed conv layer with a stride of 2, which upsamples it from 14x14 to 28x28
# and reduces the depth from 64 to 1. This layer uses the tanh activation function, so the 
# outputs will range from -1 to 1. For this reason, before training the GAN, we need to 
# rescale the training set to that same range. We also need to rehspae it to add the 
# channel dimension.

X_train = X_train.reshape(-1,28,28,1) *2. - 1. #reshape and rescale

# The discriminator looks much like a regular CNN for binary class, except of a
# max pool layer to downsample the image, we use strided convs. Lastly to build
# the dataset, then compile and train this model, we use the exact same code as 
# earlier. After 50 epochs, the generator produces images.

# If you add each image's class as a extra input to both the generator
# and the discriminator, they will both learn what each class looks like,
# and thus you will be able to control the class of each image produced
# by the generator. This is called a Conditional GAN.