# Assignment: Vision Transformers on CIFAR10

In [1]:
#imports
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [2]:
#loading the dataset
dataset = dset.CIFAR10(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
nc=3

dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


In [3]:
#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Tasks:
* try to get the best test Accuracy on Cifar10 using a transformer model
* pre-trained models allowed - see [here](https://docs.pytorch.org/vision/main/models/vision_transformer.html) for list of models in TorchVision
* **hint**: just like with the CNN in Week 5 - wee need to change the classification layer to fit our 10 class CIFAR-10 problem before we can fine-tune it...
* **hint**: Transformers need a lot of compute + memory - use the A100 GPU



In [4]:
!pip install transformers accelerate

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from transformers import ViTForImageClassification, ViTConfig
from accelerate import Accelerator

# Define device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loading the dataset
# We will use a larger image size for the transformer
transform = transforms.Compose([
    transforms.Resize(224),  # ViT models typically use 224x224 input
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = CIFAR10(root="./data", download=True, train=True, transform=transform)
testset = CIFAR10(root="./data", download=True, train=False, transform=transform)

# Create DataLoaders
batch_size = 32  # Adjust based on available GPU memory
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Load a pre-trained Vision Transformer model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)

num_classes = 10
model.classifier = nn.Linear(model.config.hidden_size, num_classes)

model.to(device)

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5) # Lower learning rate for fine-tuning

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Initialize Accelerate
accelerator = Accelerator()
model, optimizer, trainloader, testloader = accelerator.prepare(
    model, optimizer, trainloader, testloader
)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    # Evaluation on the test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Epoch {epoch + 1} - Test Accuracy: {accuracy:.2f}%')

print('Finished Training')

# Final evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

final_accuracy = 100 * correct / total
print(f'Final Test Accuracy: {final_accuracy:.2f}%')




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


[Epoch 1, Batch 100] loss: 0.502
[Epoch 1, Batch 200] loss: 0.111
[Epoch 1, Batch 300] loss: 0.093
[Epoch 1, Batch 400] loss: 0.097
[Epoch 1, Batch 500] loss: 0.078
[Epoch 1, Batch 600] loss: 0.091
[Epoch 1, Batch 700] loss: 0.072
[Epoch 1, Batch 800] loss: 0.092
[Epoch 1, Batch 900] loss: 0.072
[Epoch 1, Batch 1000] loss: 0.062
[Epoch 1, Batch 1100] loss: 0.087
[Epoch 1, Batch 1200] loss: 0.066
[Epoch 1, Batch 1300] loss: 0.066
[Epoch 1, Batch 1400] loss: 0.055
[Epoch 1, Batch 1500] loss: 0.070
Epoch 1 - Test Accuracy: 97.57%
[Epoch 2, Batch 100] loss: 0.025
[Epoch 2, Batch 200] loss: 0.025
[Epoch 2, Batch 300] loss: 0.024
[Epoch 2, Batch 400] loss: 0.019
[Epoch 2, Batch 500] loss: 0.033
[Epoch 2, Batch 600] loss: 0.041
[Epoch 2, Batch 700] loss: 0.032
[Epoch 2, Batch 800] loss: 0.033
[Epoch 2, Batch 900] loss: 0.029
[Epoch 2, Batch 1000] loss: 0.031
[Epoch 2, Batch 1100] loss: 0.034
[Epoch 2, Batch 1200] loss: 0.051
[Epoch 2, Batch 1300] loss: 0.050
[Epoch 2, Batch 1400] loss: 0.039
