In [None]:
from IPython.display import clear_output

In [None]:
%pip install torch
%pip install torchvision
%pip install matplotlib

%pip install vit-pytorch

clear_output()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from vit_pytorch import SimpleViT

from tqdm import tqdm

# Contents:

We'll make a classifier for CIFAR10 dataset in pytorch using Vision Transformers (ViT).

Note: instead of using the full scale ViT (which is very large), we will use a library called vit-pytorch to build a smaller model based on vit architecture.

About CIFAR10:

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

![CIFAR-10 image](https://production-media.paperswithcode.com/datasets/4fdf2b82-2bc3-4f97-ba51-400322b228b1.png)

## Loading the datasets and data loaders

In [None]:
# Define your data transformations
transform = transforms.Compose([
    transforms.Resize(224),  # ViT models expect 224x224 images
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize images
])

# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Define dataloaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 94840203.49it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Defining the model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define your model
model = SimpleViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 10,
    dim = 512,
    depth = 2,
    heads = 4,
    mlp_dim = 2048
)

model.to(device)

train_losses = []
val_losses = []

In [None]:
# # You can also use the full scale model if you have the resources to train and experiment with it.
# from torchvision import models as vision_models

# model = vision_models.vit_b_32(weights='DEFAULT')  # weights='DEFAULT' downloads the pretrained model weights (trained on imagenet data.)

# model.to(device)

# # note: we need to change the model's last layer because the original vit classifies b/w 1000 classes and we only have 10 classes for ciphar 10

# train_losses = []
# val_losses = []

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100. * correct / total

    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_accuracy = 100. * correct / total

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, "
          f"Test Accuracy: {test_accuracy:.2f}%")


Epoch 1/10: 100%|██████████| 196/196 [01:47<00:00,  1.82it/s]


Epoch [1/10], Train Loss: 1.8632, Train Accuracy: 31.89%, Test Accuracy: 39.75%


Epoch 2/10: 100%|██████████| 196/196 [01:44<00:00,  1.87it/s]


Epoch [2/10], Train Loss: 1.5628, Train Accuracy: 43.16%, Test Accuracy: 45.35%


Epoch 3/10: 100%|██████████| 196/196 [01:45<00:00,  1.85it/s]


Epoch [3/10], Train Loss: 1.4289, Train Accuracy: 47.88%, Test Accuracy: 49.44%


Epoch 4/10: 100%|██████████| 196/196 [01:44<00:00,  1.87it/s]


Epoch [4/10], Train Loss: 1.3248, Train Accuracy: 52.17%, Test Accuracy: 53.88%


Epoch 5/10: 100%|██████████| 196/196 [01:45<00:00,  1.86it/s]


Epoch [5/10], Train Loss: 1.2317, Train Accuracy: 55.66%, Test Accuracy: 55.02%


Epoch 6/10: 100%|██████████| 196/196 [01:45<00:00,  1.86it/s]


Epoch [6/10], Train Loss: 1.1650, Train Accuracy: 57.95%, Test Accuracy: 56.01%


Epoch 7/10: 100%|██████████| 196/196 [01:45<00:00,  1.85it/s]


Epoch [7/10], Train Loss: 1.1044, Train Accuracy: 60.44%, Test Accuracy: 60.58%


Epoch 8/10: 100%|██████████| 196/196 [01:45<00:00,  1.86it/s]


Epoch [8/10], Train Loss: 1.0434, Train Accuracy: 62.62%, Test Accuracy: 60.52%


Epoch 9/10: 100%|██████████| 196/196 [01:45<00:00,  1.87it/s]


Epoch [9/10], Train Loss: 0.9984, Train Accuracy: 64.18%, Test Accuracy: 62.48%


Epoch 10/10: 100%|██████████| 196/196 [01:45<00:00,  1.86it/s]


Epoch [10/10], Train Loss: 0.9567, Train Accuracy: 65.91%, Test Accuracy: 64.28%
