In [16]:
import gradio as gr
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch

# Load the Vision Transformer model and feature extractor
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Define a function to perform the classification
def classify_image(image):
    # Preprocess the image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Make the prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()

    # Return the predicted class
    return model.config.id2label[predicted_class_idx]

# Create the Gradio interface using updated API
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="Vision Transformer Image Classification",
    description="Upload an image to Classify"
)

# Launch the demo
iface.launch()




Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


