<a href="https://colab.research.google.com/github/Srini-c28/GEN-AI---lab-work/blob/main/9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio
# gan_gradio.py

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import gradio as gr
import os

# Define the path to save the trained generator model
MODEL_SAVE_PATH = 'simple_gan_generator.h5'
LATENT_DIM = 100
IMAGE_SIZE = 28
CHANNELS = 1  # For grayscale MNIST images

def build_generator(latent_dim, image_size, channels):
    model = models.Sequential()
    nodes = image_size // 4 * image_size // 4 * 256
    model.add(layers.Dense(nodes, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((image_size // 4, image_size // 4, 256)))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(channels, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model

def build_discriminator(image_size, channels):
    model = models.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(image_size, image_size, channels)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

@tf.function
def train_step(images, generator, discriminator, latent_dim):
    noise = tf.random.normal([BATCH_SIZE, latent_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_image(latent_vector):
    latent_vector = np.array(latent_vector).reshape(1, LATENT_DIM)
    generated_image = generator(latent_vector, training=False)
    generated_image = (generated_image[0, :, :, 0] * 127.5 + 127.5).numpy().astype(np.uint8)
    return generated_image

# Training parameters
EPOCHS = 50
BATCH_SIZE = 256
BUFFER_SIZE = 60000

# Load or train the model
generator = build_generator(LATENT_DIM, IMAGE_SIZE, CHANNELS)
discriminator = build_discriminator(IMAGE_SIZE, CHANNELS)

if os.path.exists(MODEL_SAVE_PATH):
    print(f"Loading pre-trained generator from {MODEL_SAVE_PATH}")
    generator = tf.keras.models.load_model(MODEL_SAVE_PATH)
else:
    print("Training the GAN...")
    (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
    train_images = train_images.reshape(train_images.shape[0], IMAGE_SIZE, IMAGE_SIZE, CHANNELS).astype('float32')
    train_images = (train_images - 127.5) / 127.5
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    for epoch in range(EPOCHS):
        for image_batch in train_dataset:
            train_step(image_batch, generator, discriminator, LATENT_DIM)
        print(f"Epoch {epoch+1} complete")

    generator.save(MODEL_SAVE_PATH)
    print(f"Generator model saved to {MODEL_SAVE_PATH}")

# Gradio Interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[gr.Slider(minimum=-5, maximum=5, step=0.1, label=f"Latent Dimension {i+1}") for i in range(LATENT_DIM)],
    outputs=gr.Image(shape=(IMAGE_SIZE, IMAGE_SIZE), label="Generated Image"),
    title="Simple GAN Image Generator (MNIST)",
    description="Generate handwritten digits using a trained simple GAN. Adjust the latent space sliders to explore the generated output."
)

iface.launch(share=True)

Collecting gradio
  Downloading gradio-5.26.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.9.0 (from gradio)
  Downloading gradio_client-1.9.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Training the GAN...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
