# Assignment: Vision Transformers on CIFAR10

In [18]:
#imports
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [19]:
#loading the dataset
dataset = dset.CIFAR10(root="./data", download=True,
  transform=transforms.Compose([
  transforms.Resize(256, interpolation=3),
  transforms.CenterCrop(224),
  #transforms.Resize(64),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, num_workers=2)

#dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [20]:
#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Tasks:
* try to get the best test Accuracy on Cifar10 using a transformer model
* pre-trained models allowed - see [here](https://docs.pytorch.org/vision/main/models/vision_transformer.html) for list of models in TorchVision
* **hint**: just like with the CNN in Week 5 - wee need to change the classification layer to fit our 10 class CIFAR-10 problem before we can fine-tune it...
* **hint**: Transformers need a lot of compute + memory - use the A100 GPU



In [21]:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

NUM_CLASSES = 10

# freeze
for param in model.parameters():
    param.requires_grad = False

model.head = nn.Linear(model.head.in_features, NUM_CLASSES)

# Only train the classifier
for param in model.head.parameters():
    param.requires_grad = True

model.to(device)

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.head.parameters(), lr=5e-4)  # Only update classifier

In [23]:
from tqdm import tqdm

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    correct, total = 0, 0

    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    print(f"Train Loss: {total_loss/len(loader):.4f} | Accuracy: {100. * correct / total:.2f}%")

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

    print(f"Test Loss: {total_loss/len(loader):.4f} | Accuracy: {100. * correct / total:.2f}%")


In [24]:
EPOCHS = 5
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    train(model, train_loader, optimizer, criterion)
    evaluate(model, test_loader, criterion)



Epoch 1/5


100%|██████████| 313/313 [06:44<00:00,  1.29s/it]

Train Loss: 0.6635 | Accuracy: 88.30%





Test Loss: 0.3065 | Accuracy: 93.00%

Epoch 2/5


100%|██████████| 313/313 [06:45<00:00,  1.29s/it]

Train Loss: 0.2542 | Accuracy: 93.35%





Test Loss: 0.2330 | Accuracy: 93.64%

Epoch 3/5


100%|██████████| 313/313 [06:45<00:00,  1.29s/it]

Train Loss: 0.2077 | Accuracy: 94.06%





Test Loss: 0.2069 | Accuracy: 94.03%

Epoch 4/5


100%|██████████| 313/313 [06:44<00:00,  1.29s/it]

Train Loss: 0.1858 | Accuracy: 94.53%





Test Loss: 0.1922 | Accuracy: 94.21%

Epoch 5/5


100%|██████████| 313/313 [06:44<00:00,  1.29s/it]

Train Loss: 0.1720 | Accuracy: 94.81%





Test Loss: 0.1839 | Accuracy: 94.52%
