In [3]:
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image

# Define the CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Load the pre-trained CIFAR-10 model
# Make sure 'cifar10Model.keras' is in the same directory as your script
try:
    model = tf.keras.models.load_model("cifar10Model.keras")
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure 'cifar10Model.keras' is in the correct directory.")
    # Exit or handle the error appropriately if the model can't be loaded

def predict_cifar10_image(image: np.ndarray) -> str:
    """
    Predicts the class of a CIFAR-10 image using the loaded model.

    Args:
        image: A NumPy array representing the input image (e.g., from Gradio's gr.Image).

    Returns:
        A string indicating the predicted class.
    """
    if image is None:
        return "No image uploaded."

    # Preprocess the image
    # Convert numpy array to PIL Image for resizing
    img_pil = Image.fromarray(image.astype('uint8'))
    # CIFAR-10 images are 32x32
    img_resized = img_pil.resize((32, 32))
    # Convert back to numpy array and normalize to [0, 1]
    img_array = np.array(img_resized).astype('float32') / 255.0
    # Add batch dimension (1, 32, 32, 3)
    img_array = np.expand_dims(img_array, axis=0)

    # Make prediction
    predictions = model.predict(img_array)
    predicted_class_index = np.argmax(predictions, axis=1)[0]
    predicted_class_name = class_names[predicted_class_index]

    # You can also include probabilities if desired, similar to your example
    # For now, let's just return the class name
    return f"Predicted Class: {predicted_class_name}"

# Create the Gradio interface
iface_cifar10_prediction = gr.Interface(
    fn=predict_cifar10_image,
    inputs=gr.Image(type="numpy", label="Upload a CIFAR-10 Image"),
    outputs=gr.Textbox(label="Prediction Result"),
    title="CIFAR-10 Image Classification",
    description="Upload an image to get its predicted class from the CIFAR-10 dataset.",
    examples=[
        # You can add example image paths here if you have them, e.g.:
        # "path/to/your/image1.png",
        # "path/to/your/image2.jpg"
    ],
    live=True # Predicts as you upload/change image
)

# To launch the Gradio app, uncomment the line below:
iface_cifar10_prediction.launch(share=False)

Model loaded successfully!
* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


