In [1]:
import os
import numpy as np
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt

from keras.api.optimizers import Adam
from keras.api.models import Sequential
from keras.api.losses import BinaryCrossentropy
from keras.api.layers import (
    Input,
    Dense,
    LeakyReLU,
    Reshape,
    Conv2D,
    Conv2DTranspose,
    BatchNormalization,
    LeakyReLU,
    Flatten,
)

In [2]:
# Create a PhysicalDeviceSpec to set the device
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)


os.makedirs("images", exist_ok=True)

In [10]:
# Load and preprocess the MNIST dataset
def load_data() -> list:
    (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
    print(train_images.shape)
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype(
        "float32"
    )
    train_images = (train_images - 127.5) / 127.5  # Normalize to [-1, 1]
    return train_images

In [21]:
def generate_and_save_images(epoch: int) -> None:
    global generator
    test_input = tf.random.normal([25, 100])
    predictions = generator(test_input, training=False)

    fig, axes = plt.subplots(5, 5, figsize=(8, 8))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    for i, ax in enumerate(axes.flat):
        sns.heatmap(
            predictions[i, :, :, 0] * 127.5 + 127.5,
            cbar=False,
            ax=ax,
            xticklabels=False,
            yticklabels=False,
        )
        ax.axis("off")

    plt.savefig(
        f"images/image_at_epoch_{epoch:02d}.png", bbox_inches="tight", pad_inches=0
    )


def plot_losses(gen_losses: list[float], disc_losses: list[float]) -> None:
    gen_losses = np.array(gen_losses)
    disc_losses = np.array(disc_losses)
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=gen_losses, label="Generator Loss", color="blue")
    sns.lineplot(data=disc_losses, label="Discriminator Loss", color="red")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [22]:
# Create the generator model
def create_generator() -> Sequential:
    model = Sequential(
        [
            Input((100,)),
            Dense(256 * 7 * 7),
            LeakyReLU(0.01),
            Reshape((7, 7, 256)),
            Conv2DTranspose(128, kernel_size=5, strides=1, padding="same"),
            BatchNormalization(),
            LeakyReLU(0.01),
            Conv2DTranspose(64, kernel_size=5, strides=2, padding="same"),
            BatchNormalization(),
            LeakyReLU(0.01),
            Conv2DTranspose(
                1, kernel_size=5, strides=2, padding="same", activation="tanh"
            ),
        ]
    )
    return model


# Create the discriminator model
def create_discriminator() -> Sequential:
    model = Sequential(
        [
            Input((28, 28, 1)),
            Conv2D(32, kernel_size=3, strides=2, padding="same"),
            LeakyReLU(0.01),
            Conv2D(64, kernel_size=3, strides=2, padding="same"),
            BatchNormalization(),
            LeakyReLU(0.01),
            Conv2D(128, kernel_size=3, strides=2, padding="same"),
            BatchNormalization(),
            LeakyReLU(0.01),
            Flatten(),
            Dense(1),
        ]
    )
    return model

In [23]:
def discriminator_loss(real_output: list, fake_output: list) -> float:
    global cross_entropy
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


def generator_loss(fake_output: list) -> list:
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [24]:
# Training step
@tf.function
def train_step(images: list) -> tuple[float, float]:
    global generator, discriminator, batch_size, generator_optimizer, discriminator_optimizer

    noise = tf.random.normal([batch_size, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(
        disc_loss, discriminator.trainable_variables
    )

    generator_optimizer.apply_gradients(
        zip(gradients_of_generator, generator.trainable_variables)
    )
    discriminator_optimizer.apply_gradients(
        zip(gradients_of_discriminator, discriminator.trainable_variables)
    )

    return gen_loss, disc_loss


# Training loop
def train(dataset: list, epochs: int) -> None:
    global gen_losses, disc_losses

    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

        print(f"Epoch {epoch + 1}, Gen Loss: {gen_loss}, Disc Loss: {disc_loss}")
        generate_and_save_images(epoch + 1)

In [27]:
buffer_size = 59_968
batch_size = 64
epochs = 50

gen_losses = []
disc_losses = []

train_images = load_data()
train_dataset = (
    tf.data.Dataset.from_tensor_slices(train_images)
    .shuffle(buffer_size)
    .batch(batch_size)
)

In [None]:
generator_optimizer = Adam(learning_rate=0.0002)
discriminator_optimizer = Adam(learning_rate=0.0002)

cross_entropy = BinaryCrossentropy(from_logits=True)
generator = create_generator()
discriminator = create_discriminator()

train(train_dataset, epochs)

In [None]:
plot_losses(gen_losses, disc_losses)