In [7]:
"""
Simple Vision Transformer (ViT) Hugging Face + Gradio app
with sample input image.

Requirements:
    pip install torch torchvision transformers gradio pillow
"""

import gradio as gr
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image

# Load pre-trained ViT model and processor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

def classify_image(image):
    # Process input
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
    
    # Top-5 predictions
    topk_vals, topk_idxs = torch.topk(probs, k=5)
    results = {model.config.id2label[idx.item()]: float(val) for val, idx in zip(topk_vals, topk_idxs)}
    return results

# Gradio interface
demo = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Label(num_top_classes=5, label="Predictions"),
    title="Simple Vision Transformer (ViT) Classifier",
    description="Upload an image to classify it using Hugging Face's Vision Transformer.",
    examples=[
        ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/dog.jpg"],
        ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat.jpg"],
        ["https://huggingface.co/datasets/mishig/sample_images/resolve/main/plane.jpeg"],
    ],
)

if __name__ == "__main__":
    demo.launch()


Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.
