In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import BatchNormalization, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt


In [2]:
# Load MNIST data
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

# Normalize pixel values to [-1, 1]
train_images = (train_images.astype(np.float32) - 127.5) / 127.5

# Add channel dimension for grayscale images
train_images = np.expand_dims(train_images, axis=-1)


In [3]:
def build_generator():
    noise_shape = (100,)
    
    model = tf.keras.Sequential([
        Dense(7*7*256, input_shape=noise_shape),
        Reshape((7, 7, 256)),
        BatchNormalization(),
        Conv2DTranspose(128, kernel_size=5, strides=1, padding='same', activation='relu'),
        BatchNormalization(),
        Conv2DTranspose(64, kernel_size=5, strides=2, padding='same', activation='relu'),
        BatchNormalization(),
        Conv2DTranspose(1, kernel_size=5, strides=2, padding='same', activation='tanh')
    ])
    
    return model

def build_discriminator():
    model = tf.keras.Sequential([
        Conv2D(64, kernel_size=5, strides=2, padding='same', input_shape=(28, 28, 1)),
        LeakyReLU(alpha=0.2),
        Conv2D(128, kernel_size=5, strides=2, padding='same'),
        LeakyReLU(alpha=0.2),
        Flatten(),
        Dense(1, activation='sigmoid')
    ])
    
    return model

# Instantiate generator and discriminator
generator = build_generator()
discriminator = build_discriminator()


In [4]:
discriminator.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy')



In [5]:
def build_gan(generator, discriminator):
    discriminator.trainable = False
    
    gan_input = Input(shape=(100,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    
    gan = Model(gan_input, gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    
    return gan

# Instantiate GAN
gan = build_gan(generator, discriminator)




In [6]:
def train_gan(generator, discriminator, gan, images, epochs=100, batch_size=128, sample_interval=10):
    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):
        # Select a random batch of real images
        idx = np.random.randint(0, images.shape[0], batch_size)
        real_images = images[idx]
        
        # Generate random noise as input to the generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        
        # Generate fake images using the generator
        generated_images = generator.predict(noise)
        
        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train the generator (via the GAN model)
        g_loss = gan.train_on_batch(noise, real_labels)
        
        # Print training progress
        if epoch % sample_interval == 0:
            print(f"Epoch {epoch}: D loss = {d_loss}, G loss = {g_loss}")
            save_generated_images(epoch, generator)

def save_generated_images(epoch, generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"gan_generated_image_epoch_{epoch}.png")
    plt.close()

# Train the GAN
train_gan(generator, discriminator, gan, train_images, epochs=200, batch_size=128, sample_interval=20)


Epoch 0: D loss = 0.6718755662441254, G loss = 0.5675057172775269
Epoch 20: D loss = 0.00114081273204647, G loss = 0.040155261754989624
Epoch 40: D loss = 0.03497873508331395, G loss = 0.30469995737075806
Epoch 60: D loss = 0.012010554084554315, G loss = 0.21024367213249207
Epoch 80: D loss = 1.2113357660951891, G loss = 0.6726780533790588
Epoch 100: D loss = 0.03863152489066124, G loss = 0.050005000084638596
Epoch 120: D loss = 0.03778000921010971, G loss = 0.030853452160954475
Epoch 140: D loss = 0.02581699239090085, G loss = 0.01418351661413908


Epoch 160: D loss = 0.021660474129021168, G loss = 0.009275875985622406
Epoch 180: D loss = 0.022297436371445656, G loss = 0.015936240553855896
