In [1]:
!pip install gradio
import gradio as gr




In [2]:
import tensorflow as tf
from PIL import Image
import numpy as np

# Load the saved generator model in the new Keras format
generator = tf.keras.models.load_model('colorization_generator_model.keras')

# Function to preprocess the uploaded grayscale image for the model
def preprocess_image(image):
    if image is None:
        raise ValueError("Invalid input: No image was uploaded.")
        
    original_size = image.size  # Store the original size of the input image
    image = image.convert("L")  # Explicitly convert to grayscale
    image = image.resize((256, 256))  # Resize to match the model input
    image = np.array(image) / 127.5 - 1  # Normalize to [-1, 1]
    image = np.expand_dims(image, axis=-1)  # Add the channel dimension (HxWx1)
    image = np.expand_dims(image, axis=0)  # Add the batch dimension (1xHxWx1)
    return image, original_size

# Function to post-process the generated image for display and resize to original dimensions
def postprocess_image(image, original_size):
    image = (image + 1) / 2  # Convert from [-1, 1] to [0, 1]
    image = np.clip(image, 0, 1)  # Clip values to ensure they're in [0, 1]
    image = (image * 255).astype(np.uint8)  # Convert to [0, 255] range
    image = Image.fromarray(image)  # Convert numpy array to PIL image
    image = image.resize(original_size)  # Resize back to the original input size
    return image

# Gradio function to take grayscale input, process, and return the colorized image
def colorize_image(image):
    if image is None:
        return "Please upload a valid image."  # Error message for invalid input
    
    bw_image, original_size = preprocess_image(image)  # Preprocess the uploaded grayscale image
    generated_image = generator(bw_image, training=False)[0]  # Generate colorized image
    
    # Convert TensorFlow tensor to numpy array
    generated_image = generated_image.numpy()
    
    # Post-process and resize the generated image to the original input size
    colorized_image = postprocess_image(generated_image, original_size)
    
    return colorized_image

# Create a Gradio interface that accepts only grayscale images
interface = gr.Interface(
    fn=colorize_image,  # Function to process the input and return output
    inputs=gr.Image(type="pil", label="Upload a Grayscale Image (B/W)"),  # Only grayscale images
    outputs="image",  # Output a colorized image
    title="Grayscale to Color Image GAN",
    description="Upload a grayscale image, and the model will colorize it."
)

# Launch the Gradio app
interface.launch()

Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.


