# **Vision Transformer (ViT)**

In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import random
import timeit
import timm

In [2]:
class PatchEmbeddings(nn.Module):
  def __init__(self,
               in_channels=3,
               patch_size=4,
               embedding_dim=48):
    super().__init__()

    self.patch_size = patch_size

    self.patcher = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    self.flatten = nn.Flatten(start_dim=2, end_dim=3)

  def forward(self, x):

    x_patched = self.patcher(x)
    x_flattened = self.flatten(x_patched)

    print("Flattened vector: ", x_flattened.shape)

    return x_flattened.permute(0, 2, 1)

In [3]:
#testing the patcher
random_example = torch.randn(64,3,32,32)
print("Random image shape:", random_example.shape)

patcher = PatchEmbeddings()
patcher_output = patcher(random_example)
print("Patch Embeddings output:", patcher_output.shape)

Random image shape: torch.Size([64, 3, 32, 32])
Flattened vector:  torch.Size([64, 48, 64])
Patch Embeddings output: torch.Size([64, 64, 48])


In [4]:
class VisionTransformer(nn.Module):
  def __init__(self,
               img_size=32,
               patch_size=4,
               embedding_dim=48,
               in_channels=3,
               dropout=0.1,
               num_att_heads=8,
               mlp_size = 3072,
               num_transformer_layer=12,
               num_classes = 10):
    super().__init__()

    assert img_size % patch_size == 0, f"Image size must be divisble by patch size!!! Image shape is: {img_size}, while patch size is: {patch_size}"

    self.patch_embeddings = PatchEmbeddings(in_channels=in_channels,
                                            patch_size=patch_size,
                                            embedding_dim=embedding_dim)

    self.class_token = nn.Parameter(torch.randn(size=(1, 1, embedding_dim)), requires_grad=True)

    num_patches = (img_size // patch_size) * (img_size // patch_size) # determine the number of patches

    self.positional_embedding = nn.Parameter(torch.rand(size=(1, num_patches + 1, embedding_dim)),requires_grad=True)

    self.dropout = nn.Dropout(p=dropout)

    # single transformer encoder layer
    self.transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                                nhead=num_att_heads,
                                                                dim_feedforward=mlp_size,
                                                                dropout=dropout,
                                                                activation="gelu",
                                                                batch_first=True,
                                                                norm_first=True)

    # stack of N transformer encoder layers
    self.transformer_encoder = nn.TransformerEncoder(encoder_layer=self.transformer_encoder_layer, num_layers=num_transformer_layer)

    # MLP head
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim, out_features=num_classes)
    )

  def forward(self, x):

    patch_embeddings = self.patch_embeddings(x)

    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)

    x = torch.cat((class_token, patch_embeddings), dim=1)

    x = self.positional_embedding + x

    x = self.dropout(x)

    x = self.transformer_encoder(x)

    x = self.mlp_head(x[:, 0]) # passing just the class token through the MLP head to get final classification

    return x

In [5]:
cifar10_labels = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [6]:
img_example = torch.randn(1, 3, 32, 32)
print("Random image shape:", img_example.shape)

# Create ViT
vision_transformer = VisionTransformer(num_classes=len(cifar10_labels))
print("The output of Vision Transformer:", vision_transformer(img_example))

Random image shape: torch.Size([1, 3, 32, 32])
Flattened vector:  torch.Size([1, 48, 64])
The output of Vision Transformer: tensor([[-0.6941,  0.3187,  0.2410, -0.5704,  0.4219, -0.1650,  0.0610, -1.1321,
         -0.2041, -1.1518]], grad_fn=<AddmmBackward0>)




In [7]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [8]:
from torchinfo import summary
summary(vision_transformer, input_size=img_example.shape)

Flattened vector:  torch.Size([1, 48, 64])


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 10]                   310,800
├─PatchEmbeddings: 1-1                        [1, 64, 48]               --
│    └─Conv2d: 2-1                            [1, 48, 8, 8]             2,352
│    └─Flatten: 2-2                           [1, 48, 64]               --
├─Dropout: 1-2                                [1, 65, 48]               --
├─TransformerEncoder: 1-3                     [1, 65, 48]               --
│    └─ModuleList: 2-3                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [1, 65, 48]               307,632
│    │    └─TransformerEncoderLayer: 3-2      [1, 65, 48]               307,632
│    │    └─TransformerEncoderLayer: 3-3      [1, 65, 48]               307,632
│    │    └─TransformerEncoderLayer: 3-4      [1, 65, 48]               307,632
│    │    └─TransformerEncoderLayer: 3-5      [1, 65, 48]          

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# **Pretained ViT-base-16**

In [10]:
# Feature Extraction of ViT-base-16, I am freezing all inner layer while I change only the last layer

pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)


# Freeze all layers in pretrained ViT model
for param in pretrained_vit.parameters():
    param.requires_grad=False


# Update just the last layer of pretrained ViT, making it suitable for CIFAR-10 classification
embedding_dim = 768
pretrained_vit.heads = nn.Sequential(
    nn.LayerNorm(normalized_shape=embedding_dim),
    nn.Linear(in_features=embedding_dim,
             out_features=len(cifar10_labels))
)

pretrained_vit.to(device)

summary(model=pretrained_vit,
        input_size=(1, 3, 224, 224)) # ViT-base-16 accepts 224x224 size images

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 68.7MB/s]


Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 10]                   768
├─Conv2d: 1-1                                 [1, 768, 14, 14]          (590,592)
├─Encoder: 1-2                                [1, 197, 768]             151,296
│    └─Dropout: 2-1                           [1, 197, 768]             --
│    └─Sequential: 2-2                        [1, 197, 768]             --
│    │    └─EncoderBlock: 3-1                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-2                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-3                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-4                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-5                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-6                 [1, 197, 768]             (7,087,872)
│    │    └─EncoderBlock: 3-

# **Processing CIFAR-10 data**

In [11]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()

full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=pretrained_vit_transforms)

fraction = 0.6  # fraction of the dataset to use for training and validation
num_train_samples = int(len(full_train_dataset) * fraction)
random_seed = 42  # for reproducibility

torch.manual_seed(random_seed)
random_indices = torch.randperm(len(full_train_dataset))[:num_train_samples]

train_subset = Subset(full_train_dataset, random_indices)

# spliting the subset into training (90%) and validation (10%) sets
train_indices, val_indices = train_test_split(range(len(train_subset)), test_size=0.1, random_state=random_seed, shuffle=True)

train_dataset = Subset(train_subset, train_indices)
val_dataset = Subset(train_subset, val_indices)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


print(f"Training dataset size is: {len(train_dataset)}")
print(f"Validation dataset size is: {len(val_dataset)}")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:18<00:00, 9.05MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Training dataset size is: 27000
Validation dataset size is: 3000


# **Feature Extraction of Pre-Trained ViT-base-16(i.e., training just the head(last layer) for CIFAR-10 classification)**

In [12]:
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 15
for epoch in range(epochs):
    pretrained_vit.train()
    train_loss = 0
    train_correct = 0
    total_train = 0

    for batch in tqdm(train_loader, desc="ViT training for CIFAR-10 classification"):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = pretrained_vit(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = torch.argmax(outputs, 1)
        train_correct += (preds == labels).sum().item()
        total_train += labels.size(0)

    train_accuracy = train_correct / total_train
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}")

    pretrained_vit.eval()
    val_loss = 0
    val_correct = 0
    total_val = 0
    with torch.no_grad():
        for batch in val_loader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            outputs = pretrained_vit(images)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()
            preds = torch.argmax(outputs, 1)
            val_correct += (preds == labels).sum().item()
            total_val += labels.size(0)

    val_accuracy = val_correct / total_val
    print(f"Epoch {epoch+1}/{epochs}, Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.4f}")

ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:10<00:00,  1.76s/it]


Epoch 1/15, Train Loss: 0.3168, Train Accuracy: 0.9060
Epoch 1/15, Val Loss: 0.1776, Val Accuracy: 0.9420


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:14<00:00,  1.77s/it]


Epoch 2/15, Train Loss: 0.1635, Train Accuracy: 0.9464
Epoch 2/15, Val Loss: 0.1586, Val Accuracy: 0.9443


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:14<00:00,  1.78s/it]


Epoch 3/15, Train Loss: 0.1389, Train Accuracy: 0.9546
Epoch 3/15, Val Loss: 0.1461, Val Accuracy: 0.9487


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:15<00:00,  1.78s/it]


Epoch 4/15, Train Loss: 0.1236, Train Accuracy: 0.9609
Epoch 4/15, Val Loss: 0.1445, Val Accuracy: 0.9490


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:14<00:00,  1.77s/it]


Epoch 5/15, Train Loss: 0.1133, Train Accuracy: 0.9633
Epoch 5/15, Val Loss: 0.1433, Val Accuracy: 0.9487


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:14<00:00,  1.77s/it]


Epoch 6/15, Train Loss: 0.1042, Train Accuracy: 0.9661
Epoch 6/15, Val Loss: 0.1470, Val Accuracy: 0.9490


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:13<00:00,  1.77s/it]


Epoch 7/15, Train Loss: 0.0975, Train Accuracy: 0.9684
Epoch 7/15, Val Loss: 0.1430, Val Accuracy: 0.9503


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:16<00:00,  1.78s/it]


Epoch 8/15, Train Loss: 0.0923, Train Accuracy: 0.9710
Epoch 8/15, Val Loss: 0.1444, Val Accuracy: 0.9500


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:15<00:00,  1.78s/it]


Epoch 9/15, Train Loss: 0.0868, Train Accuracy: 0.9720
Epoch 9/15, Val Loss: 0.1501, Val Accuracy: 0.9493


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:13<00:00,  1.77s/it]


Epoch 10/15, Train Loss: 0.0824, Train Accuracy: 0.9733
Epoch 10/15, Val Loss: 0.1464, Val Accuracy: 0.9520


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:13<00:00,  1.77s/it]


Epoch 11/15, Train Loss: 0.0781, Train Accuracy: 0.9749
Epoch 11/15, Val Loss: 0.1490, Val Accuracy: 0.9507


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:13<00:00,  1.77s/it]


Epoch 12/15, Train Loss: 0.0750, Train Accuracy: 0.9767
Epoch 12/15, Val Loss: 0.1511, Val Accuracy: 0.9473


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:12<00:00,  1.77s/it]


Epoch 13/15, Train Loss: 0.0711, Train Accuracy: 0.9781
Epoch 13/15, Val Loss: 0.1542, Val Accuracy: 0.9497


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:13<00:00,  1.77s/it]


Epoch 14/15, Train Loss: 0.0693, Train Accuracy: 0.9781
Epoch 14/15, Val Loss: 0.1583, Val Accuracy: 0.9470


ViT training for CIFAR-10 classification: 100%|██████████| 211/211 [06:12<00:00,  1.77s/it]


Epoch 15/15, Train Loss: 0.0647, Train Accuracy: 0.9803
Epoch 15/15, Val Loss: 0.1575, Val Accuracy: 0.9497


# **Testing new ViT on CIFAR-10 data**

In [13]:
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=pretrained_vit_transforms)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
print(f"Test dataset size is: {len(test_dataset)}")

pretrained_vit.eval()

test_loss = 0
test_correct = 0
total_test = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing new ViT"):
        images, labels = images.to(device), labels.to(device)

        outputs = pretrained_vit(images)

        loss = loss_fn(outputs, labels)
        test_loss += loss.item()

        preds = torch.argmax(outputs, 1)

        test_correct += (preds == labels).sum().item()
        total_test += labels.size(0)

test_accuracy = test_correct / total_test
print(f"Test Loss: {test_loss / len(test_loader):.4f}, Test Accuracy: {test_accuracy:.4f}")

Files already downloaded and verified
Test dataset size is: 10000


Testing new ViT: 100%|██████████| 79/79 [01:09<00:00,  1.13it/s]

Test Loss: 0.1240, Test Accuracy: 0.9643



