his task focusses on implementing and fine-tuning vision transformers for image classification.

Your task is to implement and fine-tune (or train from scratch) a Vision Transformer model for an image classification task based on the CIFAR-10 dataset (depending on version size is below 200 MB). The task is designed to help you gain experience in applying ViTs to real-world computer vision problems. Feel free to use helpers of your choice, e.g. the huggingface/transformers library, which provides pre-built Vision Transformer models. Then you can load the CIFAR-10 dataset using the torchvision.datasets module.

Got lost looking for a starting point? Here are two related examples (not tested):

example on colab
example from github
example on kaggle
In the best case, you create your own. Note that you do not need to compute the attention yourself by hand for this exercise.

a. Load the CIFAR-10 dataset: Preprocess the images by resizing them to the appropriate size for Vision Transformers (e.g., 224x224). Normalize the dataset using standard image normalization techniques, and describe what those are.

b. Load a pre-trained Vision Transformer model: Use a pre-trained ViT model from the transformers library. For example, you can use ViTForImageClassification from huggingface/transformers. Initialize the model with pre-trained weights (such as google/vit-base-patch16-224).

c. Fine-tune the model: Ideally you do not want to train all parameters. (You can if you can manage.) Freeze the lower layers of the Vision Transformer to leverage the pre-trained features. Fine-tune the top layers of the model for CIFAR-10 image classification (10 classes). Use a suitable optimizer (e.g., AdamW) and a learning rate scheduler (e.g., linear decay with warmup).

d. Train the model: Train the model for a fixed number of epochs (e.g., 10 epochs) and monitor the validation accuracy. Ensure that the model outputs the classification accuracy for both the training and validation sets after each epoch.

e. Evaluation: After training, evaluate the model on the test set and report the final accuracy. Plot the training and validation loss curves and the accuracy curves. Compare how well the model performs with and without pre-training.

f. (OPTIONAL) Visualize the attention maps from the Vision Transformer for some test images. Discuss how the model attends to different parts of the image for classification.

In [1]:
# a
# . Load the CIFAR-10 dataset: Preprocess the images by resizing them to the appropriate size for Vision Transformers (e.g., 224x224). Normalize the dataset using standard image normalization techniques, and describe what those are.

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Standard ImageNet normalization values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT input size
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Download CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Train/val split
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
train_ds, val_ds = torch.utils.data.random_split(trainset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader = DataLoader(testset, batch_size=32, shuffle=False)



ModuleNotFoundError: No module named 'torch'

In [4]:
# b
#  Load a pre-trained Vision Transformer model: Use a pre-trained ViT model from the transformers library. For example, you can use ViTForImageClassification from huggingface/transformers. Initialize the model with pre-trained weights (such as google/vit-base-patch16-224).

from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=10,
    ignore_mismatched_sizes=True
)



Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# c Fine-tune the model: Ideally you do not want to train all parameters. (You can if you can manage.) Freeze the lower layers of the Vision Transformer to leverage the pre-trained features. Fine-tune the top layers of the model for CIFAR-10 image classification (10 classes). Use a suitable optimizer (e.g., AdamW) and a learning rate scheduler (e.g., linear decay with warmup).

for param in model.vit.parameters():
    param.requires_grad = False  # freeze backbone



In [None]:
# d Train the model: Train the model for a fixed number of epochs (e.g., 10 epochs) and monitor the validation accuracy. Ensure that the model outputs the classification accuracy for both the training and validation sets after each epoch.

from torch.optim import AdamW

from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 10
num_training_steps = num_epochs * len(train_loader)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=num_training_steps
)

loss_fn = torch.nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    # ---- Training ----
    model.train()
    train_loss, correct, total = 0, 0, 0
    for batch in train_loader:
        optimizer.zero_grad()
        inputs, labels = batch[0].to(device), batch[1].to(device)
        outputs = model(inputs).logits
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    train_acc = 100 * correct / total

    # ---- Validation ----
    model.eval()
    val_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch[0].to(device), batch[1].to(device)
            outputs = model(inputs).logits
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = 100 * correct / total

    print(f"Epoch {epoch+1}: Train Acc {train_acc:.2f}%, Val Acc {val_acc:.2f}%")




In [None]:
# e  Evaluation: After training, evaluate the model on the test set and report the final accuracy. Plot the training and validation loss curves and the accuracy curves. Compare how well the model performs with and without pre-training.



model.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch[0].to(device), batch[1].to(device)
        outputs = model(inputs).logits
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")




In [None]:
# plot loss/accuracy curves

import matplotlib.pyplot as plt

# Save train_loss, val_loss, train_acc, val_acc inside loop above
# Then plot:
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.legend()
plt.show()

In [None]:
# f optional - attention visualization

outputs = model(inputs, output_attentions=True)
attentions = outputs.attentions  # list of attention maps per layer
print(len(attentions), attentions[0].shape)