# Vision Transformer (ViT) from Scratch on CIFAR-10
This notebook implements a Vision Transformer (ViT) model from scratch in PyTorch and trains it on the CIFAR-10 dataset.

## Setup & Imports

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data.dataloader as dataloader

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


## Hyperparameters

In [None]:

batch_size = 128
num_epochs = 30
learning_rate = 1e-4
patch_size = 4
data_set_root = "./data"


## Dataset & Dataloaders
We use CIFAR-10 with standard augmentations and normalization.

In [None]:

transform = transforms.Compose([
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_data = datasets.CIFAR10(data_set_root, train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(data_set_root, train=False, download=True, transform=test_transform)

# Train/Validation split
validation_split = 0.9
n_train_examples = int(len(train_data) * validation_split)
n_valid_examples = len(train_data) - n_train_examples
train_data, valid_data = torch.utils.data.random_split(train_data, [n_train_examples, n_valid_examples],
                                                       generator=torch.Generator().manual_seed(42))

train_loader = dataloader.DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_loader = dataloader.DataLoader(valid_data, batch_size=batch_size)
test_loader  = dataloader.DataLoader(test_data, batch_size=batch_size)

classes = train_data.dataset.classes
print("Classes:", classes)


## Patch Extraction
Splits an input image into smaller, non-overlapping patches. Each patch is then flattened and projected into a vector. This step converts the 2D image into a sequence, mimicking how words are tokenized in NLP tasks.

In [None]:

def extract_patches(image_tensor, patch_size=8):
    bs, c, h, w = image_tensor.size()
    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
    unfolded = unfold(image_tensor)
    unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size)
    return unfolded

# Visualize patches
dataiter = next(iter(test_loader))
test_images, test_labels = dataiter
patches = extract_patches(test_images, patch_size=patch_size)
patches_square = patches.reshape(test_images.shape[0], -1, 3, patch_size, patch_size)

grid_size = test_images.shape[2] // patch_size
print("Sequence Length:", grid_size**2)

plt.figure(figsize=(5, 5))
out = torchvision.utils.make_grid(patches_square[0], grid_size, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))


## Vision Transformer Architecture

## TransformerBlock

- Implements a single block of the Vision Transformer. It has:

- Multi-Head Self-Attention (MHSA): Lets the model learn relationships between all patches by attending to relevant regions.

- Layer Normalization & Residual Connections: Improve stability and prevent vanishing gradients.

- Feed-Forward Network (MLP): Further processes the embeddings after attention to add non-linearity and representation power.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=num_heads,
                                                    batch_first=True, dropout=0.1)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.LayerNorm(hidden_size * 2),
            nn.ELU(),
            nn.Linear(hidden_size * 2, hidden_size)
        )

    def forward(self, x):
        norm_x = self.norm1(x)
        x = self.multihead_attn(norm_x, norm_x, norm_x)[0] + x
        norm_x = self.norm2(x)
        x = self.mlp(norm_x) + x
        return x

## ViT Class (Vision Transformer)

- Patch Embedding Layer: Projects flattened patches into a hidden embedding dimension.

- Positional Embeddings: Learnable vectors added to patches so the model retains spatial order information.

- [CLS]-like Output Token (out_vec): A special learnable token that aggregates information from all patches for final classification.

- Stack of Transformer Blocks: Deep self-attention layers capture complex dependencies between patches.

- Classification Head (fc_out): Maps the output embedding to class logits (10 for CIFAR-10).

In [None]:
class ViT(nn.Module):
    def __init__(self, image_size, channels_in, patch_size, hidden_size, num_layers, num_heads=8, num_classes=10):
        super(ViT, self).__init__()
        self.patch_size = patch_size
        self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(hidden_size, num_classes)
        self.out_vec = nn.Parameter(torch.zeros(1, 1, hidden_size))
        seq_length = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_size).normal_(std=0.001))

    def forward(self, image):
        bs = image.shape[0]
        patch_seq = extract_patches(image, patch_size=self.patch_size)
        patch_emb = self.fc_in(patch_seq)
        patch_emb = patch_emb + self.pos_embedding
        embs = torch.cat((self.out_vec.expand(bs, 1, -1), patch_emb), 1)
        for block in self.blocks:
            embs = block(embs)
        return self.fc_out(embs[:, 0])

# Create model
model = ViT(image_size=test_images.shape[2],
            channels_in=test_images.shape[1],
            patch_size=patch_size,
            hidden_size=128,
            num_layers=6,
            num_heads=8).to(device)

print(model)


## Training Utilities

In [None]:

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                    T_max=num_epochs,
                                                    eta_min=0)
loss_fun = nn.CrossEntropyLoss()

def train(model, optimizer, loader, device, loss_fun, loss_logger):
    model.train()
    for x, y in tqdm(loader, leave=False, desc="Training"):
        fx = model(x.to(device))
        loss = loss_fun(fx, y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_logger.append(loss.item())
    return model, optimizer, loss_logger

def evaluate(model, device, loader):
    epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader, leave=False, desc="Evaluating"):
            fx = model(x.to(device))
            epoch_acc += (fx.argmax(1) == y.to(device)).sum().item()
    return epoch_acc / len(loader.dataset)


## Training Loop

In [None]:

training_loss_logger = []
validation_acc_logger = []
training_acc_logger = []

valid_acc = 0
train_acc = 0

pbar = trange(0, num_epochs, leave=False, desc="Epoch")
for epoch in pbar:
    pbar.set_postfix_str('Train %.2f%% | Val %.2f%%' % (train_acc * 100, valid_acc * 100))

    model, optimizer, training_loss_logger = train(model=model,
                                                   optimizer=optimizer,
                                                   loader=train_loader,
                                                   device=device,
                                                   loss_fun=loss_fun,
                                                   loss_logger=training_loss_logger)

    train_acc = evaluate(model=model, device=device, loader=train_loader)
    valid_acc = evaluate(model=model, device=device, loader=valid_loader)

    validation_acc_logger.append(valid_acc)
    training_acc_logger.append(train_acc)
    lr_scheduler.step()

print("Training Complete")


## Results

In [None]:

plt.figure(figsize = (10,5))
train_x = np.linspace(0, num_epochs, len(training_loss_logger))
plt.plot(train_x, training_loss_logger)
plt.title("ViT Training Loss")
plt.show()

plt.figure(figsize = (10,5))
train_x = np.linspace(0, num_epochs, len(training_acc_logger))
plt.plot(train_x, training_acc_logger, c = "y")
valid_x = np.linspace(0, num_epochs, len(validation_acc_logger))
plt.plot(valid_x, validation_acc_logger, c = "k")
plt.title("ViT Accuracy")
plt.legend(["Training accuracy", "Validation accuracy"])
plt.show()


## Test Accuracy & Predictions

In [None]:

test_acc = evaluate(model=model, device=device, loader=test_loader)
print("The total test accuracy is: %.2f%%" %(test_acc*100))

# Visualize predictions
with torch.no_grad():
    fx = model(test_images[:8].to(device))
    pred = fx.argmax(-1)

plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(test_images[:8], 8, normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))
plt.axis("off")

print("Predicted:", list(pred.cpu().numpy()))
print("True:", list(test_labels[:8].numpy()))


## Positional Embedding Visualization

In [None]:

pos_embs = model.pos_embedding.detach().cpu()
dist = F.cosine_similarity(pos_embs, pos_embs.reshape(pos_embs.shape[1], 1, -1), dim=-1).numpy()

n_rows_cols = 32//patch_size
fig, axes = plt.subplots(n_rows_cols, n_rows_cols, figsize=(5, 5))
for i in range(n_rows_cols):
    for j in range(n_rows_cols):
        img = dist[j + i * n_rows_cols].reshape(n_rows_cols, n_rows_cols)
        axes[i, j].imshow(img)
        axes[i, j].axis('off')
plt.tight_layout()
plt.show()
