In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

# Define the generator network
def build_generator():
    model = models.Sequential()
    model.add(layers.Dense(256, input_dim=100, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(784, activation='sigmoid'))
    model.add(layers.Reshape((28, 28, 1)))
    return model

# Define the discriminator network
def build_discriminator():
    model = models.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28, 1)))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

# Create the generator and discriminator
generator = build_generator()
discriminator = build_discriminator()

# Compile the discriminator
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# The generator takes random noise as input and generates images
z = layers.Input(shape=(100,))
img = generator(z)

# The discriminator will not be trained during the combined GAN model training
discriminator.trainable = False

# The discriminator takes real images and generated images as input and classifies them
validity = discriminator(img)

# The combined GAN model (stacked generator and discriminator)
combined = models.Model(z, validity)
combined.compile(optimizer='adam', loss='binary_crossentropy')

# Load and preprocess a dataset (e.g., MNIST)
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# Training parameters
batch_size = 64
epochs = 10000
sample_interval = 1000

# Training the GAN
for epoch in range(epochs):
    # Train the discriminator
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_imgs = x_train[idx]
    fake_imgs = generator.predict(np.random.rand(batch_size, 100))
    d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train the generator
    z = np.random.rand(batch_size, 100)
    g_loss = combined.train_on_batch(z, np.ones((batch_size, 1)))

    # Print progress
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}")

    # Save generated images at specified intervals
    if epoch % sample_interval == 0:
        samples = generator.predict(np.random.rand(16, 100))
        samples = 0.5 * samples + 0.5  # Rescale values from -1 to 1 to 0 to 1
        fig, axs = plt.subplots(4, 4)
        cnt = 0
        for i in range(4):
            for j in range(4):
                axs[i, j].imshow(samples[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        plt.show()
