In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (Dense, LeakyReLU, Reshape, Flatten, BatchNormalization, Conv2D, 
                                     Conv2DTranspose, Input, AveragePooling2D, UpSampling2D)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.constraints import Constraint
import tensorflow_datasets as tfds
import os

# Constants
EPOCHS = 1000
BATCH_SIZE = 20
TRAINING_RATIO = 5  # Number of discriminator updates per generator update
GRADIENT_PENALTY_WEIGHT = 10  

# Model parameters
NOISE_DIM = 100
IMG_SHAPE = (256, 256, 3)

# Spectral Normalisation
class SpectralNormalization(Constraint):
    def __init__(self, power_iterations=1):
        self.power_iterations = power_iterations
        self.u = None

    def __call__(self, w):
        flattened_w = tf.reshape(w, [w.shape[0], -1])

        if self.u is None:
            self.u = tf.Variable(initial_value=K.random_normal([1, flattened_w.shape[1]], 0, 1),
                                 trainable=False, dtype=tf.float32)

        for _ in range(self.power_iterations):
            v = K.l2_normalize(tf.matmul(self.u, flattened_w, transpose_b=True))
            self.u.assign(K.l2_normalize(tf.matmul(v, flattened_w)))

        sigma = tf.matmul(v, tf.matmul(flattened_w, self.u, transpose_b=True))
        w_bar = w / sigma
        return w_bar

    def get_config(self):
        return {'power_iterations': self.power_iterations}


# Build Generator
def build_generator():
    model = tf.keras.Sequential()
    
    # 16x16 feature maps
    model.add(Dense(128 * 16 * 16, activation="relu", input_shape=(NOISE_DIM,)))
    model.add(Reshape((16, 16, 128)))
    model.add(BatchNormalization())
    model.add(UpSampling2D())
    
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(UpSampling2D())
    
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(UpSampling2D())
    
    model.add(Conv2D(32, kernel_size=3, padding="same"))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(UpSampling2D())

    model.add(Conv2D(32, kernel_size=3, padding="same"))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(3, kernel_size=3, padding="same", activation='tanh'))
    
    return model


    # extra upsampling and Conv2D block
    model.add(Conv2D(32, kernel_size=3, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(UpSampling2D())

    model.add(Conv2D(3, kernel_size=3, padding="same", activation='tanh'))
    
    return model


# Build Discriminator
def build_discriminator():
    model = tf.keras.Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, padding="same", kernel_constraint=SpectralNormalization()))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same", kernel_constraint=SpectralNormalization()))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Flatten())
    model.add(Dense(1))
    
    return model

# Wasserstein Loss
def wasserstein_loss(y_true, y_pred):
    return -tf.reduce_mean(y_true * y_pred)

# Gradient Penalty
def gradient_penalty(batch_size, real_images, fake_images, discriminator):
    alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
    diff = fake_images - real_images
    interpolated = real_images + alpha * diff

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)

    grads = gp_tape.gradient(pred, [interpolated])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

# Load and preprocess the tf_flowers dataset
def preprocess_dataset(dataset):
    def _preprocess_img(img, label):
        # Resise and normalise the image
        img = tf.image.resize(img, (256, 256))
        img = (img - 127.5) / 127.5
        return img

    return dataset.map(_preprocess_img)

data, info = tfds.load('tf_flowers', split='train', with_info=True, as_supervised=True)
processed_data = preprocess_dataset(data).batch(BATCH_SIZE, drop_remainder=True).shuffle(1024).prefetch(tf.data.experimental.AUTOTUNE)

# optimisers
generator_optimizer = Adam(0.0001, beta_1=0.5)
discriminator_optimizer = Adam(0.0004, beta_1=0.5)

@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(BATCH_SIZE, NOISE_DIM))

    # Decode them to fake images
    generated_images = generator(random_latent_vectors)

    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble the labels - discriminating real from fake images
    labels = tf.concat([tf.ones((BATCH_SIZE, 1)), -tf.ones((BATCH_SIZE, 1))], axis=0)

    # Add random noise to the labels !
    labels += 0.05 * tf.random.uniform(labels.shape)

    with tf.GradientTape() as tape:
        # Train the discriminator
        predictions = discriminator(combined_images)
        d_cost = wasserstein_loss(labels, predictions)
        gp = gradient_penalty(BATCH_SIZE, real_images, generated_images, discriminator)
        d_loss = d_cost + gp * GRADIENT_PENALTY_WEIGHT

    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    discriminator_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(BATCH_SIZE, NOISE_DIM))
    misleading_labels = -tf.ones((BATCH_SIZE, 1))

    with tf.GradientTape() as tape:
        # Train the generator (note - do *not* update the discriminator weights here)
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = wasserstein_loss(misleading_labels, predictions)

    grads = tape.gradient(g_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(grads, generator.trainable_weights))

    return d_loss, g_loss

    
# Save the generated images
def save_generated_images(epoch, generator, save_path="saved_images", num_samples=10):
    """
    Saves generated images individually without any borders.
    """
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    noise = np.random.normal(0, 1, size=[num_samples, NOISE_DIM])
    generated_images = generator.predict(noise)

    # Rescale images from [-1, 1] to [0, 1]
    generated_images = (generated_images + 1) / 2.0

    for i in range(num_samples):
        image_path = os.path.join(save_path, f"gan_generated_image_epoch_{epoch}_sample_{i}.png")
        
        # Save image without axis
        plt.imshow(generated_images[i])
        plt.axis('off')
        plt.tight_layout(pad=0)
        plt.savefig(image_path, bbox_inches='tight', pad_inches=0)
        plt.close()

generator = build_generator()
discriminator = build_discriminator()

def train_gan(dataset, epochs):
    save_interval = 25  # Interval for saving images
    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch + 1,))
        for step, real_images in enumerate(dataset):
            # Train the discriminator & generator
            d_loss, g_loss = train_step(real_images)
            if step % 200 == 0:
                # Print metrics
                print("discriminator loss at step %d: %.2f" % (step, d_loss))
                print("adversarial loss at step %d: %.2f" % (step, g_loss))
        
        # Save images at the interval
        if (epoch + 1) % save_interval == 0:
            save_generated_images(epoch + 1, generator)

# Start the training loop
train_gan(processed_data, EPOCHS)
