In [1]:
import tensorflow as tf


In [2]:
def build_generator(latent_dim, output_dim):
    z = tf.keras.Input(shape=(latent_dim,))

    x = tf.keras.layers.Dense(128, activation="relu")(z)
    out = tf.keras.layers.Dense(output_dim, activation="tanh")(x)

    return tf.keras.Model(z, out, name="Generator")

latent_dim = 100
output_dim = 784
generator = build_generator(latent_dim, output_dim)
generator.summary()

In [4]:
def build_discriminator(input_dim):
    x_in = tf.keras.Input(shape=(input_dim,))

    x = tf.keras.layers.Dense(256, activation="relu")(x_in)
    x = tf.keras.layers.Dense(128, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    out = tf.keras.layers.Dense(1, activation="sigmoid")(x)

    return tf.keras.Model(x_in, out, name="Discriminator")


discriminator = build_discriminator(output_dim)
discriminator.summary()

In [5]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(f-ake_output), fake_output)

In [6]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
discriminator_optmizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

In [7]:
@tf.function
def train_step(real_images, generator, discriminator, gen_optimizer, disc_optimizer, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # calc. losses
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # calc. gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(gradients_of_discriminator, generator.trainable_variables))

    return gen_loss, disc_loss