In [1]:
# importing necessary libraries

import torch
import torchvision
import torch.utils.data as dataloader
import torch.nn as nn

In [2]:
from torchvision import datasets, transforms

# import dataset
# Define transforms
transform = transforms.Compose([transforms.ToTensor()])

# Download and load MNIST dataset
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
mnist_val = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
transform = torchvision.transforms.Compose([transforms.ToTensor()])

100%|██████████| 9.91M/9.91M [00:00<00:00, 62.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.85MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.24MB/s]


In [3]:
# define batches 
train_loader = dataloader.DataLoader(mnist_train, batch_size=64, shuffle=True)
val_loader = dataloader.DataLoader(mnist_val, batch_size=64, shuffle=True)

In [4]:
# define variable 
num_classes = 10
batch_size =64
num_channels = 1
image_size = 28
patch_size = 7
num_patches = (image_size // patch_size) ** 2
embedding_dim = 64
attention_heads =2
transformer_blocks = 4
learning_rate = 0.001
epochs = 5
mlp_hidden_nodes = 128



In [5]:
# sample a data point from train_loader
sample_data = next(iter(train_loader))
images, labels = sample_data
print(f'Image batch shape: {images.size()}')
print(f'Label batch shape: {labels.size()}')


Image batch shape: torch.Size([64, 1, 28, 28])
Label batch shape: torch.Size([64])


In [6]:
patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)
print( patch_embed(images).shape) 

patch_embed_output_flatten=patch_embed(images).flatten(2) 
print(patch_embed_output_flatten.shape) 
print (patch_embed_output_flatten.transpose(1, 2).shape) # Expected output: (batch_size, num_patches, embedding_dim)

torch.Size([64, 64, 4, 4])
torch.Size([64, 64, 16])
torch.Size([64, 16, 64])


In [7]:
### part 1 : patch embedding
### part 2 : transformer encoder
### part 3: mlp head
### transformer class

### part 1 : patch embedding

In [8]:
class PatchEmbedding(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.patch_embed=nn.Conv2d(num_channels,embedding_dim,kernel_size=patch_size,stride=patch_size, )
        
    def forward(self,x):
        # patch embedding
        x= self.patch_embed(x)
        # flatten the patches
        x= x.flatten(2)
        x= x.transpose(1,2) # (batch_size, num_patches, embedding_dim)
        return x
    
        
    

### part 2 : Encoder

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)
        self.multihead_attention = nn.MultiheadAttention(embedding_dim,attention_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim,mlp_hidden_nodes),
            nn.GELU(),
            nn.Linear(mlp_hidden_nodes,embedding_dim),
        )
        
    def forward(self,x):
        residual1 = x
        x = self.layer_norm1(x)
        x= self.multihead_attention(x,x,x)[0]
        x = x + residual1
        
        residual2 = x 
        x= self.layer_norm2(x)
        x=self.mlp(x)
        x= x + residual2 
        
        return x 
        

### part 3 : MLP Head

In [10]:
class MLP_head(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear (embedding_dim, num_classes)
        
    def forward(self,x):
        x= self.layer_norm1 (x)
        x= self.mlp_head (x)
        
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__ (self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        self.clas_token = nn.Parameter(torch.randn(1,1,embedding_dim))
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1 , embedding_dim))
        self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
        self.mlp_head = MLP_head()
        
  
    