<a href="https://colab.research.google.com/github/AiJared/Generative_Adversarial_Network/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
 import os

In [2]:
class ImageGenerationGAN:
    def __init__(self, img_shape=(64, 64, 3), latent_dim=100):
        """
        Initialize GAN with specified image dimensions and latent space

        Args:
            img_shape (tuple): Dimensions of input images
            latent_dim (int): Dimensionality of the random noise vector
        """
        self.img_shape = img_shape
        self.latent_dim = latent_dim

        # Compile generator and discriminator
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()

        # Compile the adversarial model
        self.adversarial_model = self.build_gan()

    def build_generator(self):
        """
        Build the generator network

        Returns:
            tf.keras.Model: Generator neural network
        """
        model = tf.keras.Sequential([
            # Input layer
            tf.keras.layers.Dense(8 * 8 * 256, input_dim=self.latent_dim),
            tf.keras.layers.Reshape((8, 8, 256)),

            # Upsampling blocks
            tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha=0.2),

            tf.keras.layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha=0.2),

            tf.keras.layers.Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')
        ])
        return model

    def build_discriminator(self):
        """
        Build the discriminator network

        Returns:
            tf.keras.Model: Discriminator neural network
        """
        model = tf.keras.Sequential([
            # Input layer
            tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, padding='same', input_shape=self.img_shape),
            tf.keras.layers.LeakyReLU(alpha=0.2),

            tf.keras.layers.Conv2D(128, kernel_size=4, strides=2, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha=0.2),

            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1, activation='sigmoid')
        ])
        return model

    def build_gan(self):
        """
        Build the Generative Adversarial Network

        Returns:
            tf.keras.Model: Combined GAN model
        """
        # Set discriminator to non-trainable when training generator
        self.discriminator.trainable = False

        # Connect generator and discriminator
        model = tf.keras.Sequential([
            self.generator,
            self.discriminator
        ])

        return model

    def train_step(self, real_images):
        """
        Single training step for the GAN

        Args:
            real_images (tf.Tensor): Batch of real training images
        """
        batch_size = tf.shape(real_images)[0]

        # Generate noise
        noise = tf.random.normal([batch_size, self.latent_dim])

        # Generate fake images
        generated_images = self.generator(noise, training=True)

        # Prepare labels
        real_labels = tf.ones((batch_size, 1))
        fake_labels = tf.zeros((batch_size, 1))

        # Train Discriminator
        with tf.GradientTape() as disc_tape:
            real_predictions = self.discriminator(real_images, training=True)
            fake_predictions = self.discriminator(generated_images, training=True)

            disc_real_loss = tf.keras.losses.binary_crossentropy(real_labels, real_predictions)
            disc_fake_loss = tf.keras.losses.binary_crossentropy(fake_labels, fake_predictions)

            disc_loss = 0.5 * (disc_real_loss + disc_fake_loss)

        # Compute gradients and update discriminator
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5).apply_gradients(
            zip(disc_gradients, self.discriminator.trainable_variables)
        )

        # Train Generator
        with tf.GradientTape() as gen_tape:
            noise = tf.random.normal([batch_size, self.latent_dim])
            generated_images = self.generator(noise, training=True)

            fake_predictions = self.discriminator(generated_images, training=True)
            gen_loss = tf.keras.losses.binary_crossentropy(real_labels, fake_predictions)

        # Compute gradients and update generator
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5).apply_gradients(
            zip(gen_gradients, self.generator.trainable_variables)
        )

    def train(self, dataset, epochs=100, batch_size=32):
        """
        Train the GAN on the provided dataset

        Args:
            dataset (tf.data.Dataset): Training dataset
            epochs (int): Number of training epochs
            batch_size (int): Batch size for training
        """
        for epoch in range(epochs):
            for batch in dataset:
                self.train_step(batch)

            # Optional: Generate and save sample images
            if epoch % 10 == 0:
                self.generate_and_save_images(epoch)

    def generate_and_save_images(self, epoch, num_examples=16):
        """
        Generate sample images during training

        Args:
            epoch (int): Current training epoch
            num_examples (int): Number of images to generate
        """
        noise = tf.random.normal([num_examples, self.latent_dim])
        generated_images = self.generator(noise, training=False)

        plt.figure(figsize=(10, 10))
        for i in range(num_examples):
            plt.subplot(4, 4, i+1)
            plt.imshow((generated_images[i] + 1) / 2.0)  # Rescale to [0,1]
            plt.axis('off')

        plt.savefig(f'generated_images_epoch_{epoch}.png')
        plt.close()

def preprocess_images(image_directory, target_size=(64, 64)):
    """
    Load and preprocess images from a directory

    Args:
        image_directory (str): Path to image directory
        target_size (tuple): Desired image dimensions

    Returns:
        tf.data.Dataset: Preprocessed image dataset
    """
    image_paths = [os.path.join(image_directory, f) for f in os.listdir(image_directory)
                   if f.endswith(('.png', '.jpg', '.jpeg'))]

    def load_and_preprocess_image(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3)
        image = tf.image.resize(image, target_size)
        image = tf.cast(image, tf.float32)
        image = (image / 127.5) - 1  # Normalize to [-1, 1]
        return image

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(load_and_preprocess_image)
    dataset = dataset.batch(32)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

In [None]:
# Example usage
if __name__ == '__main__':
    # Load your custom dataset
    dataset = preprocess_images('path/to/your/image/directory')

    # Initialize and train the GAN
    gan = ImageGenerationGAN(img_shape=(64, 64, 3))
    gan.train(dataset, epochs=2)