# Module 3, Task 8: Vision Transformers in PyTorch

**Objective:** Implement and evaluate a Vision Transformer (ViT) model using PyTorch and the `torchvision` library.

In [None]:
# Install necessary libraries
!pip install torch torchvision matplotlib scikit-learn tqdm

### Introduction

In this notebook, we'll leverage the `torchvision.models` library, which provides a pre-built, high-quality implementation of the Vision Transformer. Using a pre-built model allows us to focus on the training and evaluation process and often leads to better results than implementing from scratch, especially without extensive hyperparameter tuning. We will fine-tune a ViT pre-trained on ImageNet.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Data Loading and Preprocessing ---
IMG_SIZE = 224 # ViT models are often pre-trained on 224x224 images
BATCH_SIZE = 32 # Use a smaller batch size for large models

# Use transforms expected by pre-trained models
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

full_dataset = torchvision.datasets.EuroSAT(root='./data', download=True, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

NUM_CLASSES = len(full_dataset.classes)
CLASS_NAMES = full_dataset.classes
print(f"Data pipelines ready. Using {NUM_CLASSES} classes.")

### Loading a Pre-trained ViT Model

We'll load a `vit_b_16` model pre-trained on ImageNet. We then need to replace the final classification head with a new one tailored to our number of classes (10 for EuroSAT).

In [None]:
# Load a pre-trained ViT model
vit_model = models.vision_transformer.vit_b_16(weights='IMAGENET1K_V1')

# Freeze all the parameters in the model
for param in vit_model.parameters():
    param.requires_grad = False

# Replace the classifier head
num_features = vit_model.heads.head.in_features
vit_model.heads.head = nn.Linear(num_features, NUM_CLASSES)

vit_model = vit_model.to(device)

print("ViT model loaded and classifier head replaced for fine-tuning.")
# print(vit_model) # Uncomment to see the model architecture

### Training and Fine-Tuning
Since we froze the feature extraction layers, we are only training the weights of the new classification head. This is a form of transfer learning called "fine-tuning".

In [None]:
EPOCHS = 10 # Fine-tuning often requires fewer epochs
LEARNING_RATE = 0.001

criterion = nn.CrossEntropyLoss()
# Only optimize the parameters of the new classifier head
optimizer = optim.Adam(vit_model.heads.head.parameters(), lr=LEARNING_RATE)

def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs):
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        history['train_loss'].append(running_loss / total)
        history['train_acc'].append(correct / total)

        # Validation
        model.eval()
        running_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Val"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        history['val_loss'].append(running_loss / total)
        history['val_acc'].append(correct / total)
        print(f"Epoch {epoch+1}/{epochs} -> Train Acc: {history['train_acc'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.4f}")
    return history

history = train_and_validate(vit_model, train_loader, val_loader, criterion, optimizer, EPOCHS)

### Evaluation
Let's evaluate our fine-tuned ViT model.

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_acc'], label='Training Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.legend()
plt.title('ViT Accuracy (Fine-tuned)')

plt.subplot(1, 2, 2)
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('ViT Loss (Fine-tuned)')
plt.show()

# Detailed Classification Report
vit_model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = vit_model(images)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

print("\nFine-tuned ViT Classification Report:\n")
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))

### Conclusion

By using a pre-trained Vision Transformer and fine-tuning it on our specific task, we can achieve very high performance with significantly less training time and data compared to training from scratch. The features learned by the ViT on the massive ImageNet dataset are highly transferable to our EuroSAT land classification problem. This demonstrates the power and efficiency of transfer learning, which is a cornerstone of modern computer vision.