## Motivation
In the previous notebook we implemented ViT from scratch but got poor accuarcy and stated some possible fixes such as:
- Stronger augmentation
- Longer training
- Weight decay
- LR scheduler
- Proper train/eval modes

In this notebook, we will apply these modifications and see whether they will imporve the accuracy or not



##Import the dependencies

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##Load CIFAR-10 dataset
with data augmentation added (RandomCrop & Flip)

In [11]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # NEW: Stronger augmentation
    transforms.RandomHorizontalFlip(),     # NEW: Helps ViT generalize on small data
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=train_transform
)
testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=test_transform
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2
)

##Patch Embedding
An input image of shape `(H, W, C)` is partitioned into non-overlapping patches of size
`P × P`. Each patch is flattened and linearly projected into a vector of dimension `D`.


In [12]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
    # Conv2d implementation of patch embedding (efficient)
    # Explicit patch_size = 4 for CIFAR-10

    def forward(self, x):
        x = self.proj(x)        # (B, E, H', W')
        x = x.flatten(2)        # (B, E, N)
        x = x.transpose(1, 2)   # (B, N, E)
        return x


## Transformer Encoder Block
Transformers do not inherently understand spatial order. Positional embeddings are added to preserve information about **where each patch is located** in the image.


In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads, batch_first=True
        )
        self.norm2 = nn.LayerNorm(embed_dim)

        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(
            self.norm1(x), self.norm1(x), self.norm1(x)
        )[0]
        x = x + self.mlp(self.norm2(x))
        return x
    # Residual connections + layer norm: standard improvement over naive MLP


###Vision Transformer Model
Self-attention allows each patch to attend to **all other patches** in the image. This enables the model to capture **global relationships** early, rather than relying on stacked convolutions.

In [15]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        num_classes=10,
        embed_dim=256,
        depth=6,
        num_heads=8
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(
            img_size, patch_size, 3, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )

        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = self.blocks(x)
        x = self.norm(x)

        return self.head(x[:, 0])

## Model Optimizer, Scheduler, and Trainer

In [16]:
model = VisionTransformer().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(), lr=3e-4, weight_decay=1e-4
)

epochs = 50
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs
)
# - Added CosineAnnealingLR (smooth LR decay)
# - Increased epochs from 20 → 50

In [17]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()

    avg_loss = running_loss / len(trainloader)
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")

Epoch [1/50] - Loss: 1.7802
Epoch [2/50] - Loss: 1.4346
Epoch [3/50] - Loss: 1.2951
Epoch [4/50] - Loss: 1.2108
Epoch [5/50] - Loss: 1.1575
Epoch [6/50] - Loss: 1.1012
Epoch [7/50] - Loss: 1.0600
Epoch [8/50] - Loss: 1.0264
Epoch [9/50] - Loss: 0.9949
Epoch [10/50] - Loss: 0.9580
Epoch [11/50] - Loss: 0.9336
Epoch [12/50] - Loss: 0.9117
Epoch [13/50] - Loss: 0.8814
Epoch [14/50] - Loss: 0.8507
Epoch [15/50] - Loss: 0.8247
Epoch [16/50] - Loss: 0.8040
Epoch [17/50] - Loss: 0.7716
Epoch [18/50] - Loss: 0.7520
Epoch [19/50] - Loss: 0.7168
Epoch [20/50] - Loss: 0.6964
Epoch [21/50] - Loss: 0.6705
Epoch [22/50] - Loss: 0.6474
Epoch [23/50] - Loss: 0.6181
Epoch [24/50] - Loss: 0.5978
Epoch [25/50] - Loss: 0.5651
Epoch [26/50] - Loss: 0.5419
Epoch [27/50] - Loss: 0.5146
Epoch [28/50] - Loss: 0.4817
Epoch [29/50] - Loss: 0.4544
Epoch [30/50] - Loss: 0.4278
Epoch [31/50] - Loss: 0.4020
Epoch [32/50] - Loss: 0.3781
Epoch [33/50] - Loss: 0.3510
Epoch [34/50] - Loss: 0.3259
Epoch [35/50] - Loss: 0

## Evaluation

In [18]:
correct, total = 0, 0
model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 76.23%


The accuracy was drastically improved from 66.73 to 76.23%% !! but it is still less than classic resnet. Why?

Vision Transformers are known to require large-scale datasets or pretraining to outperform convolutional models.
CIFAR-10 is intentionally used here for controlled comparison with CNN, VGG, and ResNet architectures.
The achieved accuracy reflects the data-hungry nature of ViTs when trained from scratch.

**Conclusion**: Vision Transformers underperform CNNs on small datasets like CIFAR-10 due to weak inductive bias.