In [2]:
import tensorflow as tf
import numpy as np
import math
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Input
from tensorflow.keras.optimizers import Adam

# Load and preprocess the MNIST data
# Normalize images to [-1, 1] and flatten to vectors of size 784
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.astype(np.float32) / 127.5 - 1.
X_train = X_train.reshape(-1, 784)

# Set dimensionality of the latent space (noise)
latent_dim = 100

# ---------------------
# Build the Generator
# ---------------------
# The generator is a simple MLP: it takes noise as input and outputs a 784-dim vector (28x28 image)
generator = Sequential()
generator.add(Dense(256, input_dim=latent_dim))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(512))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(1024))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(784, activation='tanh'))

# ---------------------
# Build the Discriminator
# ---------------------
# The discriminator is also an MLP: it takes a flattened image and outputs a probability.
discriminator = Sequential()
discriminator.add(Dense(512, input_dim=784))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# ---------------------
# Build the Combined Model
# ---------------------
# Freeze the discriminator when training the generator via the combined model.
discriminator.trainable = False
z = Input(shape=(latent_dim,))
img = generator(z)
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# ---------------------
# Training Loop
# ---------------------
epochs = 10000       # total iterations/epochs
batch_size = 64
sample_interval = 1000  # interval to output training progress

num_batches = math.floor(X_train.shape[0] / batch_size)

for epoch in range(epochs):
    # ---------------------
    #  Train Discriminator
    # ---------------------
    # Select a random batch of real images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]

    # Generate a batch of fake images from random noise
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    gen_imgs = generator.predict(noise)

    # Labels for real and fake images
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    # Train the discriminator on both real and fake images
    d_loss_real = discriminator.train_on_batch(imgs, valid)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------
    # Generate new noise and try to fool the discriminator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    g_loss = combined.train_on_batch(noise, valid)

    # Optionally print progress
    if epoch % sample_interval == 0:
        print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

# ---------------------
# Sample and display some generated images after training
# ---------------------
noise = np.random.normal(0, 1, (16, latent_dim))
gen_imgs = generator.predict(noise)
# Rescale images from [-1, 1] to [0, 1]
gen_imgs = 0.5 * gen_imgs + 0.5
gen_imgs = gen_imgs.reshape(16, 28, 28)

# Plot generated images in a 4x4 grid
plt.figure(figsize=(6,6))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(gen_imgs[i], cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.show()


KeyboardInterrupt: 