## Building a Vision Encoder: Fine-Tuning a Vision Transformer (ViT)

This notebook is the next step in your journey to understand VLMs. Here, we will build a powerful image classification model by fine-tuning a pre-trained Vision Transformer.

**The Goal:** Create a custom vision encoder by adapting a large, general-purpose model to a specific dataset (CIFAR-10).

**Key Steps:**
1.  **Load CIFAR-10 Dataset**: Use `torchvision` to easily access and prepare the data.
2.  **Define Image Transforms**: Resize and normalize the images to match the pre-trained model's requirements.
3.  **Load a Pre-trained ViT**: Use the `timm` library to load a state-of-the-art ViT model that has already been trained on ImageNet.
4.  **Adapt the Model**: Replace the model's final layer to fit our 10 CIFAR-10 classes.
5.  **Fine-Tune**: Write a training loop to train the model for a few epochs on the new data.

### Step 0: Installation
First, you need the `timm` (PyTorch Image Models) library. It's the standard for working with computer vision models in PyTorch. Run this in your terminal:

```bash
pip install timm
```

In [None]:
# Cell 1: Imports and Device Setup
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import timm 

# ----- Device -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Cell 2: Data Preparation and Loading

# The pre-trained ViT model expects images of size 224x224.
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Download and prepare the datasets ---
print("Downloading and preparing CIFAR-10 dataset...")
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=data_transforms)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=data_transforms)

# --- Create DataLoaders ---
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("Dataset prepared successfully.")

In [None]:
# Cell 3: Load Pre-trained Vision Transformer

print("Loading pre-trained Vision Transformer (vit_base_patch16_224)...")
model = timm.create_model('vit_base_patch16_224', pretrained=True)


num_in_features = model.head.in_features
NUM_CLASSES = 10

model.head = nn.Linear(num_in_features, NUM_CLASSES)

model = model.to(device)

print("Model loaded and adapted for CIFAR-10.")
print(f"Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Cell 4: Define Loss Function, Optimizer, and Hyperparameters

# Hyperparameters
LEARNING_RATE = 1e-4 
EPOCHS = 3          

# Loss Function
loss_fn = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Cell 5: Training and Evaluation Loops

def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train() 
    total_loss = 0.0
    total_correct = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        # 1. Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # 2. Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 3. Track metrics
        total_loss += loss.item()
        predictions = torch.argmax(outputs, dim=1)
        total_correct += (predictions == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

def evaluate(model, dataloader, loss_fn, device):
    model.eval() # Set the model to evaluation mode
    total_loss = 0.0
    total_correct = 0

    with torch.no_grad(): 
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            predictions = torch.argmax(outputs, dim=1)
            total_correct += (predictions == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

In [None]:
# Cell 6: Main Fine-Tuning Execution

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch + 1}/{EPOCHS} ---")
    
    train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer, device)
    print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f}")
    
    test_loss, test_acc = evaluate(model, test_loader, loss_fn, device)
    print(f"Test Loss : {test_loss:.4f} | Test Accuracy : {test_acc:.4f}")

print("\nFine-tuning finished.")

# --- Save the fine-tuned model --- 
MODEL_SAVE_PATH = '../models/vit_cifar10_finetuned.pth'
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

## Congratulations!

You have now successfully built a custom vision encoder. The file `vit_cifar10_finetuned.pth` contains the weights of a powerful ViT model that is now specialized in classifying CIFAR-10 images.

This is a fundamental building block. The next step in your VLM journey will be to understand how to take the outputs from this vision encoder and your text encoder and combine them in a meaningful way.