In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, LeakyReLU
from keras.optimizers import Adam
import ssl

# Allow MNIST download over unverified context
ssl._create_default_https_context = ssl._create_unverified_context

# Load and preprocess MNIST
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)


# Build the Generator
def build_generator():
    model = Sequential([
        Input(shape=(100,)),
        Dense(256),
        LeakyReLU(negative_slope=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(negative_slope=0.2),
        BatchNormalization(momentum=0.8),
        Dense(1024),
        LeakyReLU(negative_slope=0.2),
        BatchNormalization(momentum=0.8),
        Dense(784, activation='tanh'),
        Reshape((28, 28, 1))
    ])
    return model

# Build the Discriminator
def build_discriminator():
    model = Sequential([
        Input(shape=(28, 28, 1)),
        Flatten(),
        Dense(512),
        LeakyReLU(negative_slope=0.2),
        Dense(256),
        LeakyReLU(negative_slope=0.2),
        Dense(1, activation='sigmoid')
    ])
    return model

# Instantiate models
generator = build_generator()
discriminator = build_discriminator()

# Compile the discriminator (must be trainable here)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])


# Build the GAN
# Freeze the discriminator for the combined GAN
discriminator.trainable = False

gan_input = Input(shape=(100,))
img = generator(gan_input)
validity = discriminator(img)
gan = Model(gan_input, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))


# GAN Training Function
def train(epochs, batch_size=64, save_interval=100):
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        #  Train Discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]

        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)

        # Enable training on discriminator
        discriminator.trainable = True
        d_loss_real = discriminator.train_on_batch(real_imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        discriminator.trainable = False  # Re-freeze for GAN


        #  Train Generator

        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, real)

        # Logging
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
            save_images(epoch)

# Save Generated Images
def save_images(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0,1]

    fig, axs = plt.subplots(r, c, figsize=(5, 5))
    count = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"gan_image_epoch_{epoch}.png")
    plt.close()

# Train the GAN
train(epochs=10000, batch_size=64, save_interval=100)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step 
0 [D loss: 0.7360, acc.: 39.84%] [G loss: 0.5885]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 338ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step
[1m2/2[0m 