In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
# Load MNIST dataset
(X_train, _), (_, _) = keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32) / 255.0  # Normalize to [0, 1]
X_train = np.expand_dims(X_train, axis=-1)  # Add channel dimension


In [None]:
# Parameters
latent_dim = 100  # Dimension of the random noise vector
batch_size = 128
epochs = 10000
sample_interval = 1000  # Interval for saving generated images


In [None]:

# Generator model
def build_generator():
    model = keras.Sequential([
        layers.Dense(256, activation='relu', input_dim=latent_dim),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(1024, activation='relu'),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(np.prod(X_train.shape[1:]), activation='sigmoid'),  # Changed to sigmoid
        layers.Reshape(X_train.shape[1:])
    ])
    return model

In [None]:
# Discriminator model
def build_discriminator():
    model = keras.Sequential([
        layers.Flatten(input_shape=X_train.shape[1:]),
        layers.Dense(512, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    return model


In [None]:
# Build and compile the models
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# GAN model
z = layers.Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False  # Freeze the discriminator when training GAN
validity = discriminator(img)
gan = keras.Model(z, validity)
gan.compile(optimizer='adam', loss='binary_crossentropy')


In [None]:
# Function to save generated images
def sample_images(epoch):
    noise = np.random.normal(0, 1, (25, latent_dim))
    generated_imgs = generator.predict(noise)
    generated_imgs = generated_imgs  # No rescaling needed since we use sigmoid

    plt.figure(figsize=(5, 5))
    for i in range(generated_imgs.shape[0]):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_imgs[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"gan_generated_epoch_{epoch}.png")
    plt.close()

In [None]:
# Training the GAN
def train_gan(epochs, batch_size, sample_interval):
    # Lists to store loss values
    d_losses = []
    g_losses = []

    for epoch in range(epochs):
        # Train the discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

        # Store the losses
        d_losses.append(d_loss[0])
        g_losses.append(g_loss[0])

        # Print the progress at defined intervals
        if epoch % sample_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, D acc.: {100 * d_loss[1]:.2f}%] [G loss: {g_loss[0]:.4f}]")
            sample_images(epoch)

    # Optionally, print the final loss values after training
    print("\nFinal Loss Values:")
    for e in range(0, epochs, sample_interval):
        print(f"Epoch {e}: D loss: {d_losses[e // sample_interval]:.4f}, G loss: {g_losses[e // sample_interval]:.4f}")

# Start training
train_gan(epochs, batch_size, sample_interval)


[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
0 [D loss: 2.1368, D acc.: 50.36%] [G loss: 2.1446]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 107ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m4/4[0m