In [None]:
import tensorflow as tf
print(tf.__version__)

In [None]:
if tf.config.list_physical_device("GPU"):
    device_name = tf.test.gpu_device_name()
else:
    device_name = "CPU"

print(f"Using device: {device_name}")

In [None]:
import tensorflow_datasets as tfds
import numpy as np

def make_generator_network(
        num_hidden_layers=1,
        num_hidden_units=100,
        num_output_units=784,
):
    model = tf.keras.Sequential()
    for i in range(num_hidden_layers):
        model.add(tf.keras.layers.Dense(num_hidden_units, use_bias=False))
        model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Dense(num_hidden_units, activation="tanh"))
    return model


def make_discriminator_network(
    num_hidden_layers=1,
    num_hidden_units=100,
    num_output_units=1,
):
    model = tf.keras.Sequential()
    for i in range(num_hidden_layers):
        model.add(tf.keras.layers.Dense(num_hidden_units, use_bias=False))
        model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
        model.add(tf.keras.layers.Dropout(rate=0.5))
    
    model.add(tf.keras.layers.Dense(num_output_units, activation=None))
    return model

In [None]:
image_size = (28, 28)
z_size = 20
mode_z = "uniform"
gen_hidden_layers = 1
gen_hidden_size = 100
disc_hidden_layers = 1
disc_hidden_size = 100
tf.random.set_seed(1)

gen_model = make_generator_network(
    num_hidden_layers=gen_hidden_layers,
    num_hidden_units=gen_hidden_size,
    num_output_units=np.prod(image_size),
)

gen_model.build(input_shape=(None, z_size))
gen_model.summary()

disc_model = make_discriminator_network(
    num_hidden_layers=disc_hidden_layers,
    num_hidden_units=disc_hidden_size,
    num_output_units=1,
)

disc_model.build(input_shape=(None, np.prod(image_size)))
disc_model.summary()

In [None]:
mnist_bldr = tfds.builder("mnist")
mnist_bldr.download_and_prepare()
mnist = mnist_bldr.as_dataset(shuffle_files=False)

def preprocess(ex, mode="uniform"):
    image = ex["image"]
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.reshape(image, [-1])
    image = image * 2 - 1.0

    if mode == "uniform":
        input_z = tf.random.uniform(shape=(z_size,), minval=-1.0, maxval=1.0)
    elif mode == "normal":
        input_z = tf.random.normal(shape=(z_size,))
    return input_z, image

mnist_train = mnist["train"]
mnist_test = mnist_train.map(preprocess)

In [None]:
mnist_train = mnist_train.batch(32, drop_remainder=True)
input_z, input_real = next(iter(mnist_train))

g_output = gen_model(input_z)

d_logits_real = disc_model(input_real)
d_logits_fake = disc_model(g_output)

In [None]:
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

g_labels_real = tf.ones_like(d_logits_real)
g_labels_fake = tf.zeros_like(d_logits_fake)
g_loss = loss_fn(g_labels_real, d_logits_fake)
print("Generator loss: ", g_loss.numpy())

d_labels_real = tf.ones_like(d_logits_real)
d_labels_fake = tf.zeros_like(d_logits_fake)

d_loss_real = loss_fn(d_labels_real, d_logits_real)
d_loss_fake = loss_fn(d_labels_fake, d_logits_fake)
print("Discriminator loss: Real", d_loss_real.numpy(), "Fake", d_loss_fake.numpy())

In [None]:
import time

num_epochs = 100
batch_size = 64
image_size = (28, 28)
z_size = 20
mode_z = "uniform"
gen_hidden_layers = 1
gen_hidden_size = 100
disc_hidden_layers = 1
disc_hidden_size = 100

tf.random.set_seed(1)
np.random.seed(1)

if mode_z == "uniform":
    fixed_z = tf.random.uniform(shape=(batch_size, z_size), minval=-1.0, maxval=1.0)
elif mode_z == "normal":
    fixed_z = tf.random.normal(shape=(batch_size, z_size))
else:
    raise ValueError(f"Invalid mode_z: {mode_z}")

def create_samples(g_model, input_z):
    g_output = g_model(input_z, training=False)
    images = tf.reshape(g_output, (batch_size, *image_size))
    return (images + 1) / 2.0

mnist_train = mnist['train']
mnist_train = mnist_train.map(lambda ex: preprocess(ex, mode=mode_z))
mnist_train = mnist_train.shuffle(10000)
mnist_train = mnist_train.batch(batch_size, drop_remainder=True)

with tf.device(device_name):
    gen_model = make_generator_network(
        num_hidden_layers=gen_hidden_layers,
        num_hidden_units=gen_hidden_size,
        num_output_units=np.prod(image_size),
    )
    gen_model.build(input_shape=(None, z_size))
    disc_model = make_discriminator_network(
        num_hidden_layers=disc_hidden_layers,
        num_hidden_units=disc_hidden_size,
    )
    disc_model.build(input_shape=(None, np.prod(image_size)))

loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.9)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.9)

all_losses = []
add_d_vals = []
epoch_samples = []

start_time = time.time()
for epoch in range(1, num_epochs + 1):
    epoch_losses, epoch_d_vals = [], []
    for i, (input_z, input_real) in enumerate(mnist_train):
        with tf.GradientTape() as tape:
            g_output = gen_model(input_z, training=True)
            d_logits_real = disc_model(input_real, training=True)
            d_logits_fake = disc_model(g_output, training=True)

            g_labels_real = tf.ones_like(d_logits_real)
            g_labels_fake = tf.zeros_like(d_logits_fake)
            g_loss = loss_fn(g_labels_real, d_logits_fake)
        
        g_grads = tape.gradient(g_loss, gen_model.trainable_variables)
        g_optimizer.apply_gradients(
            zip(g_grads, gen_model.trainable_variables)
        )

        with tf.GradientTape() as tape:
            d_logits_real = disc_model(input_real, training=True)
            d_logits_fake = disc_model(g_output, training=True)

            d_labels_real = tf.ones_like(d_logits_real)
            d_labels_fake = tf.zeros_like(d_logits_fake)
            d_loss_real = loss_fn(d_labels_real, d_logits_real)
            d_loss_fake = loss_fn(d_labels_fake, d_logits_fake)
            d_loss = d_loss_real + d_loss_fake
            
        d_grads = tape.gradient(d_loss_real + d_loss_fake, disc_model.trainable_variables)
        d_optimizer.apply_gradients(
            zip(d_grads, disc_model.trainable_variables)
        )

        epoch_losses.append((g_loss.numpy(), d_loss.numpy(), d_loss_real.numpy(), d_loss_fake.numpy()))
        d_probs_real = tf.reduce_mean(tf.sigmoid(d_logits_real))
        d_probs_fake = tf.reduce_mean(tf.sigmoid(d_logits_fake))
        epoch_d_vals.append((d_probs_real.numpy(), d_probs_fake.numpy()))

    all_losses.append(epoch_losses)
    add_d_vals.append(add_d_vals)
    print(f"Epoch {epoch} completed in {time.time() - start_time:.2f} seconds")

    epoch_samples.append(create_samples(gen_model, fixed_z).numpy())