# Vision Transformer Exploration

This notebook provides interactive exploration of vision transformers.

In [None]:
# Setup
import sys
sys.path.append('../src')

from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import matplotlib.pyplot as plt
import requests

print("Libraries loaded successfully!")

## 1. Load and Visualize an Image

In [None]:
# Load sample image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_classification_parrots.png"
image = Image.open(requests.get(url, stream=True).raw)

# Display
plt.figure(figsize=(8, 6))
plt.imshow(image)
plt.axis('off')
plt.title("Sample Image for Classification")
plt.show()

## 2. Compare Vision Transformer Architectures

In [None]:
# Define models to compare
models = {
    "ViT": "google/vit-base-patch16-224",
    "DeiT": "facebook/deit-base-patch16-224",
    "Swin": "microsoft/swin-tiny-patch4-window7-224"
}

# Classify with each model
results = {}

for name, model_id in models.items():
    print(f"\nProcessing with {name}...")
    
    processor = AutoImageProcessor.from_pretrained(model_id)
    model = AutoModelForImageClassification.from_pretrained(model_id)
    
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    
    # Get top 3 predictions
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top3 = torch.topk(probs, 3)
    
    results[name] = []
    for i in range(3):
        idx = top3.indices[0][i].item()
        label = model.config.id2label[idx]
        score = top3.values[0][i].item()
        results[name].append((label, score))
        print(f"  {i+1}. {label}: {score:.2%}")

## 3. Visualize Model Predictions

In [None]:
# Create comparison chart
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (model_name, predictions) in enumerate(results.items()):
    ax = axes[idx]
    
    labels = [pred[0] for pred in predictions]
    scores = [pred[1] for pred in predictions]
    
    ax.barh(labels, scores)
    ax.set_xlim(0, 1)
    ax.set_xlabel('Confidence')
    ax.set_title(f'{model_name} Predictions')
    
    for i, score in enumerate(scores):
        ax.text(score + 0.01, i, f'{score:.1%}', va='center')

plt.tight_layout()
plt.show()

## 4. Understanding Patch-Based Processing

In [None]:
# Visualize how images are divided into patches
import numpy as np

# Create a grid overlay
img_array = np.array(image)
h, w = img_array.shape[:2]
patch_size = 16  # ViT uses 16x16 patches

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Original image
ax1.imshow(image)
ax1.set_title("Original Image")
ax1.axis('off')

# Image with patch grid
ax2.imshow(image)
ax2.set_title(f"Image Divided into {patch_size}x{patch_size} Patches")

# Draw grid
for i in range(0, h, patch_size):
    ax2.axhline(y=i, color='red', linewidth=0.5, alpha=0.5)
for i in range(0, w, patch_size):
    ax2.axvline(x=i, color='red', linewidth=0.5, alpha=0.5)

ax2.axis('off')
plt.tight_layout()
plt.show()

print(f"Total patches: {(h // patch_size) * (w // patch_size)}")