# A Simple GAN in Tensorflow

## The Generator Network

In [3]:
import tensorflow as tf

In [4]:
def build_generator(latent_dim, output_shape):
    model = tf.keras.Sequential(name='Generator')
    model.add(tf.keras.layers.Input(shape=(latent_dim,)))

    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dense(256, activation='relu'))
    model.add(tf.keras.layers.Dense(output_shape, activation='tanh'))
    return model

latent_dim = 100
output_dim = 784
generator = build_generator(latent_dim, output_dim)
generator.summary()

## The Discriminator Network

In [6]:
def build_discriminator(input_shape):
    model = tf.keras.Sequential(name='Discriminator')
    model.add(tf.keras.layers.Input(shape=(input_shape,)))

    model.add(tf.keras.layers.Dense(256, activation='relu'))
    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.3))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    return model

discriminator = build_discriminator(output_dim)
discriminator.summary()

## Defining the Loss Functions

The adversarial training requires distinct loss functions for the Generator and the Discriminator. We typically use Binary Cross-Entropy loss (tf.keras.losses.BinaryCrossentropy) because the Discriminator performs binary classification (real vs. fake).

In [7]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

## Optimizers

In [9]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
discriminator_optmizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

Learning rates might require tuning; sometimes different rates are used for the generator and discriminator.

## The Training Loop

GAN training requires a custom training loop because the updates for the Generator and Discriminator must be carefully orchestrated. Standard model.fit() is not directly applicable. We'll use tf.GradientTape to compute gradients for each network.

### structure of a single training step

In [10]:
@tf.function
def train_step(real_images, generator, discriminator, gen_optimizer, disc_optimizer, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # calc. losses
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # calc. gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(gradients_of_discriminator, generator.trainable_variables))

    return gen_loss, disc_loss

## Training Epoch

A full training process involves iterating this train_step over multiple epochs and batches of the real dataset.

In [None]:
epochs = ...
batch_size = ...
dataset = ...

for epoch in range(epochs):
    print(epoch)
    epoch_gen_loss_avg = tf.keras.metrics.Mean()
    epoch_disc_loss_avg = tf.keras.metrics.Mean()

    for image_batch in dataset:
        gen_loss, disc_loss = train_step(
            image_batch,
            generator,
            discriminator,
            generator_optimizer,
            discriminator_optimizer,
            batch_size,
            latent_dim
        )
        epoch_gen_loss_avg.update_state(gen_loss)
        epoch_disc_loss_avg.update_state(disc_loss)

        print(gen_loss, disc_loss)

        epoch_gen_loss_avg.reset_states()
        epoch_disc_loss_avg.reset_states()