In [41]:
import torch 
import torch.nn as nn  
import matplotlib.pyplot as plt 
import torchvision
import numpy as np 
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [42]:
learning_rate = 0.001
weight_decay = 0.0001
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
batch_size = 64

In [43]:
num_classes = 100
input_shape = (32,32,3)
import torchvision.transforms as transforms
#dataset download 
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
                transforms.Resize(image_size)])

trainset = torchvision.datasets.CIFAR100(root='./data.cifar100', train=True,
                                    download=True, transform=transform)
testset = torchvision.datasets.CIFAR100(root='./data.cifar100', train=False,
                                    download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [44]:
train_dataloader = torch.utils.data.DataLoader(trainset,batch_size=64)
test_dataloader = torch.utils.data.DataLoader(testset,batch_size=64)

len_train, len_test = len(train_dataloader), len(test_dataloader)

data_iter = iter(train_dataloader)
img, label = next(data_iter)

In [45]:
"""


class Transformer(nn.Module):  
    def __init__(self,dim,depth,heads,head_dim,mlp_dim,drop_prob=0.0):   
        super(Attention,self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm_Residual(dim,nn.MultiheadAttention(dim,heads,drop_prob)), 
                PreNorm_Residual(dim,FeedForward(dim,mlp_dim,dropout=drop_prob))
            ]))
    
    def forward(self,x):  
        for attention, 

class PreNorm_Residual(nn.Module):
    def __init__(self,layer,dim):
        super(PreNorm,self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.layer = layer

    def forward(self,x):  
        return self.layer(self.norm(x)) + x

class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,drop_prob=0.0):
        super(FeedForward,self).__init__()
        self.fc1 = nn.Linear(dim,hidden_dim)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim,dim)
        self.drop = nn.Dropout(drop_prob)

    def forward(self,x): 
        x = self.drop(self.gelu(self.fc1(x))) 
        x = self.drop(self.gelu(self.fc2(x))) 
        return x 


class VIT(nn.Module):
   def __init__(self,transformer,image_size=image_size,patch_size=patch_size,num_classes=num_classes,projection_dim=projection_dim,num_patches=num_patches):
      super(VIT,self).__init__()
      assert image_size % patch_size == 0, "image size must be dividible by patch size" 
      self.num_patches =  num_patches
      self.patch_dim = 3 * patch_size**2
   
      self.flatten_patches = Rearrange('b c (h px1) (w px2) -> b (h w) (px1 px2 c)', p1 = patch_size, p2 = patch_size)
      self.patch_emedding = nn.Linear(self.patch_dim,projection_dim)

      self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, projection_dim))   #pose embedding 
      self.class_token = nn.Parameter(torch.randn(1, 1, projection_dim))   #class embedding 
      #output
      self.mlp_head = nn.Sequential(
               nn.LayerNorm(projection_dim),
               nn.Linear(projection_dim,num_classes)
      )

      self.transformer = transformer
      self.to_latent = nn.Identity()  
   
   def forward(self,img):
      #"pre process"
      x = self.flatten_patches(x)
      x = self.patch_emedding(x)

      b, n, _ = x.shape 
      class_token = repeat(self.class_token,'() n d -> b n d', b = b)
      x = torch.cat((class_token,x),dim=1)
      x += self.position_embedding[:,:(n+1)] 
      x = self.transformer(x)
      x = x[:,0]
      x = self.to_latent(x)
      x = self.mlp_head(x) 
      return x 
"""

In [177]:
class MLPS(nn.Module): 
    def __init__(self,hidden_units,drop_prob=0.0,hidden_size1=64):
        super(MLPS,self).__init__()
        self.layers = []
        hidden_units.insert(0,hidden_size1)
        for i in range(0,len(hidden_units)-1): 
            self.layers.append(nn.Linear(hidden_units[i],hidden_units[i+1]))
            self.layers.append(nn.Dropout(drop_prob))
        
        self.layers = nn.Sequential(*self.layers)
    
    def forward(self,x): 
        for layer in self.layers:
            x = layer(x)
        return x   


In [178]:
class TransformerBlock(nn.Module):
    def __init__(self,hidden_units,num_heads,projection_dim,dropout=0.0):
        super(TransformerBlock,self).__init__()
        self.hidden_units = hidden_units
        self.attention = nn.MultiheadAttention(projection_dim,num_heads,dropout=dropout)
        self.norm = nn.LayerNorm(projection_dim,eps=1e-6)
        self.mlp = MLPS(hidden_units)

    def forward(self,encoded_patches):
        x1 = encoded_patches
        #add norm layer here
        attention_out, out_weights = self.attention(x1,x1,x1)
        x2 = x1 + x1 #residual concatenation 
        #add norm layer here
        x3 = self.mlp(x2)
        encoded = x3 + x2
        return encoded
        
  

In [179]:
class VIT(nn.Module):
   def __init__(self,transformer=None,image_size=image_size,patch_size=patch_size,num_classes=num_classes,projection_dim=projection_dim,num_patches=num_patches,
   hidden_units=transformer_units, num_heads = num_heads,mlp_head_units=mlp_head_units, drop_prob=0.0):
      super(VIT,self).__init__()
      assert image_size % patch_size == 0, "image size must be dividible by patch size" 
      self.num_patches =  num_patches
      self.patch_dim = 3 * patch_size**2
   
      self.flatten_patches = Rearrange('b c (h px1) (w px2) -> b (h w) (px1 px2 c)', px1 = patch_size, px2 = patch_size)
      self.patch_emedding = nn.Linear(self.patch_dim,projection_dim)

      self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, projection_dim))   #pose embedding 
      self.class_token = nn.Parameter(torch.randn(1, 1, projection_dim))   #class embedding 
     
      self.transformer = TransformerBlock(hidden_units,num_heads,projection_dim,drop_prob)
      self.pose_embedding = nn.Embedding(num_patches+1,projection_dim)
      self.projection = nn.Linear(108,projection_dim)
      #output
      self.mlp_head = MLPS(mlp_head_units,drop_prob=0.5,hidden_size1=9216)

      self.to_latent = nn.Linear(mlp_head_units[-1],num_classes)
      self.flatten = nn.Flatten()
      self.drop = nn.Dropout(drop_prob)

   
   def forward(self,img):
      #"pre process"
      """
      x = self.flatten_patches(x)
      x = self.patch_emedding(x)
      """
      patch = self.flatten_patches(img)
      #patch encoding 
      positions = torch.arange(0,self.num_patches)
      encoded1 = self.pose_embedding(positions)
      encoded2 = self.projection(patch)
      encoded = encoded1 + encoded2
      block = self.transformer(encoded)
      block_norm = block #add norm here
      representation = self.drop(self.flatten(block_norm))
      features = self.mlp_head(representation)
      logits = self.to_latent(features)
      return logits 
   

In [181]:
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
mlp_head_units = [2048,1024] 
model = VIT(hidden_units=transformer_units)
img, label = next(data_iter)
img = img[0]
img = img.unsqueeze(0)
model(img)

4 64
layers Sequential(
  (0): Linear(in_features=64, out_features=128, bias=True)
  (1): Dropout(p=0.0, inplace=False)
  (2): Linear(in_features=128, out_features=64, bias=True)
  (3): Dropout(p=0.0, inplace=False)
)
layers Sequential(
  (0): Linear(in_features=9216, out_features=9216, bias=True)
  (1): Dropout(p=0.5, inplace=False)
  (2): Linear(in_features=9216, out_features=2048, bias=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=2048, out_features=1024, bias=True)
  (5): Dropout(p=0.5, inplace=False)
)
img torch.Size([1, 3, 72, 72])


1