# Vision Transformer (ViT) Image Classification

This notebook demonstrates how to use the Hugging Face Transformers library to classify images using a pre-trained Vision Transformer (ViT) model.

## About Vision Transformer
The Vision Transformer (ViT) is a BERT-like transformer encoder model pretrained on ImageNet-21k (14M images, 21k classes) and fine-tuned on ImageNet (1M images, 1k classes) at resolution 224x224 pixels.

For more information, see [the Hugging Face model page](https://huggingface.co/google/vit-base-patch16-224).

## 1. Set Up Environment

First, let's install the necessary packages. Note that **you need to restart the kernel after this cell completes**.

In [None]:
# Install required libraries if not already installed
# =========================================================================================================
# CRITICAL STEP: AFTER RUNNING THIS CELL, YOU **MUST** RESTART THE KERNEL BEFORE PROCEEDING!
#                 Failure to restart the kernel will result in 'ModuleNotFoundError' for PyTorch.
#                 In VS Code: Click the 'Restart' button for the kernel, or run 'Restart Kernel' from the Command Palette.
# =========================================================================================================
%pip install torch torchvision torchaudio transformers pillow requests --quiet
print("\n*** Installation cell complete. PLEASE RESTART THE KERNEL NOW before running the next cell! ***")

70.63s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Note: you may need to restart the kernel to use updated packages.

*** Installation cell complete. PLEASE RESTART THE KERNEL NOW before running the next cell! ***


In [None]:
# Step 1: Install required packages
%pip install transformers pillow requests --quiet
# We don't need to explicitly install PyTorch as transformers will automatically
# install the CPU version if needed.

print("\n✅ Installation complete. RESTART THE KERNEL before continuing!")

# Verify PyTorch installation and CUDA availability
# IMPORTANT: This cell should only be run AFTER restarting the kernel post-installation.
import sys
print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
try:
    import torch
    print(f"Successfully imported PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        if torch.cuda.device_count() > 0:
            print(f"Current GPU: {torch.cuda.current_device()}")
            print(f"GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        print("CUDA not available. PyTorch will run on CPU.")
    print("\n*** PyTorch verification successful. You can proceed with the next cells. ***")
except ImportError as e:
    print(f"\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(f"ERROR: FAILED to import PyTorch: {e}")
    print("This means PyTorch is NOT installed correctly in the kernel's environment OR the kernel was NOT restarted after installation.")
    print("Please RE-RUN THE INSTALLATION CELL (Cell 2), then **RESTART THE KERNEL**, and then RE-RUN THIS VERIFICATION CELL (Cell 3).")
    print("Subsequent cells WILL FAIL until PyTorch is correctly installed and imported.")
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
except Exception as e:
    print(f"An unexpected error occurred during PyTorch verification: {e}")

In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
import sys

# Print Python and library information
print(f"Python version: {sys.version}")
print(f"Transformers library loaded successfully")

ModuleNotFoundError: No module named 'torch'

# Additional imports
import torch  # For tensor operations

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load an image from the web
We will use an image from the COCO dataset.

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image.show()

## Load the ViT processor and model
We use the pre-trained `google/vit-base-patch16-224` model.

# Step 3: Load and display an image from the web
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

try:
    # Download the image
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Raise an exception for HTTP errors
    
    # Open the image
    image = Image.open(response.raw)
    
    # Display the image
    display(image)  # This works in Jupyter notebooks
    print(f"Image size: {image.size}")
except Exception as e:
    print(f"Error loading image: {e}")

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

ImportError: 
ViTForImageClassification requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.


## Preprocess the image and run inference
We process the image and use the model to predict its class.

# Step 4: Load the ViT processor and model
try:
    # Load the image processor
    print("Loading ViT image processor...")
    processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
    
    # Load the model
    print("Loading ViT model...")
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
    
    print("✅ Model and processor loaded successfully!")
except Exception as e:
    print(f"Error loading model or processor: {e}")

In [None]:
# Check if torch is available and select device
# Ensure 'torch' is imported in an earlier cell (Cell 4 after these changes)
try:
    # Select device (CPU or GPU)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    # Prepare the image for the model
    inputs = processor(images=image, return_tensors="pt").to(device)
    model = model.to(device)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Get the predicted class
    predicted_class_idx = logits.argmax(-1).item()
    
    # Print the result
    predicted_class = model.config.id2label[predicted_class_idx]
    confidence = logits.softmax(dim=-1)[0][predicted_class_idx].item()
    
    print(f"✅ Prediction successfully completed!")
    print(f"Predicted class: {predicted_class}")
    print(f"Confidence: {confidence:.2%}")
except Exception as e:
    print(f"Error during prediction: {e}")

ModuleNotFoundError: No module named 'torch'

## 6. Display Top Predictions

Let's show the top 5 predictions with their confidence scores.

In [None]:
# Step 6: Display top predictions
try:
    # Get probabilities with softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Get the top 5 predictions
    top5_prob, top5_indices = torch.topk(probs, 5)
    
    # Print the results
    print("Top 5 Predictions:")
    print("-" * 50)
    print(f"{'Class':<30} | {'Confidence':>10}")
    print("-" * 50)
    
    for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
        class_name = model.config.id2label[idx.item()]
        print(f"{class_name:<30} | {prob.item():>10.2%}")
except Exception as e:
    print(f"Error displaying top predictions: {e}")

## Conclusion

In this notebook, we successfully:
1. Set up the environment
2. Loaded an image from the web
3. Used a pre-trained Vision Transformer (ViT) model from Hugging Face
4. Processed the image and obtained classification results
5. Displayed the top predictions

The model is trained on the ImageNet dataset, which contains 1,000 classes of common objects and animals.