# Using Vision Transformer (ViT) Foundation Model with CIFAR-10

This notebook demonstrates how to use a pre-trained Vision Transformer (ViT) foundation model for image classification on the CIFAR-10 benchmark dataset.

## What is a Foundation Model?

Foundation models are large-scale pre-trained models that can be adapted to various downstream tasks. Vision Transformers (ViT) were introduced by Google Research and have become a popular foundation model for computer vision tasks.

## CIFAR-10 Dataset

CIFAR-10 is a well-known benchmark dataset consisting of 60,000 32x32 color images in 10 classes, with 6,000 images per class.

In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTImageProcessor
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load CIFAR-10 Dataset

We'll load the CIFAR-10 dataset using torchvision. The dataset will be automatically downloaded if not present.

In [None]:
# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Download and load test dataset
test_dataset = datasets.CIFAR10(root='./data', train=False, 
                                download=True, transform=transform)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f'Test dataset size: {len(test_dataset)}')

## Visualize Sample Images

Let's visualize a few sample images from the dataset.

In [None]:
# Function to denormalize images for visualization
def denormalize(img):
    img = img * 0.5 + 0.5  # Denormalize from [-1, 1] to [0, 1]
    return img

# Display sample images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    img, label = test_dataset[i]
    img = denormalize(img)
    ax.imshow(img.permute(1, 2, 0))
    ax.set_title(class_names[label])
    ax.axis('off')
plt.tight_layout()
plt.show()

## Load Pre-trained Vision Transformer Model

We'll use a pre-trained ViT model from Hugging Face's transformers library. This model has been pre-trained on ImageNet-21k and can be used for zero-shot classification or fine-tuned for specific tasks.

In [None]:
# Load pre-trained ViT model and processor
model_name = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
model = model.to(device)
model.eval()

print(f'Model loaded: {model_name}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')

## Make Predictions

Let's use the foundation model to make predictions on a few test images. Note that the model was trained on ImageNet, so the predictions may not perfectly align with CIFAR-10 classes.

In [None]:
# Function to predict on a single image
def predict_image(img_tensor, model, processor):
    # Denormalize and convert to PIL format expected by processor
    img = denormalize(img_tensor)
    img_pil = transforms.ToPILImage()(img)
    
    # Process image and make prediction
    inputs = processor(images=img_pil, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
    
    return predicted_class_idx, logits

# Make predictions on sample images
fig, axes = plt.subplots(2, 5, figsize=(15, 8))
for i, ax in enumerate(axes.flat):
    img, label = test_dataset[i]
    pred_idx, logits = predict_image(img, model, processor)
    
    # Get predicted class name from ImageNet labels
    predicted_label = model.config.id2label[pred_idx]
    
    img_display = denormalize(img)
    ax.imshow(img_display.permute(1, 2, 0))
    ax.set_title(f'True: {class_names[label]}\nPred: {predicted_label[:20]}...', 
                 fontsize=8)
    ax.axis('off')
plt.tight_layout()
plt.show()

## Evaluate Model Performance

Let's evaluate the model's performance on a subset of the test set.

In [None]:
# Evaluate on a subset (100 images) for demonstration
num_samples = 100
correct = 0
total = 0

print(f'Evaluating on {num_samples} images...')
for i in tqdm(range(num_samples)):
    img, label = test_dataset[i]
    pred_idx, _ = predict_image(img, model, processor)
    
    # Note: This is approximate since ImageNet and CIFAR-10 classes don't match exactly
    # In practice, you would need to fine-tune the model for CIFAR-10
    total += 1

print(f'\nNote: Direct evaluation is challenging because the model was trained on ImageNet,')
print('which has different classes than CIFAR-10. For accurate results, the model should be fine-tuned.')
print('See the fine-tuning notebook for an example of adapting this model to CIFAR-10.')

## Key Takeaways

1. **Foundation Models**: Vision Transformers are powerful foundation models pre-trained on large datasets (ImageNet-21k)
2. **Zero-Shot Capabilities**: These models can make predictions out-of-the-box, though they work best on classes similar to their training data
3. **Transfer Learning Ready**: The model can be fine-tuned for specific tasks like CIFAR-10 classification
4. **Benchmark Datasets**: CIFAR-10 is a standard benchmark for evaluating image classification models

## Next Steps

- See the fine-tuning notebook to learn how to adapt this model specifically for CIFAR-10
- Explore other foundation models like CLIP, DINO, or Swin Transformers
- Try different benchmark datasets like CIFAR-100, ImageNet, or domain-specific datasets