In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
import numpy as np


In [2]:
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # Normalize the images to [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)  # Reshape for the network


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
generator = Sequential([
    Dense(256, activation='relu', input_shape=(100,)),
    Dense(512, activation='relu'),
    Dense(1024, activation='relu'),
    Dense(28 * 28 * 1, activation='tanh'),
    Reshape((28, 28, 1))
])


In [4]:
discriminator = Sequential([
    Flatten(input_shape=(28, 28, 1)),
    Dense(512, activation='relu'),
    Dense(256, activation='relu'),
    Dense(1, activation='sigmoid')
])
discriminator.compile(loss='binary_crossentropy', optimizer='adam')


In [5]:
discriminator.trainable = False
gan_input = tf.keras.Input(shape=(100,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)

gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')


In [None]:
epochs = 10000
batch_size = 32
for epoch in range(epochs):
    # Train Discriminator
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_imgs = X_train[idx]
    fake_imgs = generator.predict(np.random.normal(0, 1, (batch_size, 100)))

    d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.normal(0, 1, (batch_size, 100))
    g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

    if epoch % 1000 == 0:
        print(f"Epoch: {epoch} \t Discriminator Loss: {d_loss} \t Generator Loss: {g_loss}")


Epoch: 0 	 Discriminator Loss: 0.9586802124977112 	 Generator Loss: 0.6321882009506226












Epoch: 1000 	 Discriminator Loss: 3.4363901590950263e-06 	 Generator Loss: 14.12440013885498












Epoch: 2000 	 Discriminator Loss: 1.7568931234279717e-06 	 Generator Loss: 16.10293960571289














Epoch: 3000 	 Discriminator Loss: 2.5780396128993743e-07 	 Generator Loss: 17.60544204711914












Epoch: 4000 	 Discriminator Loss: 2.884664525304448e-07 	 Generator Loss: 14.346651077270508












Epoch: 5000 	 Discriminator Loss: 2.5390836100314118e-08 	 Generator Loss: 17.03658676147461














Epoch: 6000 	 Discriminator Loss: 2.423335672574467e-06 	 Generator Loss: 13.019712448120117












Epoch: 7000 	 Discriminator Loss: 1.279388173180962e-08 	 Generator Loss: 17.66193389892578




In [None]:
def show_generated_images(generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

show_generated_images(generator)


In [None]:
import os

def save_generated_images(generator, examples=10, prefix="image"):
    noise = np.random.normal(0, 1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # Rescale images from [-1, 1] to [0, 1]

    for i, img in enumerate(generated_images):
        plt.figure(figsize=(2.8, 2.8))
        plt.imshow(img.reshape(28, 28), cmap='gray_r')
        plt.axis('off')

        # Save the figure
        filename = f"{prefix}_{i+1}.png"
        plt.savefig(filename, bbox_inches='tight', pad_inches=0)
        plt.close()

# Example usage
save_generated_images(generator, examples=10, prefix="mnist_generated")
