## Motivation
Traditional computer vision models like CNNs process images **locally** where each convolution layer sees only a small region of the image, as they rely on local receptive fields and strong inductive biases (spatial locality and translation equivariance).
Although being effective on small and medium-sized datasets, CNNs require deep hierarchies to capture global context.

###Why ViT?
Vision Transformers (ViT) take a different approach:
they they remove convolutional operations entirely and model images as sequences of visual tokens, enabling direct global reasoning via self-attention.


##Import the dependencies

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

##Load CIFAR-10 dataset

In [2]:
# CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


100%|██████████| 170M/170M [00:05<00:00, 31.3MB/s]


##Patch Embedding
An input image of shape `(H, W, C)` is partitioned into non-overlapping patches of size
`P × P`. Each patch is flattened and linearly projected into a vector of dimension `D`.


In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)            # (B, E, H', W')
        x = x.flatten(2)            # (B, E, N)
        x = x.transpose(1, 2)       # (B, N, E)
        return x


## Transformer Encoder Block
Transformers do not inherently understand spatial order. Positional embeddings are added to preserve information about **where each patch is located** in the image.


In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

###Vision Transformer Model
Self-attention allows each patch to attend to **all other patches** in the image. This enables the model to capture **global relationships** early, rather than relying on stacked convolutions.

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = self.blocks(x)
        x = self.norm(x)

        cls_output = x[:, 0]
        return self.head(cls_output)

## Model Training

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to("cuda")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

In [8]:
for epoch in range(20):
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader):.4f}")

Epoch 1, Loss: 1.6314
Epoch 2, Loss: 1.2619
Epoch 3, Loss: 1.1252
Epoch 4, Loss: 1.0384
Epoch 5, Loss: 0.9710
Epoch 6, Loss: 0.9030
Epoch 7, Loss: 0.8483
Epoch 8, Loss: 0.7844
Epoch 9, Loss: 0.7255
Epoch 10, Loss: 0.6616
Epoch 11, Loss: 0.5961
Epoch 12, Loss: 0.5351
Epoch 13, Loss: 0.4638
Epoch 14, Loss: 0.3972
Epoch 15, Loss: 0.3345
Epoch 16, Loss: 0.2838
Epoch 17, Loss: 0.2528
Epoch 18, Loss: 0.2088
Epoch 19, Loss: 0.1896
Epoch 20, Loss: 0.1769


## Evaluation

In [9]:
correct, total = 0, 0
model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 63.68%


Obviously the accuracy is poor. Why?

1- ViTs are data-hungry and our dataset is not rich enough:
CIFAR-10 is 50k small images (32×32) only

2- ViT has weak inductive bias (no locality like CNNs):
From scratch → struggles to generalize

3- No strong data augmentation:
CNNs generalize well even with weak augmentation
ViTs depend heavily on: RandomCrop, HorizontalFlip, Color jitter, MixUp / CutMix

4- Training duration may be too short as ViTs converge slower than CNNs

5- Learning-rate scheduling may help: ViTs are very sensitive to LR Constant LR → suboptimal minima
In the next notebook, we will apply these refinements and see whether it works