In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# Load and preprocess MNIST dataset
def load_data():
    (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5) / 127.5  # Normalize to [-1, 1]
    x_train = np.expand_dims(x_train, axis=-1)  # Add channel dimension
    return tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(128).prefetch(tf.data.AUTOTUNE)

# Build generator
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Reshape((7, 7, 256)),
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh') # [-1, 1]
    ])
    return model
#Batch Normalization normalizes activations of previous layer at each batch, i.e helps stabilize and speed up training, 

# Build discriminator
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=2, padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5, 5), strides=2, padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Optimizers
def get_optimizers():
    return tf.keras.optimizers.Adam(1e-4), tf.keras.optimizers.Adam(1e-4)

generator = build_generator()
discriminator = build_discriminator()
g_opt, d_opt = get_optimizers()
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

#   Discrimator wants real image as 1 and fake as 0
# Generator wants discriminator to think fake images are real (1)

# Training step
@tf.function  # Compiles to a TensorFlow graph for speed
def train_step(images):
    noise = tf.random.normal([128, 100])  # creates a batch of 128 random noise vectors, each of size 100-dimensional

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: #recording operations for automatic differentiation
        generated_images = generator(noise, training=True)  # generator takes the random noise and produces fake MNIST-like images resembling real MNIST digits
        real_output = discriminator(images, training=True)  # discriminator evaluates real MNIST images are real (should be close to 1)
        fake_output = discriminator(generated_images, training=True)  # discriminator predicts whether the generated  fake images are real or fake (should be close to 0)

        gen_loss = loss_fn(tf.ones_like(fake_output), fake_output)  # tf.ones_like creates a tensor of ones representing real labels
                                                                    # loss_fn (Binary crossentropy loss ) how close fake output is to 1
                                                                    # lower generator loss means the generator is successfully fooling the discriminator
                                                                    # higher generator loss means the generator is failing to fool the discriminator
                                                                    
        disc_loss = (loss_fn(tf.ones_like(real_output), real_output) + # how well discriminator classifies real images as real (close to 1)
                            loss_fn(tf.zeros_like(fake_output), fake_output)) # how well discriminator classifies fake images as fake (close to 0) - same as gen_loss
                                                                    # lower discriminator loss means it is successfully distinguishing real from fake 
    # gradient calculation and weight updates with the tape
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) # compute gradients for generator parameters
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # compute gradients for discriminator parameters
    # optimizer apply gradients
    g_opt.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))# update generator weights (when we do zip, zip means we are computing one by one and updating)
    d_opt.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # update discriminator weight
    
    return gen_loss, disc_loss

# Training loop
def train(dataset, epochs=20):
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch)
        print(f"Epoch {epoch+1}, Gen Loss: {gen_loss.numpy():.4f}, Disc Loss: {disc_loss.numpy():.4f}")

# Function to generate and display sample images from generator
def generate_and_show():
    noise = tf.random.normal([16, 100])  # Generate 16 noise vectors
    images = generator(noise, training=False)  # Generate fake images
    images = (images + 1) / 2  # Rescale from [-1, 1] to [0, 1] for display

    fig, axes = plt.subplots(4, 4, figsize=(4, 4))  # Create 4x4 grid
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i, :, :, 0], cmap='gray')  # Show single channel grayscale image
        ax.axis('off')  # Hide axis
    plt.show()

# Load the dataset and begin training
mnist_data = load_data()
train(mnist_data, epochs=10)
generate_and_show()

    
