In [None]:
# GAN Stippling Project
# Description: This project implements a Generative Adversarial Network (GAN) to generate stippled images.
# This project demonstrates the use of TensorFlow in building and training both generator
# and discriminator models, incorporating custom layers and training procedures to simulate
# stippling effects in images.

import tensorflow as tf
import os
from tensorflow.keras import layers, models, optimizers, losses

# Utility functions for data preparation and image preprocessing
# ----------------------------------------------------------------------------------

# Function to load and preprocess images from a given path
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)  # Read the image file
    image = tf.image.decode_png(image, channels=1)  # Decoding PNG images to tensors
    image = tf.image.resize(image, [256, 256])  # Resizing image to 256x256 pixels
    image /= 255.0  # Normalizing the image to [0, 1] range for model processing
    return image

# Function to prepare the dataset by loading images and setting up batching and prefetching
def prepare_dataset(real_image_path, batch_size):
    image_paths = tf.data.Dataset.list_files(real_image_path + '/*.png', shuffle=False)
    images = image_paths.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = tf.data.Dataset.zip((images, images))
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

# Custom layers and model architecture
# ----------------------------------------------------------------------------------

# Custom layer to create a stippling effect in images
class StippleEffectLayer(layers.Layer):
    def __init__(self):
        super(StippleEffectLayer, self).__init__()

    def call(self, inputs):
        # Applying a sine function to the input to create a stippling effect
        return inputs * tf.math.sin(inputs)

# Function to build the generator model of the GAN
def build_generator():
    model = models.Sequential([
        layers.InputLayer(input_shape=(256, 256, 1)),
        # Initial convolution layer with a custom kernel initializer for edge detection
        layers.Conv2D(1, (3, 3), padding='same', activation='relu',
                      kernel_initializer=tf.constant_initializer([
                          [-1, -2, -1],
                          [0, 0, 0],
                          [1, 2, 1]
                      ]),
                      bias_initializer='zeros'),
        layers.ReLU(),
        # Adding additional convolutional layers to process the image
        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
        layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'),
        StippleEffectLayer(),  # Integrating the custom stippling effect
        layers.Dropout(0.5),  # Dropout to control overfitting and model complexity
        layers.Conv2D(1, (3, 3), padding='same', activation='sigmoid'),
    ])
    return model

# Function to build the discriminator model of the GAN
def build_discriminator():
    model = models.Sequential([
        layers.InputLayer(input_shape=(256, 256, 1)),
        # Convolutional layers to differentiate real and generated images
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid'),
    ])
    return model

# Training and utility functions
# ----------------------------------------------------------------------------------

# Function to train the GAN, including both the generator and discriminator
def train_gan(generator, discriminator, dataset, epochs=10):
    # Initializing loss function and optimizers for both generator and discriminator
    binary_cross_entropy = losses.BinaryCrossentropy()
    generator_optimizer = optimizers.Adam(1e-4)
    discriminator_optimizer = optimizers.Adam(1e-4)

    # Iteration through the dataset for a specified number of epochs
    for epoch in range(epochs):
        for real_images, stippled_images in dataset:
            # Training the discriminator on both real and generated images
            with tf.GradientTape() as disc_tape:
                generated_images = generator(real_images, training=True)
                real_output = discriminator(stippled_images, training=True)
                generated_output = discriminator(generated_images, training=True)
                real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)
                generated_loss = binary_cross_entropy(tf.zeros_like(generated_output), generated_output)
                disc_loss = real_loss + generated_loss

            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

            # Training the generator to improve its ability to fool the discriminator
            with tf.GradientTape() as gen_tape:
                generated_images = generator(real_images, training=True)
                gen_output = discriminator(generated_images, training=True)
                gen_loss = binary_cross_entropy(tf.ones_like(gen_output), gen_output)

            gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
            
            # Output training progress
            print(f"Epoch {epoch+1}, Disc Loss: {disc_loss.numpy()}, Gen Loss: {gen_loss.numpy()}")

        # Saving the generated images after each epoch to visualize progress
        save_generated_images(generator, epoch, real_images, save_path='/Users/mohammedabbas/Desktop/generated_images')
        print(f"Epoch {epoch+1} completed")


# Function to save generated images during training for review and analysis
def save_generated_images(generator, epoch, real_images, save_path='../output/generated_images'):
    predictions = generator(real_images, training=False) 
    predictions = (predictions + 1) / 2  
    os.makedirs(save_path, exist_ok=True)  

    for i, img in enumerate(predictions):  # Saving each generated image
        img = tf.keras.preprocessing.image.array_to_img(img)
        img.save(os.path.join(save_path, f"image_at_epoch_{epoch:04d}_{i}.png"))

# Main execution: setup and train the GAN
real_images_path = '../data/input_images'  # Path to real images
batch_size = 5
dataset = prepare_dataset(real_images_path, batch_size)

generator = build_generator()  
discriminator = build_discriminator()  

train_gan(generator, discriminator, dataset, epochs=2)  
