<a href="https://colab.research.google.com/github/ajsal-ali/vit-from-scratch/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer (ViT) from Scratch

In this notebook, we will implement a Vision Transformer (ViT) using PyTorch. The ViT model applies the transformer architecture—originally designed for NLP—to image classification tasks by treating image patches as tokens. We'll go through building patch embeddings, multi-head self-attention, transformer blocks, and finally assembling the complete ViT for classifying images.


In [None]:
!pip install torch
!pip install torchvision

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## ViT Architecture Components

Below is the full implementation of a Vision Transformer (ViT) broken down into modular PyTorch classes:

- **PatchEmbedding**: Converts an image into a sequence of flattened patches and projects them into a lower-dimensional embedding space. It uses unfolding to extract patches and a linear layer to embed them.

- **Head & MultiHeadAttention**: Implements scaled dot-product attention for each head, then concatenates their outputs. Each head computes attention scores between patches and aggregates patch information.

- **FeedForward**: A standard MLP with a hidden layer (4× expansion) and ReLU activation used after attention for richer representations.

- **Block**: Represents one transformer layer. It contains multi-head self-attention and a feedforward network, both wrapped with residual connections and layer normalization.

- **VisionTransformer**: The main ViT class. It includes a learnable class token, positional embeddings, and a sequence of transformer blocks. After processing, it extracts the class token's output for final classification via a linear layer.

This architecture is compatible with image sizes like 32×32 and can be trained on datasets like CIFAR-10.


In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self,image_size, patch_size, in_channels, embed_dim):
        super().__init__()

        self.image_size= image_size
        self.patch_size=patch_size
        self.in_channels=in_channels
        self.embed_dim=embed_dim
        self.patch_dim=in_channels*patch_size*patch_size
        num_patches=(self.image_size//self.patch_size)**2
        self.proj=nn.Linear(self.patch_dim,self.embed_dim)
    def forward(self,x):
        B,C,H,W=x.shape
        # print(f"Input shape: {x.shape}")
        x=x.unfold(2,self.patch_size,self.patch_size).unfold(3,self.patch_size,self.patch_size)
        # print(f"Unfolded shape: {x.shape}")
        x = x.permute(0, 2, 3, 1, 4, 5)
        # print(f"Permuted shape: {x.shape}")
        x=x.contiguous().view(B,-1,self.patch_dim)
        # print(f"Reshaped shape: {x.shape}")
        x=self.proj(x)
        return x



class Head(nn.Module):
    def __init__(self, head_size,n_embd, block_size, dropout=0.0):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        head_size = k.size(-1)
        wei = q @ k.transpose(-2,-1) * head_size**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size , n_embd, block_size, dropout=0.0):

        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
class FeedFoward(nn.Module):
    def __init__(self, n_embd, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self, n_embd, n_head, block_size, dropout=0.0):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size , n_embd, block_size, dropout)
        self.ffwd = FeedFoward(n_embd,dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, image_size=32, patch_size=8, num_classes=10,
                 dim=128, depth=4, heads=4, mlp_dim=256, channels=3, dropout=0.1):
        super(VisionTransformer, self).__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size"
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size * patch_size
        self.patch_embedding = PatchEmbedding(image_size, patch_size, channels, dim)
        self.class_embedding = nn.Parameter(torch.randn(1, 1, dim))
        self.possition_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.dropout = nn.Dropout(dropout)
        self.blocks= nn.Sequential(*[Block(dim, heads, num_patches + 1, dropout) for _ in range(depth)])
        self.to_class_token = nn.Identity()
        self.ln= nn.LayerNorm(dim)
        self.linear = nn.Linear(dim, num_classes)
    def forward(self, x):
        B, C, H, W = x.shape
        x= self.patch_embedding(x)
        cls_token = self.class_embedding.expand(B, -1, -1)
        x=torch.cat((cls_token,x),dim=1)
        x+= self.possition_embedding
        x=self.blocks(x)
        x=self.ln(x[:, 0])
        x=self.linear(x)
        return x



## Data Preparation: CIFAR-10

We load the CIFAR-10 dataset using `torchvision`. Basic transformations are applied to convert the images to tensors. Although no data augmentation or normalization is used here, these can be easily added for improved performance.

- **Train/Test Split**: CIFAR-10 is split into training and test sets.
- **Transformations**: Currently only `ToTensor()` is applied, which scales pixel values to [0, 1].
- **Data Loaders**: Batches are created for training and testing with a batch size of 1024.


In [6]:
import torchvision
import torchvision.transforms as T
# Define transformations for the training and test sets
train_transform = T.Compose([
    T.ToTensor(),
    # Normally, you'd add normalization and perhaps random flips/crops here for augmentation
])
test_transform = T.Compose([
    T.ToTensor(),
    # Corresponding normalization (using same mean/std as train if applied)
])
# Download and load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True)
# Create data loaders for batching
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

100%|██████████| 170M/170M [00:05<00:00, 31.2MB/s]


Train batches: 49, Test batches: 10


## Training the Vision Transformer

We initialize the ViT model with defined hyperparameters and train it on CIFAR-10 for 200 epochs using the Adam optimizer and cross-entropy loss. The model is trained on GPU if available.


In [37]:
# Move model to device (GPU if available)
torch.backends.cudnn.benchmark = True
model = VisionTransformer(image_size=32, patch_size=4, num_classes=10, dim=128, depth=4, heads=4, mlp_dim=256)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
epochs = 200
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)            # forward pass
        loss = criterion(outputs, labels)  # compute loss
        loss.backward()                    # backpropagate gradients
        optimizer.step()                   # update parameters

        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

Epoch [1/20], Loss: 0.2969
Epoch [2/20], Loss: 0.2848
Epoch [3/20], Loss: 0.2836
Epoch [4/20], Loss: 0.2870
Epoch [5/20], Loss: 0.2821
Epoch [6/20], Loss: 0.2771
Epoch [7/20], Loss: 0.2789
Epoch [8/20], Loss: 0.2752
Epoch [9/20], Loss: 0.2772
Epoch [10/20], Loss: 0.2727
Epoch [11/20], Loss: 0.2730
Epoch [12/20], Loss: 0.2705
Epoch [13/20], Loss: 0.2713
Epoch [14/20], Loss: 0.2719
Epoch [15/20], Loss: 0.2667
Epoch [16/20], Loss: 0.2670
Epoch [17/20], Loss: 0.2666
Epoch [18/20], Loss: 0.2597
Epoch [19/20], Loss: 0.2561
Epoch [20/20], Loss: 0.2627


## Model Evaluation

The trained Vision Transformer is evaluated on the test set without gradient calculations. Accuracy is computed by comparing predicted and true labels.


In [38]:
model.eval()
correct = 0
total = 0
with torch.no_grad():  # no gradient needed for eval
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 62.90%
