In [None]:

# Install required libraries
!pip install tensorflow numpy matplotlib

# Define theoretical concepts
# A GAN consists of two networks: the Generator and the Discriminator.
# The Generator creates fake images, and the Discriminator tries to distinguish between real and fake images.
# They are trained together in a zero-sum game: the Generator tries to improve its fake images to fool the Discriminator,
# while the Discriminator tries to get better at distinguishing real from fake.


In [10]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import mnist


In [11]:
def build_generator():
    model = Sequential()
    model.add(Dense(128, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(784, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

# Define the Discriminator model
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model



In [12]:
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
gan = build_gan(generator, discriminator)
gan.compile(optimizer='adam', loss='binary_crossentropy')

In [None]:
# Display the models
generator.summary()
discriminator.summary()
gan.summary()

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

In [22]:
@tf.function
def train_step(real_images, batch_size):
    # Generate fake images
    noise = tf.random.normal([batch_size, 100])
    fake_images = generator(noise, training=True)

    # Train Discriminator
    with tf.GradientTape() as disc_tape:
        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)
        disc_loss_real = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
        disc_loss_fake = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = disc_loss_real + disc_loss_fake

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Train Generator
    with tf.GradientTape() as gen_tape:
        fake_images = generator(noise, training=True)
        fake_output = discriminator(fake_images, training=True)
        gen_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

    return disc_loss, gen_loss

In [23]:
def train_gan(epochs, batch_size=64):
    # Load and preprocess MNIST dataset
    (X_train, _), (_, _) = mnist.load_data()
    X_train = (X_train / 127.5) - 1.0
    X_train = np.expand_dims(X_train, axis=-1)

    for epoch in range(epochs):
        # Select a random batch of real images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx]

        # Perform a training step
        disc_loss, gen_loss = train_step(real_images, batch_size)

        # Print progress
        if epoch % 1000 == 0:
            print(f"{epoch} [D loss: {np.mean(disc_loss):.4f}] [G loss: {np.mean(gen_loss):.4f}]")
            # Generate and display images
            generate_and_display_images(generator)

def generate_and_display_images(generator, num_images=16):
    noise = np.random.randn(num_images, 100)
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # Rescale images to [0, 1]

    fig, axs = plt.subplots(4, 4, figsize=(8, 8))
    cnt = 0
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(generated_images[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# Run training


In [None]:
train_gan(epochs=1000, batch_size=64) #change epochs to improve loss