<a href="https://colab.research.google.com/github/AhmedBaari/Deep-Learning-Essentials/blob/main/11%20-%20GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ===== PART 1: DOWNLOAD & SAVE MNIST TO FOLDER =====
import os
from PIL import Image
import tensorflow as tf

def save_mnist_to_folder(folder='mnist_data'):
    # Download MNIST
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Create folder structure: mnist_data/0/, mnist_data/1/, etc.
    for digit in range(10):
        os.makedirs(f'{folder}/{digit}', exist_ok=True)

    # Save training images
    print("Saving images to folder...")
    for i, (img, label) in enumerate(zip(x_train, y_train)):
        img_path = f'{folder}/{label}/train_{i}.png'
        Image.fromarray(img).save(img_path)

    print(f"Saved {len(x_train)} images to '{folder}' folder!")
    return folder

# Execute: Save and Load
folder_name = 'mnist_data'
if not os.path.exists(folder_name):
    save_mnist_to_folder(folder_name)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Saving images to folder...
Saved 60000 images to 'mnist_data' folder!


In [None]:
# ULTRA-MINIMAL GAN - EASIEST TO MEMORIZE
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, LeakyReLU, Input, Reshape
from tensorflow.keras.models import Sequential, Model
import os
from PIL import Image

# 1. LOAD DATA using image_dataset_from_directory
def preprocess_image(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = (image - 0.5) / 0.5 # Scale to [-1, 1]
    return image

batch_size = 128
image_size = (28, 28)
data_dir = 'mnist_data'

train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    labels='inferred',
    label_mode=None, # We don't need labels for GAN training
    image_size=image_size,
    color_mode='grayscale', # Load images in grayscale
    interpolation='nearest',
    batch_size=batch_size,
    shuffle=True
)

# Apply preprocessing
train_ds = train_ds.map(preprocess_image)
train_ds = train_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)


# 2. BUILD GENERATOR
def make_generator():
    model = Sequential([
        Dense(128, input_dim=100),
        LeakyReLU(0.2),
        Dense(784, activation='tanh'),
        Reshape((28, 28, 1)) # Reshape output to image format
    ])
    return model

# 3. BUILD DISCRIMINATOR
def make_discriminator():
    model = Sequential([
        Reshape((784,), input_shape=(28, 28, 1)), # Flatten the input image
        Dense(128, input_dim=784),
        LeakyReLU(0.2),
        Dense(1, activation='sigmoid')
    ])
    return model

# 4. CREATE MODELS
D = make_discriminator()
D.compile(optimizer='adam', loss='binary_crossentropy')

G = make_generator()
D.trainable = False
GAN = Sequential([G, D])
GAN.compile(optimizer='adam', loss='binary_crossentropy')

# 5. TRAIN
epochs = 5000
for epoch in range(epochs):
    for image_batch in train_ds:
        current_batch_size = tf.shape(image_batch)[0] # Get the actual batch size

        # Train Discriminator
        real = image_batch
        fake = G.predict(tf.random.normal([current_batch_size, 100]), verbose=0)

        D.train_on_batch(real, tf.ones((current_batch_size, 1)))
        D.train_on_batch(fake, tf.zeros((current_batch_size, 1)))

        # Train Generator
        GAN.train_on_batch(tf.random.normal([current_batch_size, 100]), tf.ones((current_batch_size, 1)))

    if epoch % 500 == 0:
        print(f'Epoch {epoch}')

# 6. GENERATE & SHOW
imgs = G.predict(tf.random.normal([25, 100]), verbose=0)
imgs = (imgs + 1) / 2  # Scale to [0, 1]

fig, axes = plt.subplots(5, 5, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(imgs[i].reshape(28, 28), cmap='gray')
    ax.axis('off')
plt.show()

Found 60000 files.


  super().__init__(**kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 0
