# Vision Transformer (ViT) Architecture Overview

Vision Transformer architecture comprises several key stages:

* Image Patching and Embedding
* Positional Encoding
* Transformer Encoder
* Classification Head (MLP Head)

![ViT Diagram](https://media.geeksforgeeks.org/wp-content/uploads/20250108160202257232/Vision-Transformer-Architecture_.webp)


## Table of Contents
- [1. Image Patching and Embedding](#1-image-patching-and-embedding)
- [2. Positional Encoding](#2-positional-encoding)
- [3. Transformer Encoder Layers](#3-transformer-encoder-layers)
  - [Multi-Head Self-Attention (MSA)](#1-multi-head-self-attention-msa)
  - [Feed-Forward Network (FFN)](#2-feed-forward-network-ffn)
  - [Stacking Encoder Layers](#stacking-encoder-layers)
- [4. Classification Token (CLS Token)](#4-classification-token-cls-token)
- [5. MLP Head (Classification Head)](#5-mlp-head-classification-head)
- [Vision Transformer Architecture Summary](#vision-transformer-architecture-summary)





## 1. Image Patching and Embedding

The first and most critical step in the ViT pipeline is to convert the image into a sequence of patches, similar to the tokens in an NLP model.

* **Patch Splitting:** The input image, usually of size H×W×C (height, width, and channels), is divided into fixed-size patches. For example, an image of size 224x224 can be split into non-overlapping 16x16 patches, resulting in 224/16 × 224/16 = 14×14 = 196 patches.

* **Patch Flattening:** Each patch is then flattened into a 1D vector. A patch of size P×P×C (e.g., 16x16x3) is reshaped into a vector of size P²×C, creating 196 patch vectors for an image.

* **Patch Embedding:** Each flattened patch is projected into a higher-dimensional space (embedding dimension D) through a learnable linear projection. This linear transformation enables the model to learn richer feature representations for each patch. The result is a sequence of patch embeddings, each representing a part of the image. The total number of patches in the sequence is N = H/P × W/P, where N is the number of patches. For instance, with a 224x224 image and 16x16 patches, we have 196 patches.

In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2) # flatten last two dimensions and transpose from (B, embding_dim, num_patches) to (B, num_patches, embding_dim)
        return x


## 2. Positional Encoding

Transformers do not inherently capture the spatial order of input sequences. Since the patches are processed as independent tokens, it's essential to introduce **positional encodings** to retain the spatial structure of the original image.

* **Positional Embedding:** Positional encodings are added to each patch embedding to encode information about the location of patches within the image. These embeddings help the model understand the spatial relationships between patches, similar to how transformers in NLP encode the positions of words in a sentence.

* **Learned vs. Fixed Positional Encoding:** In ViTs, positional encodings can either be learned during training or predefined (fixed). Most implementations of Vision Transformers use learnable positional encodings.

In [2]:
# 2. Adding Positional Embeddings
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, seq_len):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))  #1 is batch size, adjusted for [CLS] token, embed_dim is embedding dimension

    def forward(self, x):
        return x + self.pos_embed

## 3. Transformer Encoder Layers

Once the patches are embedded and augmented with positional information, they are passed through a stack of transformer encoder layers. These layers consist of two primary components: Multi-Head Self-Attention (MSA) and a Feed-Forward Neural Network (FFN).

### 1. Multi-Head Self-Attention (MSA)

* **Self-Attention:** The self-attention mechanism allows each patch to attend to every other patch in the sequence. This means that the transformer can model long-range dependencies and relationships between different parts of the image. Each patch computes a weighted sum of the values of all other patches based on its similarity to them, known as the attention score.

  Attention(Q,K,V) = softmax(QK^T / √d_k)V

  Where Q (query), K (key), and V (value) are learned linear projections of the input patch embeddings.
  - The dot product between queries and keys determines the attention score, and softmax normalizes it.
  - The weighted sum of values determines the output.

* **Multi-Head Attention:** The attention mechanism is computed in parallel across multiple attention heads, allowing the model to focus on different parts of the image simultaneously.

### 2. Feed-Forward Network (FFN)

After self-attention, the patches are passed through a **feed-forward network (FFN)**. The FFN consists of two fully connected layers with a non-linear activation function (typically GELU) in between.

Each transformer encoder layer includes residual (skip) connections and layer normalization to stabilize training and improve convergence. These techniques ensure that the deeper layers do not lose important information from the earlier layers.

### Stacking Encoder Layers

Multiple transformer encoder layers (e.g., 12, 24 layers) are stacked on top of each other. Each layer refines the patch embeddings, allowing the model to build more complex and abstract representations of the image.

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        return self.attn(x, x, x)[0] # return only the output, ignore the attention weights and 3 x use for query, key, value

## 4. Classification Token (CLS Token)

In Vision Transformers, a special classification token (CLS token) is introduced at the beginning of the input sequence. This token serves a critical role: it gathers information from all the patches throughout the transformer layers.

The CLS token learns to represent the entire image by attending to the different patches through the self-attention mechanism. At the output of the transformer layers, the CLS token is extracted and passed to a classifier for the final prediction.

## 5. MLP Head (Classification Head)

After the transformer encoders process the sequence of patches and the CLS token, the output corresponding to the CLS token is used for classification.

The output of the CLS token is fed into an MLP, typically consisting of one or two fully connected layers. A softmax layer is applied at the end of the MLP for classification tasks, predicting the image's label.

In [4]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.attn(self.norm1(x)) # Residual connection as paper described
        x = x + self.mlp(self.norm2(x)) # Residual connection as paper described
        return x

## Building the Vision Transformer Architecture

In [5]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=10, embed_dim=768, num_heads=8, depth=6, mlp_dim=1024):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, (img_size // patch_size) ** 2)
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim) for _ in range(depth)
        ])
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.mlp_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embedding(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x)
        return self.mlp_head(x[:, 0])

##Training the Vision Transformer

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Device configuration (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataset and DataLoader
train_data = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=32,
    shuffle=True
)

# Model
model = VisionTransformer().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        # Move tensors to device (GPU/CPU)
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backpropagation
        loss.backward()

        # Update weights
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/5], Loss: {avg_loss:.4f}")


Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 43.0MB/s]


Epoch [1/5], Loss: 2.8157
Epoch [2/5], Loss: 2.3300
Epoch [3/5], Loss: 2.3229
Epoch [4/5], Loss: 2.3198
Epoch [5/5], Loss: 2.3186


## Vision Transformer Architecture Summary

To summarize, the Vision Transformer architecture involves the following key steps:

1. **Input Image Processing:** The input image is divided into patches, which are flattened and embedded using a linear projection.

2. **Positional Encoding:** Positional encodings are added to the patch embeddings to retain spatial information.

3. **Transformer Encoder:** The patch embeddings (along with the CLS token) are passed through multiple transformer encoder layers, which include multi-head self-attention and feed-forward networks.

4. **Classification Head:** The CLS token's output is extracted and fed into an MLP for final classification.

##pretrained model

In [11]:
## Import Libraries

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm
from tqdm import tqdm

In [12]:
## Load CIFAR-10 Dataset
# Define transforms for CIFAR-10 (32x32) to ViT input size (224x224)
transform_train = transforms.Compose([
    transforms.Resize(224),  # Resize to ViT input size
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])


In [13]:
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

In [14]:
# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

In [15]:
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

In [16]:

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Training samples: 50000
Test samples: 10000


In [17]:
# Load pretrained ViT model from timm library
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Using device: {device}")
print(f"Model: {model.__class__.__name__}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Using device: cuda
Model: VisionTransformer


In [18]:
## Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [19]:
## Training Function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

In [20]:
## Evaluation Function
def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

In [None]:
## Train the Model
num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    # Update learning rate
    scheduler.step()

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


Epoch 1/10


Training: 100%|██████████| 1563/1563 [27:41<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.77it/s]


Train Loss: 0.2382, Train Acc: 92.28%
Test Loss: 0.2292, Test Acc: 92.56%

Epoch 2/10


Training: 100%|██████████| 1563/1563 [27:37<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.77it/s]


Train Loss: 0.1591, Train Acc: 94.72%
Test Loss: 0.1849, Test Acc: 93.71%

Epoch 3/10


Training: 100%|██████████| 1563/1563 [27:36<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.1301, Train Acc: 95.57%
Test Loss: 0.1861, Test Acc: 94.03%

Epoch 4/10


Training: 100%|██████████| 1563/1563 [27:36<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.1094, Train Acc: 96.26%
Test Loss: 0.2021, Test Acc: 93.71%

Epoch 5/10


Training: 100%|██████████| 1563/1563 [27:38<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.0963, Train Acc: 96.67%
Test Loss: 0.1961, Test Acc: 93.81%

Epoch 6/10


Training: 100%|██████████| 1563/1563 [27:36<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.0180, Train Acc: 99.38%
Test Loss: 0.1118, Test Acc: 96.76%

Epoch 7/10


Training: 100%|██████████| 1563/1563 [27:37<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.0032, Train Acc: 99.91%
Test Loss: 0.1340, Test Acc: 96.64%

Epoch 8/10


Training: 100%|██████████| 1563/1563 [27:36<00:00,  1.06s/it]
Evaluating: 100%|██████████| 313/313 [01:53<00:00,  2.76it/s]


Train Loss: 0.0019, Train Acc: 99.96%
Test Loss: 0.1532, Test Acc: 96.77%

Epoch 9/10


Training:  44%|████▍     | 690/1563 [12:11<15:25,  1.06s/it]

In [1]:
# Save the trained model
torch.save(model.state_dict(), 'vit_cifar10.pth')
print("Model saved successfully!")

NameError: name 'torch' is not defined

In [None]:
# Load the saved model
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('vit_cifar10.pth')) # load the train model
model = model.to(device)
model.eval()

In [None]:
# Test on a single image
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# Get a sample image
sample_image, sample_label = test_dataset[0]
sample_image = sample_image.unsqueeze(0).to(device)

with torch.no_grad():
    output = model(sample_image)
    _, predicted = output.max(1)

print(f"Predicted: {classes[predicted.item()]}")
print(f"Actual: {classes[sample_label]}")