In [19]:
import torch
import torch.nn as nn
from torch.optim import Adam


In [20]:
class PatchEmbedding(nn.Module):
  def __init__(self,imgSize,patchSize,numOfChannels=1,embeddingDimension=769):
    super().__init__()
    self.imgSize=imgSize
    self.patchSize=patchSize
    self.numOfPatches = (imgSize//patchSize)**2
    self.projection = nn.Conv2d(numOfChannels,embeddingDimension,kernel_size=patchSize,stride=patchSize)
  
  def forward(self,input):
    input = self.projection(input) #(number of samples,embeding dimension, sqrt number of patches,sqrt number of patches )
    input = input.flatten(2) #(number of samples,embeding dimension,number of patches)
    input = input.transpose(1,2) #(number of samples,number of patches, embeding dimension)
    return input

In [21]:
class AttentionModule(nn.Module):
  def __init__(self,dimensions,numOfHeads=12,queryKeyValueBias=True,kvpDropoutProbability=0.,projectionDropoutProbability=0.):
    super().__init__()
    self.numOfHeads = numOfHeads
    self.dimensions = dimensions
    self.headDimension = dimensions//numOfHeads
    self.normalizationFactor = self.headDimension** -0.5  # extremly large values to softmax -> small gradients
    
    self.queryKeyValue = nn.Linear(dimensions,dimensions*3,bias =queryKeyValueBias )
    self.kvpDropout = nn.Dropout(kvpDropoutProbability)
    self.projection = nn.Linear(dimensions,dimensions)
    self.projectionDropout = nn.Dropout(projectionDropoutProbability)

  def forward(self,input):
    numOfSamples,numOfTokens,dimensions = input.shape
    if dimensions != self.dimensions:
      raise ValueError("Dimensions shape in Attention Module")
    queryKeysValues = self.queryKeyValue(input) #(number of samples, number of patches+1, dimensions *3)
    queryKeysValues = queryKeysValues.reshape(numOfSamples,numOfTokens,3,self.numOfHeads,self.headDimensions)
    queryKeysValues = queryKeysValues.permute(2,0,3,1,4) #(3,number of samples, number of heads,number of patches +1, head dimenson)
    query,key,value = queryKeysValues[0],queryKeysValues[1],queryKeysValues[2]
    keyTranspose = key.Transpose(-2,-1) #(number of samples, number of heads,head dimension,number of patches)
    dotProduct = (query @ keyTranspose) * self.normalizationFactor # (number of samples,number of heads,number of patches +1, number of patches + 1)
    attention = dotProduct.softmax(dim=-1) # (number of samples,number of heads,number of patches +1, number of patches + 1)
    attention = self.kvpDropout(attention)

    weightedAverage = attention @ value # (number of samples,number of heads,number of patches +1, head dimension)
    weightedAverage = weightedAverage.Transpose(1,2) #(number of samples, number of patches + 1,number of heads,head dimension)
    weightedAverage = weightedAverage.flatten(2) #(number of samples, number of patches +1, dimensions)
    input = self.projection(weightedAverage)
    input = self.projectionDropout(input)
    return input
     

In [22]:
class MLP(nn.Module):
   def __init__(self,inputFeatures,hiddenFeatures,outputFeatures,prob=0.):
      super().__init__()
      self.fullyConnected1= nn.Linear(inputFeatures,hiddenFeatures)
      self.activateion = nn.GELU()
      self.fulltConnected2= nn.Linear(hiddenFeatures,outputFeatures)
      self.dropout = nn.Dropout(prob)
   
   def forward (self,input):
     input = self.fullyConnected1(input)
     input = self.activateion(input)
     input = self.dropout(input)
     input = self.fullyConnected2(input)
     input = self.dropout(input)
     return input

In [23]:
class TransformerBlock(nn.Module):
  def __init__(self,dimensions,numOfHeads,MLPRatio,QKVBias=True,p=0.,attentionProb=0.):
    super(TransformerBlock,self).__init__()
    self.normalization1 = nn.LayerNorm(dimensions,eps=1e-6)
    self.attention = AttentionModule(dimensions,numOfHeads,QKVBias,attentionProb,p)
    self.normalization2= nn.LayerNorm(dimensions,eps=1e-6)
    hiddenFeatures = int(dimensions*MLPRatio)
    self.mlp = MLP(dimensions,hiddenFeatures,dimensions)

  def forward(self,input):
    #residual block
    input = input + self.attention(self.normalization1(input))
    input = input + self.mlp(self.normalization2(input))
    return input
     

In [24]:
class VisionTransformer(nn.Module):
  def __init__(self,imgSize,patchSize,numOfChannels,numOfClasses,embeddingDim,depth, numOfHeads,MLPRatio,QKVBias,p=0.,attentionProb=0.):
    super(VisionTransformer,self).__init__()
    self.patchEmbedding = PatchEmbedding(imgSize,patchSize,numOfChannels,embeddingDim)
    self.classToken = nn.Parameter(torch.zeros(1,1,embeddingDim))
    self.positionalEmbedding = nn.Parameter(torch.zeros(1,1+self.patchEmbedding.numOfPatches,embeddingDim))
    self.positionDropout = nn.Dropout(p)
    self.blocks = nn.ModuleList([TransformerBlock(embeddingDim,numOfHeads,MLPRatio,QKVBias,p,attentionProb) for ctr in range (depth)])
    self.normalization = nn.LayerNorm(embeddingDim,eps= 1e-6)
    self.head = nn.Linear(embeddingDim,numOfClasses)
 
  def foward(self,input):
    numOfSamples = input.shape[0]
    input = self.patchEmbedding(input)
    classTokens = self.classToken.expand(numOfSamples,-1,-1)
    input = torch.cat((classTokens,input),dim=1)
    input = input + self.positionalEmbedding
    input = self.positionDropout(input)
    for block in self.blocks:
      input = block(input)
    input = self.normalization(input)
    finalClassTokens = input[:,0]
    input = self.head(finalClassTokens)
    return input

In [25]:
visionTransformer = VisionTransformer(200,20,1,5,768,12,8,0.4,True,0.3,0.2)

In [None]:
lossFunction = nn.CrossEntropyLoss()
optimizer = Adam(visionTransformer.parameters(),lr=0.001,weight_decay=0.0001)
