In [5]:
!pip install torch torchvision einops tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
import numpy as np
import random



In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, emb_size=256, img_size=32):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

class Attention(nn.Module):
    def __init__(self, emb_size=256, num_heads=8, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=emb_size, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out, _ = self.mha(x,x,x)
        return self.dropout(out)

class TransformerBlock(nn.Module):
    def __init__(self, emb_size=256, num_heads=8, mlp_ratio=4., dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.norm2 = nn.LayerNorm(emb_size)
        self.attn = Attention(emb_size, num_heads, dropout)
        self.mlp = nn.Sequential(
            nn.Linear(emb_size, int(emb_size*mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(emb_size*mlp_ratio), emb_size),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 emb_size=256, depth=8, num_heads=8, mlp_ratio=4., dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
        self.pos_embed = nn.Parameter(torch.randn(1, (img_size//patch_size)**2 + 1, emb_size))
        self.blocks = nn.ModuleList([TransformerBlock(emb_size, num_heads, mlp_ratio, dropout) for _ in range(depth)])
        self.norm = nn.LayerNorm(emb_size)
        self.head = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        cls_out = self.norm(x[:,0])
        out = self.head(cls_out)
        return out

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ViT().to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

epochs = 100
best_acc = 0.0


for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    loop = tqdm(train_loader, leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        loop.set_description(f"Epoch [{epoch+1}/{epochs}] Loss: {loss.item():.4f}")

    scheduler.step()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            correct += (preds.argmax(1) == labels).sum().item()
            total += labels.size(0)
    acc = correct/total

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "vit_cifar10.pth")

    print(f"Epoch {epoch+1}/{epochs} | Test Acc: {acc:.4f} | Best: {best_acc:.4f}")

print("Accuracy:", best_acc)



Epoch 1/100 | Test Acc: 0.4216 | Best: 0.4216




Epoch 2/100 | Test Acc: 0.4755 | Best: 0.4755




Epoch 3/100 | Test Acc: 0.5022 | Best: 0.5022




Epoch 4/100 | Test Acc: 0.5361 | Best: 0.5361




Epoch 5/100 | Test Acc: 0.5709 | Best: 0.5709




Epoch 6/100 | Test Acc: 0.5587 | Best: 0.5709




Epoch 7/100 | Test Acc: 0.5950 | Best: 0.5950




Epoch 8/100 | Test Acc: 0.5916 | Best: 0.5950




Epoch 9/100 | Test Acc: 0.6141 | Best: 0.6141




Epoch 10/100 | Test Acc: 0.6196 | Best: 0.6196




Epoch 11/100 | Test Acc: 0.6383 | Best: 0.6383




Epoch 12/100 | Test Acc: 0.6531 | Best: 0.6531




Epoch 13/100 | Test Acc: 0.6498 | Best: 0.6531




Epoch 14/100 | Test Acc: 0.6643 | Best: 0.6643




Epoch 15/100 | Test Acc: 0.6669 | Best: 0.6669




Epoch 16/100 | Test Acc: 0.6893 | Best: 0.6893




Epoch 17/100 | Test Acc: 0.6978 | Best: 0.6978




Epoch 18/100 | Test Acc: 0.7042 | Best: 0.7042




Epoch 19/100 | Test Acc: 0.7040 | Best: 0.7042




Epoch 20/100 | Test Acc: 0.7045 | Best: 0.7045




Epoch 21/100 | Test Acc: 0.7133 | Best: 0.7133




Epoch 22/100 | Test Acc: 0.7147 | Best: 0.7147




Epoch 23/100 | Test Acc: 0.7296 | Best: 0.7296




Epoch 24/100 | Test Acc: 0.7379 | Best: 0.7379




Epoch 25/100 | Test Acc: 0.7434 | Best: 0.7434




Epoch 26/100 | Test Acc: 0.7440 | Best: 0.7440




Epoch 27/100 | Test Acc: 0.7484 | Best: 0.7484




Epoch 28/100 | Test Acc: 0.7429 | Best: 0.7484




Epoch 29/100 | Test Acc: 0.7658 | Best: 0.7658




Epoch 30/100 | Test Acc: 0.7635 | Best: 0.7658




Epoch 31/100 | Test Acc: 0.7619 | Best: 0.7658




Epoch 32/100 | Test Acc: 0.7569 | Best: 0.7658




Epoch 33/100 | Test Acc: 0.7631 | Best: 0.7658




Epoch 34/100 | Test Acc: 0.7681 | Best: 0.7681




Epoch 35/100 | Test Acc: 0.7665 | Best: 0.7681




Epoch 36/100 | Test Acc: 0.7673 | Best: 0.7681




Epoch 37/100 | Test Acc: 0.7715 | Best: 0.7715




Epoch 38/100 | Test Acc: 0.7745 | Best: 0.7745




Epoch 39/100 | Test Acc: 0.7778 | Best: 0.7778




Epoch 40/100 | Test Acc: 0.7790 | Best: 0.7790




Epoch 41/100 | Test Acc: 0.7803 | Best: 0.7803




Epoch 42/100 | Test Acc: 0.7798 | Best: 0.7803




Epoch 43/100 | Test Acc: 0.7847 | Best: 0.7847




Epoch 44/100 | Test Acc: 0.7839 | Best: 0.7847




Epoch 45/100 | Test Acc: 0.7884 | Best: 0.7884




Epoch 46/100 | Test Acc: 0.7913 | Best: 0.7913




Epoch 47/100 | Test Acc: 0.7901 | Best: 0.7913




Epoch 48/100 | Test Acc: 0.7905 | Best: 0.7913




Epoch 49/100 | Test Acc: 0.7914 | Best: 0.7914




Epoch 50/100 | Test Acc: 0.7900 | Best: 0.7914




Epoch 51/100 | Test Acc: 0.7894 | Best: 0.7914




Epoch 52/100 | Test Acc: 0.7942 | Best: 0.7942




Epoch 53/100 | Test Acc: 0.7904 | Best: 0.7942




Epoch 54/100 | Test Acc: 0.7915 | Best: 0.7942




Epoch 55/100 | Test Acc: 0.7955 | Best: 0.7955




Epoch 56/100 | Test Acc: 0.7878 | Best: 0.7955




Epoch 57/100 | Test Acc: 0.7956 | Best: 0.7956




Epoch 58/100 | Test Acc: 0.7980 | Best: 0.7980




Epoch 59/100 | Test Acc: 0.7949 | Best: 0.7980




Epoch 60/100 | Test Acc: 0.7929 | Best: 0.7980




Epoch 61/100 | Test Acc: 0.7972 | Best: 0.7980




Epoch 62/100 | Test Acc: 0.7991 | Best: 0.7991




Epoch 63/100 | Test Acc: 0.8061 | Best: 0.8061




Epoch 64/100 | Test Acc: 0.7943 | Best: 0.8061




Epoch 65/100 | Test Acc: 0.8021 | Best: 0.8061




Epoch 66/100 | Test Acc: 0.7996 | Best: 0.8061




Epoch 67/100 | Test Acc: 0.7989 | Best: 0.8061




Epoch 68/100 | Test Acc: 0.8031 | Best: 0.8061




Epoch 69/100 | Test Acc: 0.8043 | Best: 0.8061




Epoch 70/100 | Test Acc: 0.8012 | Best: 0.8061




Epoch 71/100 | Test Acc: 0.8018 | Best: 0.8061




Epoch 72/100 | Test Acc: 0.8021 | Best: 0.8061




Epoch 73/100 | Test Acc: 0.8077 | Best: 0.8077




Epoch 74/100 | Test Acc: 0.8032 | Best: 0.8077




Epoch 75/100 | Test Acc: 0.8030 | Best: 0.8077




Epoch 76/100 | Test Acc: 0.8068 | Best: 0.8077




Epoch 77/100 | Test Acc: 0.8051 | Best: 0.8077




Epoch 78/100 | Test Acc: 0.8079 | Best: 0.8079




Epoch 79/100 | Test Acc: 0.8073 | Best: 0.8079




Epoch 80/100 | Test Acc: 0.8048 | Best: 0.8079




Epoch 81/100 | Test Acc: 0.8039 | Best: 0.8079




Epoch 82/100 | Test Acc: 0.8040 | Best: 0.8079




Epoch 83/100 | Test Acc: 0.8049 | Best: 0.8079




Epoch 84/100 | Test Acc: 0.8072 | Best: 0.8079




Epoch 85/100 | Test Acc: 0.8059 | Best: 0.8079




Epoch 86/100 | Test Acc: 0.8062 | Best: 0.8079




Epoch 87/100 | Test Acc: 0.8094 | Best: 0.8094




Epoch 88/100 | Test Acc: 0.8082 | Best: 0.8094




Epoch 89/100 | Test Acc: 0.8085 | Best: 0.8094




Epoch 90/100 | Test Acc: 0.8066 | Best: 0.8094




Epoch 91/100 | Test Acc: 0.8047 | Best: 0.8094




Epoch 92/100 | Test Acc: 0.8063 | Best: 0.8094




Epoch 93/100 | Test Acc: 0.8061 | Best: 0.8094




Epoch 94/100 | Test Acc: 0.8077 | Best: 0.8094




Epoch 95/100 | Test Acc: 0.8064 | Best: 0.8094




Epoch 96/100 | Test Acc: 0.8078 | Best: 0.8094




Epoch 97/100 | Test Acc: 0.8077 | Best: 0.8094




Epoch 98/100 | Test Acc: 0.8069 | Best: 0.8094




Epoch 99/100 | Test Acc: 0.8067 | Best: 0.8094




Epoch 100/100 | Test Acc: 0.8068 | Best: 0.8094
Accuracy: 0.8094
