In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import os

In [None]:
# Hyperparameters
batch_size = 128
img_size = 32           # CIFAR-10 images are 32x32
patch_size = 4          # (32/4=8 → 8x8 = 64 patches)
num_channels = 3        # RGB images
num_patches = (img_size // patch_size) ** 2
embed_dim = 192          # Must be divisible by num_heads
num_heads = 8
mlp_dim = 4 * embed_dim
transformer_layers = 6  # Deeper network
dropout_rate = 0.1
learning_rate = 0.003
weight_decay = 0.05
epochs = 50
warmup_epochs = 5       # For linear warmup

# Normalization Parameters
mean = [0.4914,0.4822, 0.4465]
std = [0.2471, 0.2435, 0.2616]

# Output Class
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# transform train for CIFAR-10
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# transform Test
transform_test = transforms.Compose([
    # We don't need other expression we used in training format
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.projection = nn.Conv2d(
            num_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return self.dropout(x)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # First Normalization
        self.norm1 = nn.LayerNorm(embed_dim)

        # self attention is being computed (query, key, and value).
        self.attention = nn.MultiheadAttention(embed_dim, num_heads,
                                               dropout=dropout_rate,
                                               batch_first=True)
        
        self.attention_dropout = nn.Dropout(p=dropout_rate)

        # Second Normalization
        self.norm2 = nn.LayerNorm(embed_dim)

        # Multilayer Perceptron
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(p=dropout_rate)
        )

    def forward(self, x):
        norm_x = self.norm1(x)
        attention_output, _ = self.attention(norm_x, norm_x, norm_x)

        # Residual connection
        x = x + self.attention_dropout(attention_output)
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()

        # CLS embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Positional embedding
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        self.transformer = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_layers)])
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)  # 10 classes
        )

    def forward(self, patch_inputs):
        B = patch_inputs.shape[0]
        patch_inputs = self.patch_embedding(patch_inputs)
        cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, embed_dim]
        patch_inputs = torch.cat((cls_tokens, patch_inputs), dim=1) # concatenation
        patch_inputs = patch_inputs + self.positional_embedding

        patch_inputs = self.transformer(patch_inputs) 
        patch_inputs = patch_inputs[:, 0] # Only the CLS token is selected
        return self.head(patch_inputs)