**Image classification project using Vision Transformer (ViT) in Hugging Face 🚀**

*ARYAM ASEIRI*

AI model capable of image classification using Hugging Face's pre-trained Vision Transformer (ViT). It can be used to classify images into different categories such as:

- Identify the types of fruits and vegetables 🍎🥦
- Classification of types of clothing 👕👗
- Identifying animal breeds 🐶🐱
- Identifying diseases in x-ray images 🏥 (in the medical field)

In [None]:
#Download libraries (Ai)
!pip install transformers datasets torch torchvision gradio

In [24]:
#Import libraries
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import requests

In [25]:
#Load the pre-trained ViT model
#https://huggingface.co/google/vit-base-patch16-224
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)



In [26]:
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import requests
import json

# Load the model and feature extractor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)

# Download the image (Classification experiment on an image)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"
image = Image.open(requests.get(url, stream=True).raw)
image = image.convert("RGB")  # Ensure the image is in "RGB format"

# Process
inputs = feature_extractor(images=image, return_tensors="pt")

# Run the model
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(-1).item()

# Load ImageNet class labels from Hugging Face dataset
labels_url = "https://huggingface.co/google/vit-base-patch16-224/raw/main/config.json"
labels = requests.get(labels_url).json()["id2label"]

# Convert dictionary keys to string format
labels = {str(k): v for k, v in labels.items()}

# Get the predicted class name
predicted_label = labels.get(str(predicted_class), "Unknown Class")
print(f"The image is classified as: {predicted_label} (Class No. {predicted_class})")

The image is classified as: tabby, tabby cat (Class No. 281)


In [33]:
#Create a simple user interface with Gradio
import gradio as gr

def classify_image(img):
    inputs = feature_extractor(images=img, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = logits.argmax(-1).item()
    class_name = labels.get(str(predicted_class), "Unknown Class")

    return f"🔍 Predicted Class: {class_name} (Class No. {predicted_class})"

gr.Interface(fn=classify_image, inputs="image", outputs="text").launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://ed24316115f694ba5e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


