In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
image_size = 28
patch_size = 7
num_classes = 10
dim = 128
depth = 6 
heads = 8
mlp_dim = 256
lr = 3e-4
epochs = 5

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [13]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, patch_size=7, emb_size=128, img_size=28):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches_h = img_size // patch_size
        self.n_patches_w = img_size // patch_size
        self.n_patches = self.n_patches_h * self.n_patches_w
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_size=128, heads=8, mlp_dim=256, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.attn = nn.MultiheadAttention(embed_dim=emb_size, num_heads=heads, batch_first=True)
        self.norm2 = nn.LayerNorm(emb_size)
        self.mlp = nn.Sequential(
            nn.Linear(emb_size, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, emb_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x2 = self.norm1(x)
        attn_out, _ = self.attn(x2, x2, x2)
        x = x + self.dropout(self.mlp(x2))
        x2 = self.norm2(x)
        x = x + self.dropout(self.mlp(x2))
        return x

In [15]:
class ViT(nn.Module):
    def __init__(self, in_channels=1, patch_size=7, emb_size=128, img_size=28, depth=6, heads=8, mlp_dim=256, num_classes=10, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.n_patches + 1, emb_size))
        self.dropout = nn.Dropout(dropout)
        self.transformer = nn.Sequential(*[TransformerBlock(emb_size, heads, mlp_dim, dropout) for _ in range(depth)])
        self.mlp_head = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B = x.size(0)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        x = self.transformer(x)
        cls_out = x[:, 0]
        out = self.mlp_head(cls_out)
        return out

In [16]:
model = ViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [17]:
for epoch in range(epochs):
    model.train()
    loop = tqdm(train_loader)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
        loop.set_postfix(loss=loss.item())

Epoch [1/5]: 100%|███████████████████████████████████████████████████| 938/938 [01:33<00:00, 10.01it/s, loss=2.41]
Epoch [2/5]: 100%|████████████████████████████████████████████████████| 938/938 [01:32<00:00, 10.18it/s, loss=2.3]
Epoch [3/5]: 100%|███████████████████████████████████████████████████| 938/938 [01:35<00:00,  9.87it/s, loss=2.29]
Epoch [4/5]: 100%|███████████████████████████████████████████████████| 938/938 [01:39<00:00,  9.40it/s, loss=2.28]
Epoch [5/5]: 100%|███████████████████████████████████████████████████| 938/938 [01:40<00:00,  9.30it/s, loss=2.31]


In [18]:
model.eval()
images, labels = next(iter(train_loader))
images = images.to(device)
with torch.no_grad():
    outputs = model(images)
preds = outputs.argmax(dim=1)
print("Predicted:", preds[:10].cpu().numpy())
print("Ground truth:", labels[:10].numpy())


Predicted: [1 1 1 1 1 1 1 1 1 1]
Ground truth: [9 4 5 7 4 8 3 7 6 9]
