In [1]:
!pip install torch torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

# Hyperparameters 
image_size = 32
patch_size = 4
num_classes = 10
embed_dim = 384
num_heads = 8
num_layers = 8
mlp_ratio = 4
dropout_rate = 0.1
batch_size = 128
epochs = 100
learning_rate = 3e-4
weight_decay= 0.05

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data augmentation 
train_transforms = transforms.Compose([
    transforms.RandomCrop(image_size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.RandomRotation(degrees=15), 
    transforms.ToTensor(),

    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2471, 0.2435, 0.2616])
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2471, 0.2435, 0.2616])
])

# Load CIFAR-10 
train_dataset = torchvision.datasets.CIFAR10(
    root='data', train=True, download=True, transform=train_transforms)
test_dataset = torchvision.datasets.CIFAR10(
    root='data', train=False, download=True, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s] 


Train batches: 391, Test batches: 79


In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, dropout):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        # CLS token 
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.num_patches = num_patches

    def forward(self, x):
        B = x.shape[0]
        # Convert 
        x = self.proj(x)
        # Flatten patches
        x = x.flatten(2).transpose(1, 2)  # (B, N, embed_dim)
        # Expand CLS token 
        cls_tokens = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_tokens, x), dim=1)          
        x = x + self.pos_embed
        x = self.dropout(x)
        return x

class TransformerEncoderLayer(nn.Module):

    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        # MLP 
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(embed_dim * mlp_ratio, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        res = x
        x = self.norm1(x)
        x_t = x.transpose(0, 1)              
        attn_output, _ = self.attn(x_t, x_t, x_t)  
        attn_output = attn_output.transpose(0, 1)  
        x = res + self.dropout(attn_output)
        # Feed-forward MLP with residual
        res2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = res2 + self.dropout(x)
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer (ViT) model for image classification."""
    def __init__(self, image_size=32, patch_size=4, in_channels=3,
                 num_classes=10, embed_dim=128, num_heads=8,
                 num_layers=4, mlp_ratio=4, dropout=0.1):
        super().__init__()
        # Embedding layer
        self.embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim, dropout)
        # Transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        # Final classification head 
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):

        x = self.embedding(x)  
        for layer in self.encoder_layers:
            x = layer(x)
        # Extract the CLS token 
        cls_token_final = x[:, 0]  # (B, embed_dim)
        out = self.classifier(cls_token_final)  # (B, num_classes)
        return out

model = VisionTransformer(
    image_size=image_size, patch_size=patch_size, in_channels=3,
    num_classes=num_classes, embed_dim=embed_dim,
    num_heads=num_heads, num_layers=num_layers,
    mlp_ratio=mlp_ratio, dropout=dropout_rate
).to(device)
print(model)


VisionTransformer(
  (embedding): PatchEmbedding(
    (proj): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0-7): 8 x TransformerEncoderLayer(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
      )
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1536, out_features=384, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (classifier): Linear(in_features=384, out_features=10, bias=True)
)


In [3]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=0.05 )


#scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs*0.5), int(epochs*0.75)], gamma=0.1)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)                   # (B, num_classes)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

      
        total_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)
    # Update scheduler
    scheduler.step()

    train_loss = total_loss / total_samples
    train_acc = total_correct / total_samples * 100
    # Evaluate on test set
    model.eval()
    test_correct = 0
    test_samples = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            test_correct += (preds == labels).sum().item()
            test_samples += labels.size(0)
    test_acc = test_correct / test_samples * 100

    print(f"Epoch {epoch:2d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")


Epoch  1: Train Loss=2.0251, Train Acc=30.59%, Test Acc=44.67%
Epoch  2: Train Loss=1.7463, Train Acc=42.23%, Test Acc=50.49%
Epoch  3: Train Loss=1.6580, Train Acc=46.68%, Test Acc=52.65%
Epoch  4: Train Loss=1.5890, Train Acc=49.58%, Test Acc=53.70%
Epoch  5: Train Loss=1.5390, Train Acc=52.09%, Test Acc=57.77%
Epoch  6: Train Loss=1.4963, Train Acc=54.20%, Test Acc=59.32%
Epoch  7: Train Loss=1.4559, Train Acc=56.37%, Test Acc=58.30%
Epoch  8: Train Loss=1.4231, Train Acc=58.03%, Test Acc=61.57%
Epoch  9: Train Loss=1.3967, Train Acc=59.20%, Test Acc=62.68%
Epoch 10: Train Loss=1.3710, Train Acc=60.32%, Test Acc=62.84%
Epoch 11: Train Loss=1.3414, Train Acc=61.79%, Test Acc=65.36%
Epoch 12: Train Loss=1.3102, Train Acc=63.57%, Test Acc=65.24%
Epoch 13: Train Loss=1.2913, Train Acc=64.11%, Test Acc=66.72%
Epoch 14: Train Loss=1.2629, Train Acc=65.76%, Test Acc=67.78%
Epoch 15: Train Loss=1.2457, Train Acc=66.22%, Test Acc=69.49%
Epoch 16: Train Loss=1.2240, Train Acc=67.18%, Test Acc

In [4]:
print(test_acc)

79.55
