In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

# Path to your local CelebA dataset folder
DATASET_PATH = "celeba-dataset"
BATCH_SIZE = 128
IMG_SIZE = (64, 64)

# Load dataset from directory
def load_dataset():
    dataset = tf.keras.preprocessing.image_dataset_from_directory(
        DATASET_PATH,
        label_mode=None,  # No labels needed for GAN
        image_size=IMG_SIZE,  # Resize images to 64x64
        batch_size=BATCH_SIZE
    )
    dataset = dataset.map(lambda x: (tf.cast(x, tf.float32) - 127.5) / 127.5)  # Normalize [-1,1]
    return dataset

# Load dataset
dataset = load_dataset()

def build_generator():
    inputs = tf.keras.layers.Input(shape=(100,))
    x = tf.keras.layers.Dense(8 * 8 * 512, use_bias=False)(inputs)
    x = tf.keras.layers.Reshape((8, 8, 512))(x)
    
    x = tf.keras.layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    outputs = tf.keras.layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding="same", activation="tanh")(x)
    
    model = tf.keras.models.Model(inputs, outputs)
    return model

def build_discriminator():
    inputs = tf.keras.layers.Input(shape=(64, 64, 3))
    x = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same")(inputs)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    x = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same")(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    x = tf.keras.layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same")(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    x = tf.keras.layers.Flatten()(x)
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    
    model = tf.keras.models.Model(inputs, outputs)
    return model

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(1e-4))

def train_dcgan(epochs=50):
    seed = tf.random.normal([16, 100])
    
    for epoch in range(epochs):
        for real_images in dataset:
            noise = tf.random.normal([BATCH_SIZE, 100])
            generated_images = generator(noise)
            
            real_labels = tf.ones((BATCH_SIZE, 1))
            fake_labels = tf.zeros((BATCH_SIZE, 1))
            
            d_loss_real = discriminator.train_on_batch(real_images, real_labels)
            d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
            
            noise = tf.random.normal([BATCH_SIZE, 100])
            g_loss = gan.train_on_batch(noise, tf.ones((BATCH_SIZE, 1)))
            
        print(f"Epoch {epoch+1}/{epochs} - D Loss: {d_loss_real + d_loss_fake:.4f}, G Loss: {g_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] + 1) / 2)
        plt.axis("off")
    plt.savefig(f"generated_image_epoch_{epoch}.png")
    plt.show()

discriminator.trainable = False
gan_input = tf.keras.layers.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.models.Model(gan_input, gan_output)
gan.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(1e-4))

train_dcgan(epochs=50)


Found 202599 files.


ValueError: Input 0 of layer "functional_1" is incompatible with the layer: expected shape=(None, 64, 64, 3), found shape=(None, 128, 128, 3)