In [None]:
# ======================================
# 1️ Imports
# ======================================
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# ======================================
# 2  Load & Preprocess Dataset (CIFAR-10)
# ======================================
(X_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()

# Normalize images to [-1, 1]
X_train = (X_train.astype("float32") - 127.5) / 127.5

BATCH_SIZE = 128
LATENT_DIM = 100

# ======================================
# 3️ Generator Model (Dense + BatchNorm) -> 32x32x3
# ======================================
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(256, input_dim=LATENT_DIM),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Dense(512),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Dense(1024),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),

        layers.Dense(32 * 32 * 3, activation="tanh"),
        layers.Reshape((32, 32, 3))
    ])
    return model

# ======================================
# 4️ Discriminator Model (Dense)
# ======================================
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(32, 32, 3)),

        layers.Dense(512),
        layers.LeakyReLU(0.2),

        layers.Dense(256),
        layers.LeakyReLU(0.2),

        layers.Dense(1, activation="sigmoid")
    ])
    return model

# ======================================
# 5️ Build & Compile Models
# ======================================
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(
    optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
    loss="binary_crossentropy"
)

# Freeze discriminator during GAN training
discriminator.trainable = False

gan_input = layers.Input(shape=(LATENT_DIM,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)

gan = tf.keras.Model(gan_input, gan_output)

gan.compile(
    optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
    loss="binary_crossentropy"
)

# ======================================
# 6️ Image Generation Function (RGB)
# ======================================
def generate_images(epoch):
    noise = np.random.normal(0, 1, (16, LATENT_DIM))
    generated_images = generator.predict(noise, verbose=0)

    # Back to [0, 1] for plotting
    generated_images = (generated_images + 1) / 2.0

    plt.figure(figsize=(6, 6))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(generated_images[i])  # RGB
        plt.axis("off")
    plt.suptitle(f"Epoch {epoch}")
    plt.show()

# ======================================
# 7️ Training Loop (Same style as yours)
# ======================================
def train_gan(epochs=300):
    half_batch = BATCH_SIZE // 2

    for epoch in range(epochs):

        # --------------------
        # Train Discriminator
        # --------------------
        real_images = X_train[np.random.randint(0, X_train.shape[0], half_batch)]
        noise = np.random.normal(0, 1, (half_batch, LATENT_DIM))
        fake_images = generator.predict(noise, verbose=0)

        # Label smoothing
        real_labels = np.ones((half_batch, 1)) * 0.9
        fake_labels = np.zeros((half_batch, 1))

        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)

        # --------------------
        # Train Generator
        # --------------------
        noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
        misleading_labels = np.ones((BATCH_SIZE, 1))  # wants D to output 1

        g_loss = gan.train_on_batch(noise, misleading_labels)

        # --------------------
        # Logging & Visualization
        # --------------------
        if epoch % 50 == 0:
            print(f"Epoch {epoch} | D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}")
            generate_images(epoch)

# ======================================
# 8️ Start Training
# ======================================
train_gan(epochs=20)