In [40]:
#import libraries
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

In [41]:
#variables
batch_size = 64
img_size = 28
patch_size = 7
num_channels = 1
num_patches = (img_size // patch_size) ** 2
num_heads = 1
embed_dim = 16
mlp_dim = 16
transformer_units = 1

In [42]:
transform = transforms.Compose(
    [transforms.ToTensor()])

In [43]:
#load dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
valset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)

In [44]:
#create train and val batches
train_data = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
val_data = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=False)

In [45]:
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__()
    self.proj=nn.Conv2d(in_channels=num_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)

  def forward(self,x):
    x=self.proj(x)
    x=x.flatten(2)
    x=x.transpose(1,2)
    return x

In [46]:
class Block(nn.Module):
  def __init__(self):
    super().__init__()
    self.ln1=nn.LayerNorm(embed_dim)
    self.attn=nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
    self.ln2=nn.LayerNorm(embed_dim)
    self.mlp=nn.Sequential(
        nn.Linear(embed_dim,mlp_dim),
        nn.GELU(),
        nn.Linear(mlp_dim,embed_dim)
    )
  def forward(self,x):
    x=x+self.attn(self.ln1(x),self.ln1(x),self.ln1(x))[0]
    x=x+self.mlp(self.ln2(x))
    return x



In [47]:
class ViT(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embed=PatchEmbedding()
    self.cls_token=nn.Parameter(torch.randn(1,1,embed_dim))
    self.pos_embed=nn.Parameter(torch.randn(1,num_patches+1,embed_dim))
    self.blocks=nn.ModuleList([Block() for _ in range(transformer_units)])
    self.mlp_head=nn.Sequential(nn.LayerNorm(embed_dim),nn.Linear(embed_dim,10))

  def forward(self,x):
    x=self.patch_embed(x)
    B=x.size(0)
    cls_token=self.cls_token.expand((B,-1,-1))
    x=torch.cat((cls_token,x),dim=1)
    x=x+self.pos_embed
    for block in self.blocks:
      x=block(x)
    x=x[:,0]
    x=self.mlp_head(x)
    return x

In [48]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [49]:
for epoch in range(5):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        preds = outputs.argmax(dim=1)

        correct = (preds == labels).sum().item()
        accuracy = 100.0 * correct / labels.size(0)

        correct_epoch += correct
        total_epoch += labels.size(0)

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx+1:3d}: Loss = {loss.item():.4f}, Accuracy = {accuracy:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")


Epoch 1
  Batch   1: Loss = 2.5558, Accuracy = 4.69%
  Batch 101: Loss = 1.0887, Accuracy = 62.50%
  Batch 201: Loss = 0.9266, Accuracy = 68.75%
  Batch 301: Loss = 0.7843, Accuracy = 68.75%
  Batch 401: Loss = 0.5988, Accuracy = 78.12%
  Batch 501: Loss = 0.6998, Accuracy = 78.12%
  Batch 601: Loss = 0.6530, Accuracy = 70.31%
  Batch 701: Loss = 0.6283, Accuracy = 79.69%
  Batch 801: Loss = 0.6526, Accuracy = 75.00%
  Batch 901: Loss = 0.4347, Accuracy = 84.38%
==> Epoch 1 Summary: Total Loss = 795.1363, Accuracy = 70.54%

Epoch 2
  Batch   1: Loss = 0.5333, Accuracy = 81.25%
  Batch 101: Loss = 0.4872, Accuracy = 84.38%
  Batch 201: Loss = 0.6539, Accuracy = 76.56%
  Batch 301: Loss = 0.6245, Accuracy = 78.12%
  Batch 401: Loss = 0.7325, Accuracy = 76.56%
  Batch 501: Loss = 0.5388, Accuracy = 76.56%
  Batch 601: Loss = 0.5595, Accuracy = 79.69%
  Batch 701: Loss = 0.3659, Accuracy = 90.62%
  Batch 801: Loss = 0.6799, Accuracy = 76.56%
  Batch 901: Loss = 0.4001, Accuracy = 87.50%
=

In [50]:
total_params = sum(p.numel() for p in model.parameters())
print("Total parameters:", total_params)


Total parameters: 2986


In [51]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_data:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = 100.0 * correct / total
print(f"\n==> Val Accuracy: {test_acc:.2f}%")



==> Val Accuracy: 89.39%
