In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

In [2]:
# Example dimensions
latent_dim = 100  # Dimension of the latent space
epochs = 10000    # Number of training epochs
batch_size = 64   # Batch size

In [3]:
# Load and preprocess the MNIST dataset
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype('float32') / 255.0  # Normalize to [0, 1]
X_train = X_train.reshape((X_train.shape[0], 784))  # Flatten images

In [4]:
# Generator model
generator = tf.keras.Sequential([
    layers.Dense(128, activation='relu', input_shape=(latent_dim,)),
    layers.Dense(256, activation='relu'),
    layers.Dense(784, activation='sigmoid')
])

In [5]:
# Discriminator model
discriminator = tf.keras.Sequential([
    layers.Dense(256, activation='relu', input_shape=(784,)),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')
])

In [6]:
# Compile discriminator
discriminator.compile(optimizer='adam', loss='binary_crossentropy')

In [7]:
# GAN model
discriminator.trainable = False
gan_input = tf.keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')

In [None]:
# Training loop
for epoch in range(epochs):
    # Train discriminator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    generated_images = generator.predict(noise)
    
    real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))
    
    d_loss_real = discriminator.train_on_batch(real_images, real_labels)
    d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
    
    # Train generator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    g_loss = gan.train_on_batch(noise, real_labels)  # We want generator to fool the discriminator
    
    if epoch % 1000 == 0:
        print(f"Epoch {epoch} | Discriminator Loss: {0.5 * (d_loss_real + d_loss_fake)} | Generator Loss: {g_loss}")