# Vision Transformer (ViT) on CIFAR-10

This notebook demonstrates how to implement and train a Vision Transformer (ViT) from scratch using PyTorch on the CIFAR-10 dataset.

- **Author:** [Your Name]
- **Date:** April 2025
- **Dataset:** CIFAR-10 (60,000 32x32 color images in 10 classes)

---

## 1. Import Required Libraries
We start by importing PyTorch, torchvision, and other necessary libraries.

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim

## 2. Vision Transformer Components
Below are the core building blocks for the Vision Transformer: Patch Embedding, Embedding Layer, Transformer Encoder Block, and the full ViT model.
Each class is explained in comments within the code.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

In [None]:
class ViTEmbeddings(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.dropout(x)
        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim)
        )

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_norm = self.norm1(x)
        
        attention_output, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + self.dropout(attention_output)
        
        x_norm = self.norm2(x)
        ffn_output = self.feed_forward(x_norm)
        x = x + self.dropout(ffn_output)
        
        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_dim=mlp_dim,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super().__init__()
        self.embeddings = ViTEmbeddings(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        self.encoder = TransformerEncoder(
            num_layers=num_layers,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout=dropout
        )

        self.cls_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.encoder(x)
        cls_token_output = x[:, 0]
        logits = self.cls_head(cls_token_output)
        return logits

## 3. Data Preparation and Training Utilities
We define the data transforms, training, and evaluation functions.
- **Transform:** Resizes images, converts to tensor, and normalizes.
- **train:** Trains the model for a given number of epochs.
- **evaluate:** Evaluates the model on the test set.

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def train(model, train_loader, criterion, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

## 4. Training and Evaluation
We load the CIFAR-10 dataset, initialize the Vision Transformer, and train it for one epoch. The model is then evaluated on the test set.

**Note:** For demonstration, training is set to 10 epoch. Increase epochs for better accuracy.

In [None]:
if __name__ == '__main__':
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    BATCH_SIZE = 32
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(device)

    img_size = 32
    patch_size = 4
    num_patches = (img_size // patch_size) ** 2
    
    model = VisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        in_channels=3,
        embed_dim=128,
        num_heads=4,
        mlp_dim=256,
        num_layers=6,
        num_classes=10,
        dropout=0.1
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    train(model, train_loader, criterion, optimizer, device, epochs=1)
    evaluate(model, test_loader, device)

    torch.save(model, "vit_cifar10.pth")

## 5. Results and Model Saving
After training, the test accuracy is printed and the model is saved to `vit_cifar10.pth`.

You can now use this notebook as a template for your own ViT experiments!