In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os

  if not hasattr(np, "object"):


In [2]:
# =========================
# 1) Load CIFAR-10
# =========================
(x_train, _), (_, _) = keras.datasets.cifar10.load_data()

# Normalize to [-1, 1] (because generator uses tanh)
x_train = x_train.astype("float32")
x_train = (x_train - 127.5) / 127.5  # [-1, 1]

BUFFER_SIZE = x_train.shape[0]
BATCH_SIZE = 256
LATENT_DIM = 128
EPOCHS = 50

train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

In [3]:
# =========================
# 2) Generator (32x32x3)
# =========================
def build_generator(latent_dim=128):
    model = keras.Sequential(name="Generator")
    model.add(layers.Input(shape=(latent_dim,)))

    # Start from 4x4
    model.add(layers.Dense(4 * 4 * 512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Reshape((4, 4, 512)))  # (4,4,512)

    # 8x8
    model.add(layers.Conv2DTranspose(256, 4, strides=2, padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))

    # 16x16
    model.add(layers.Conv2DTranspose(128, 4, strides=2, padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))

    # 32x32
    model.add(layers.Conv2DTranspose(64, 4, strides=2, padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))

    # Output 32x32x3 (RGB) in [-1, 1]
    model.add(layers.Conv2DTranspose(3, 3, strides=1, padding="same", activation="tanh"))
    return model

In [4]:
# =========================
# 3) Discriminator (32x32x3)
# =========================
def build_discriminator():
    model = keras.Sequential(name="Discriminator")
    model.add(layers.Input(shape=(32, 32, 3)))

    model.add(layers.Conv2D(64, 4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, 4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, 4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))  # logits
    return model

generator = build_generator(LATENT_DIM)
discriminator = build_discriminator()

In [5]:
# =========================
# 4) Loss + Optimizers
# =========================
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def d_loss_fn(real_logits, fake_logits):
    real_loss = bce(tf.ones_like(real_logits), real_logits)
    fake_loss = bce(tf.zeros_like(fake_logits), fake_logits)
    return real_loss + fake_loss

def g_loss_fn(fake_logits):
    return bce(tf.ones_like(fake_logits), fake_logits)

gen_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [6]:
# =========================
# 5) Save samples
# =========================
os.makedirs("cifar_gan_samples", exist_ok=True)
SEED = tf.random.normal([16, LATENT_DIM])

def save_generated(epoch):
    imgs = generator(SEED, training=False)      # [-1,1]
    imgs = (imgs + 1.0) / 2.0                   # [0,1]

    plt.figure(figsize=(6, 6))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(imgs[i])
        plt.axis("off")
    plt.tight_layout()
    plt.savefig(f"cifar_gan_samples/epoch_{epoch:03d}.png")
    plt.close()

In [7]:
# =========================
# 6) Train Step
# =========================
@tf.function
def train_step(real_images):
    noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])

    with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
        fake_images = generator(noise, training=True)

        real_logits = discriminator(real_images, training=True)
        fake_logits = discriminator(fake_images, training=True)

        d_loss = d_loss_fn(real_logits, fake_logits)
        g_loss = g_loss_fn(fake_logits)

    g_grads = gtape.gradient(g_loss, generator.trainable_variables)
    d_grads = dtape.gradient(d_loss, discriminator.trainable_variables)

    gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))

    return g_loss, d_loss

In [None]:
# =========================
# 7) Training Loop
# =========================
for epoch in range(1, EPOCHS + 1):
    g_losses, d_losses = [], []
    for real_batch in train_ds:
        g_loss, d_loss = train_step(real_batch)
        g_losses.append(g_loss)
        d_losses.append(d_loss)

    save_generated(epoch)
    print(f"Epoch {epoch}/{EPOCHS} | G Loss: {tf.reduce_mean(g_losses):.4f} | D Loss: {tf.reduce_mean(d_losses):.4f}")

print(" Done! Check folder: cifar_gan_samples")