In [18]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.utils.data as dataloader
import torch.nn as nn

In [19]:
transformation_operation= transforms.Compose([transforms.ToTensor()])

In [20]:
train_dataset= torchvision.datasets.MNIST(root='./data',
                                          train= True,
                                          download= True,
                                          transform=transformation_operation )

val_dataset= torchvision.datasets.MNIST(root='./data',
                                          train= False,
                                          download= True,
                                          transform=transformation_operation )

In [63]:
batch_size=64
num_classes=10
num_channels=1 #since black and white image
img_size=28
patch_size=7
patch_num= (img_size//patch_size)*(img_size//patch_size)
attention_heads=4
embed_dim =  20    #size of the embedding vector
transformer_blocks=4
mlp_nodes= 64        #number of neurons in the mlp hidden layers
learning_rate= 0.001

In [22]:
# convert to batches
train_data= dataloader.DataLoader(train_dataset, shuffle= True, batch_size=batch_size)
val_data= dataloader.DataLoader(val_dataset, shuffle= True, batch_size=batch_size)

In [74]:
#class for patch embedding- part 1 of VIT
class PatchEmbedding(nn.Module):
  def __init__(self):                 #we'll not code attention from scratch- its not our goal
    super().__init__()
    self.patch_embed= nn.Conv2d(num_channels, embed_dim, kernel_size= patch_size, stride=patch_size)
    # initial size of image(61,16,embed_dim)-> (batch_size, embed_dim, num_patches_per_row, num_patches_per_col)
  def forward(self,x):
    x= self.patch_embed(x)
    x= x.flatten(2)
    x=x.transpose(1,2)
    return x
    # or return self.patch_embed(x).flatten(2).transpose(1,2)


In [51]:
images, labels = next(iter(train_data))
print(images.shape)

patch_embed= nn.Conv2d(num_channels, embed_dim, kernel_size= patch_size, stride=patch_size)
embedded_image= patch_embed(images)
print("this is the size after conv2d ",embedded_image.shape)
embedded_image[:5]
# print("hare krishna ",embedded_image.flatten(2))
print("hare krishna ",embedded_image.flatten(2).transpose(1,2).shape)

torch.Size([64, 1, 28, 28])
this is the size after conv2d  torch.Size([64, 20, 4, 4])
hare krishna  torch.Size([64, 16, 20])


In [76]:
#class for transformer encoder- Part 2 to VIt
#Layer normalisation
#multihead attention
# layer normalisation
# residuals
# mlp and activation function
class TransformerEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1= nn.LayerNorm(embed_dim)
    self.multi_head_attention= nn.MultiheadAttention(embed_dim,attention_heads, batch_first=True)
    self.layer_norm2= nn.LayerNorm(embed_dim)
    self.mlp= nn.Sequential(
          nn.Linear(embed_dim, mlp_nodes),
          nn.GELU(),
          nn.Linear(mlp_nodes, embed_dim)
          # nn.GELU(),
          # nn.Linear(embed_dim)
    )

  def forward(self,x):
    residual1=x
    x= self.layer_norm1(x)
    x= self.multi_head_attention(x,x,x)[0]  +residual1 #for k, q, v
    residual2= x
    x= self.layer_norm2(x)
    x= self.mlp(x) + residual2
    return x




In [78]:
#class for MLP head for classification - PART 3 of VIT
class MLP_Head(nn.Module):
  def __init__(self):
    super().__init__()
    self.layernorm1= nn.LayerNorm(embed_dim)   #this is not shown in paper, but its done
    self.mlphead= nn.Sequential(
        # nn.Linear(embed_dim)
        nn.Linear(embed_dim,num_classes)
    )
  def forward(self,x):
    # x=x[:,0]   #cls token
    x= self.layernorm1(x)
    x= self.mlphead(x)
    return x

In [64]:
class VisionTransformer(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embedding= PatchEmbedding()
    self.cls_token=nn.Parameter(torch.randn(1,1,embed_dim))
    self.position_embedding=nn.Parameter(torch.randn(1,patch_num +1, embed_dim))
    self.transformer_blocks= nn.Sequential(* [TransformerEncoder() for _ in range (transformer_blocks) ])
    self.mlp_head=MLP_Head()

  def forward(self,x):
    x = self.patch_embedding(x)
    B=x.shape[0]   #B= batch_size is not perfect, what if total images in last batch not divisible by batch size
    cls_tokens= self.cls_token.expand(B,-1,-1)
    x=torch.cat((cls_tokens,x),1)
    x=x+self.position_embedding
    x= self.transformer_blocks(x)
    x=x[:,0]   #taking only the cls
    x= self.mlp_head(x)
    return x

In [70]:
# optimiser, device, crossentopyloss
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
model= VisionTransformer().to(device)
optimizer= torch.optim.Adam(VisionTransformer().parameters(), lr= learning_rate)
criterion= nn.CrossEntropyLoss()

In [79]:
# Define hyperparameters
epochs = 5
learning_rate = 0.001

# Model, optimizer and loss function setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training Loop
for epoch in range(epochs):
    model.train()
    total_correct = 0
    total_samples = 0
    for images, labels in train_data:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)              # Forward pass
        loss = criterion(outputs, labels)    # Loss calculation
        loss.backward()                      # Backpropagation
        optimizer.step()                     # Optimizer step

        predicted = outputs.argmax(dim=1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    train_acc = 100 * total_correct / total_samples
    print(f"Epoch {epoch+1}: Training Accuracy {train_acc:.2f}%")

    # Validation Loop
    model.eval()
    val_correct = 0
    val_samples = 0
    with torch.no_grad():
        for images, labels in val_data:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predicted = outputs.argmax(dim=1)
            val_correct += (predicted == labels).sum().item()
            val_samples += labels.size(0)
    val_acc = 100 * val_correct / val_samples
    print(f"Epoch {epoch+1}: Validation Accuracy {val_acc:.2f}%")




Epoch 1: Training Accuracy 82.40%
Epoch 1: Validation Accuracy 93.38%
Epoch 2: Training Accuracy 94.16%
Epoch 2: Validation Accuracy 95.02%
Epoch 3: Training Accuracy 95.58%
Epoch 3: Validation Accuracy 95.55%
Epoch 4: Training Accuracy 96.30%
Epoch 4: Validation Accuracy 96.34%
Epoch 5: Training Accuracy 96.87%
Epoch 5: Validation Accuracy 96.37%
