In [None]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

# ========================
# 1. Load the MNIST Dataset
# ========================

# Load the dataset
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

# Preprocess the data: Normalize to [-1, 1]
x_train = x_train.astype('float32') / 255.0  # Normalize to [0, 1]
x_train = x_train * 2 - 1  # Normalize to [-1, 1]
x_train = x_train.reshape(-1, 784)  # Flatten images

batch_size = 100

# Create a tf.data.Dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# ========================
# 2. Define the Generator Network
# ========================

def build_generator(input_size=100, output_size=784):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, input_dim=input_size),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(512),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(output_size, activation='tanh'),
    ])
    return model

# ========================
# 3. Define the Discriminator Network
# ========================

def build_discriminator(input_size=784):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(512, input_dim=input_size),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dense(256),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Dense(1, activation='sigmoid'),
    ])
    return model

# ========================
# 4. Initialize Networks, Loss Function, and Optimizers
# ========================

# Instantiate the networks
generator = build_generator()
discriminator = build_discriminator()

# Loss function
cross_entropy = tf.keras.losses.BinaryCrossentropy()

# Optimizers
lr = 0.0002
generator_optimizer = tf.keras.optimizers.Adam(lr)
discriminator_optimizer = tf.keras.optimizers.Adam(lr)

# Ensure the output directory exists
os.makedirs('generated_images', exist_ok=True)

# Fixed noise for consistent image generation
fixed_noise = tf.random.normal([64, 100])

# ========================
# 5. Training Loop
# ========================

num_epochs = 100
G_losses = []
D_losses = []
img_list = []
epochs_to_save = [1, 10, 50, 100]

# Training loop
for epoch in range(1, num_epochs + 1):
    for real_images in train_dataset:
        batch_size = real_images.shape[0]
        real_images = tf.reshape(real_images, [batch_size, -1])
        real_labels = tf.ones((batch_size, 1))
        fake_labels = tf.zeros((batch_size, 1))

        # Generate noise
        noise = tf.random.normal([batch_size, 100])

        # ---------------------
        # Train Discriminator
        # ---------------------
        with tf.GradientTape() as disc_tape:
            # Generate fake images
            fake_images = generator(noise, training=True)

            # Discriminator outputs
            real_output = discriminator(real_images, training=True)
            fake_output = discriminator(fake_images, training=True)

            # Compute losses
            d_loss_real = cross_entropy(real_labels, real_output)
            d_loss_fake = cross_entropy(fake_labels, fake_output)
            d_loss = d_loss_real + d_loss_fake

        # Compute gradients and update discriminator
        gradients_of_discriminator = disc_tape.gradient(d_loss, discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

        # ---------------------
        # Train Generator
        # ---------------------
        with tf.GradientTape() as gen_tape:
            # Generate fake images
            fake_images = generator(noise, training=True)

            # Discriminator output
            fake_output = discriminator(fake_images, training=True)

            # Generator wants discriminator to believe generated images are real
            g_loss = cross_entropy(real_labels, fake_output)

        # Compute gradients and update generator
        gradients_of_generator = gen_tape.gradient(g_loss, generator.trainable_variables)
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

    # Record losses
    G_losses.append(g_loss.numpy())
    D_losses.append(d_loss.numpy())

    # Save generated images at specified epochs
    if epoch in epochs_to_save:
        # Generate images from fixed noise
        fake_images = generator(fixed_noise, training=False)
        fake_images = tf.reshape(fake_images, [-1, 28, 28, 1])
        fake_images = (fake_images + 1) / 2.0  # Rescale images to [0,1]

        # Create a grid of images and save
        def save_images(images, epoch):
            grid_size = int(np.sqrt(images.shape[0]))
            fig, axs = plt.subplots(grid_size, grid_size, figsize=(grid_size, grid_size))
            idx = 0
            for i in range(grid_size):
                for j in range(grid_size):
                    axs[i, j].imshow(images[idx, :, :, 0], cmap='gray')
                    axs[i, j].axis('off')
                    idx += 1
            plt.subplots_adjust(wspace=0, hspace=0)
            plt.savefig(f'generated_images/generated_epoch_{epoch}.png')
            plt.close()

        save_images(fake_images.numpy(), epoch)
        print(f'Epoch [{epoch}/{num_epochs}]  Loss D: {d_loss.numpy():.4f}, Loss G: {g_loss.numpy():.4f}')

# ========================
# 6. Plot Loss Curves
# ========================

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(range(1, num_epochs + 1), G_losses, label="Generator")
plt.plot(range(1, num_epochs + 1), D_losses, label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig('generated_images/loss_curves.png')
plt.show()

# ========================
# 7. Visualize Generated Images
# ========================

def show_saved_images(epochs):
    for epoch in epochs:
        img_path = f'generated_images/generated_epoch_{epoch}.png'
        image = mpimg.imread(img_path)
        plt.figure(figsize=(8, 8))
        plt.axis('off')
        plt.title(f'Generated Images at Epoch {epoch}')
        plt.imshow(image)
        plt.show()

# Display the images for epochs 1, 10, 50, and 100
show_saved_images(epochs_to_save)
