In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape, Flatten, Input, BatchNormalization, Conv2D, Conv2DTranspose, Dropout, UpSampling2D, AveragePooling2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# Load CelebA dataset
celeba_data, info = tfds.load('celeb_a', split='train', with_info=True)

def preprocess_dataset(dataset):
    """
    Preprocesses the dataset images:
    - Resize the images to 256x256.
    - Normalize the images to [-1, 1].
    """
    def _preprocess_img(img):
        # Resize the image
        img = tf.image.resize(img, (256, 256))
        # Normalise to [-1, 1]
        img = (img - 127.5) / 127.5
        return img

    return dataset.map(lambda x: (_preprocess_img(x['image']), x['attributes']), num_parallel_calls=tf.data.experimental.AUTOTUNE)


# Constants
BATCH_SIZE = 8  # * can reduce if memory issues
EPOCHS = 2000  
NOISE_DIM = 200  # Noise dimension for generator input
SAVE_INTERVAL = 25  # Frequency to save generated images for visualisation
TRAINING_RATIO = 5  # Number of discriminator updates per generator update

# Preprocess and batch the dataset
celeba_dataset_processed = preprocess_dataset(celeba_data).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)


# For the training function, we need the data in numpy format
num_samples = 1000  # You can adjust this value based on available memory


IMG_SHAPE = (512, 512, 3)

def build_simplified_generator():
    model = Sequential()

    # Start with 8x8 spatial resolution
    model.add(Dense(128 * 8 * 8, activation="relu", input_shape=(NOISE_DIM,)))
    model.add(Reshape((8, 8, 128)))
    model.add(BatchNormalization())

    # UpSample to 16x16
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))

    # UpSample to 32x32
    model.add(Conv2DTranspose(32, kernel_size=4, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))

    # UpSample to 64x64
    model.add(Conv2DTranspose(16, kernel_size=4, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))

    # UpSample to 128x128
    model.add(Conv2DTranspose(8, kernel_size=4, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(0.2))

    # UpSample to 256x256
    model.add(Conv2DTranspose(3, kernel_size=4, strides=2, padding="same", activation="tanh"))

    return model

def build_simplified_discriminator():
    model = Sequential()

    model.add(Conv2D(16, kernel_size=4, strides=2, padding="same", input_shape=(256, 256, 3)))
    model.add(LeakyReLU(0.2))

    model.add(Conv2D(32, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(0.2))

    model.add(Conv2D(64, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(0.2))

    model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(0.2))

    model.add(Flatten())
    model.add(Dense(1))

    return model


# Instantiate the new models
generator_512 = build_simplified_generator()
discriminator_512 = build_simplified_discriminator()

# Optimisers with Two Timescale Update Rule (TTUR)
optimizer_gen = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)
optimizer_disc = tf.keras.optimizers.Adam(0.0004, beta_1=0.5, beta_2=0.9)

# Loss function (Wasserstein loss)
def wasserstein_loss(y_true, y_pred):
    return -tf.reduce_mean(y_true * y_pred)

# Compilation
generator_512.compile(optimizer=optimizer_gen, loss=wasserstein_loss)
discriminator_512.compile(optimizer=optimizer_disc, loss=wasserstein_loss)

# Combined model for training the generator
z = Input(shape=(NOISE_DIM,))
img = generator_512(z)
discriminator_512.trainable = False
valid = discriminator_512(img)
combined = Model(z, valid)
combined.compile(loss=wasserstein_loss, optimizer=optimizer_gen)

def train_gan_512(dataset, epochs, batch_size=BATCH_SIZE, save_interval=SAVE_INTERVAL):
    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))

    for epoch in range(epochs):
        for _ in range(TRAINING_RATIO):

            # Use TensorFlow dataset direcrlt
            for imgs, _ in dataset.take(1):
                # Sample noise as generator input
                noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))

                # Generate a batch of new images
                gen_imgs = generator_512.predict(noise)

                # Train the discriminator
                d_loss_real = discriminator_512.train_on_batch(imgs, valid)
                d_loss_fake = discriminator_512.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # Clip discriminator weights (for WGAN)
                for l in discriminator_512.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -0.01, 0.01) for w in weights]
                    l.set_weights(weights)

        # Train the generator
        g_loss = combined.train_on_batch(noise, valid)

        # Print progress
        print(f"{epoch}/{epochs} [D loss: {d_loss} | G loss: {g_loss}]")

        # Save generated images at save intervals
        if epoch % save_interval == 0:
            save_imgs(generator_512, epoch)


def save_imgs(generator, epoch, save_path="gan_images", num_samples=25):
    """
    Saves generated images for visualization.
    """
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    noise = np.random.normal(0, 1, (num_samples, NOISE_DIM))
    gen_imgs = generator.predict(noise)

    # Rescale images from [-1, 1] to [0, 1]
    gen_imgs = 0.5 * gen_imgs + 0.5

    for i in range(num_samples):
        plt.imshow(gen_imgs[i])
        plt.axis('off')
        plt.savefig(f"{save_path}/image_at_epoch_{epoch}_sample_{i}.png")
        plt.close()


# start training:
train_gan_512(celeba_dataset_processed, EPOCHS)