<a href="https://colab.research.google.com/github/NilakshanKunananthaseelan/research/blob/main/VisionTransformers/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
pip install einops

In [None]:

import torch
import torch.nn as nn
import torch.nn. functional as F

import PIL 
import torchvision

from einops import rearrange, repeat



class LayerNorm(nn.Module):
	"""A specialized dropout layer for use with Monte Carlo dropout inference.
  Can operates exactly as tf.keras.layers.Dropout during training or inference, 
  but has the option to generate and store a collection of masks for 
  reproducible Monte Carlo dropout inference (see 'fix_masks'). Requires 
  output_shape to be set as when used in a tf.keras.Model.
  
  Attributes
  ----------
  self.fixed : Bool
    Indicates whether or not the layers have been fixed
  self.masks : array_like
    A length x output_shape 
  Methods
  -------
  call
    Can be called like a standard dropout layer, or optionally with with a 
    'sample' parameter to select which of the fixed masks to apply.
  
  fix_masks
    Generates a sequence of 'length' masks and stores them in the layer for 
    future re-use. Sets 'fixed' attribute to true.
  """




	def __init__(self,normalize_dim,function):
		super().__init__()

		self.layer_norm = nn.LayerNorm(normalize_dim)
		self.function   = function

	def forward(self,x,**kwargs):
		x = self.layer_norm(x)
		x = self.function(x,**kwargs)

		return x



class ResidualConnect(nn.Module):

	def __init__(self,function):
		super().__init__()
		self.function = function

	def forward(self,x,**kwargs):
		x = self.function(x,**kwargs)+x
		return x



In [None]:
import torch 
import torch.nn as nn 


class MLP_Block(nn.Module):

  def __init__(self,input_dim,hidden_dim,dropout=0.1):
    super().__init__()

    self.fc1 = nn.Linear(input_dim,hidden_dim,bias=True)
    self.gelu1 = nn.GELU()
    self.dropou1 = nn.Dropout(dropout)

    self.fc2 = nn.Linear(hidden_dim,input_dim,bias=True)
    self.gelu2 = nn.GELU()
    self.dropou2 = nn.Dropout(dropout)

    self.init_weights()

  def init_weights(self):


    torch.nn.init.xavier_uniform_(self.fc1.weight)
    self.fc1.bias.data.fill_(0.01)


    torch.nn.init.xavier_uniform_(self.fc2.weight)
    self.fc2.bias.data.fill_(0.01)

  def forward(self,x):
    x = self.fc1(x)
    x = self.gelu1(x)
    x = self.dropou1(x)

    x = self.fc2(x)
    x = self.gelu1(x)
    x = self.dropou2(x)

    return x


In [None]:

#https://jalammar.github.io/illustrated-transformer/

import torch
import torch.nn as nn
import torch.nn. functional as F

import PIL 
import torchvision



class Attention(nn.Module):

  def __init__(self,input_dim,heads=8,dropout=0.1):
    super().__init__()

    self.heads = heads
    self.scale = input_dim**0.5

    self.mat_calc = nn.Linear(input_dim,input_dim*3)
    self.softmax = nn.Softmax(dim=-1)

    self.fc1 = nn.Linear(input_dim,input_dim,bias=True)
    self.dropou1 = nn.Dropout(dropout)


    self.init_weights()

  def init_weights(self):


    torch.nn.init.xavier_uniform_(self.mat_calc.weight)
    self.mat_calc.bias.data.fill_(0.01)


    torch.nn.init.xavier_uniform_(self.fc1.weight)
    self.fc1.bias.data.fill_(0.01)


  def forward(self,x,mask=None):

    b, n, _= x.shape
    h = self.heads
    qkv = self.mat_calc(x)

    q,k,v = rearrange(qkv,'b n (w h d) -> w b h n d',w=3,h=8)/self.scale

    dot_product = torch.einsum('bhid,bhjd->bhij',q,k)

    if mask is not None:

      mask = F.pad(mask.flatten(1), (1, 0), value = True)
      assert mask.shape[-1] == dot_product.shape[-1], 'mask has incorrect dimensions'
      mask = mask[:, None, :] * mask[:, :, None]
      dot_product.masked_fill_(~mask, float('-inf'))
      del mask



    softmax_out = self.softmax(dot_product)

    attention_matrices = torch.einsum('bhij,bhjd->bhid',softmax_out,v)

    out = rearrange(attention_matrices,'b h n d->b n (h d)')
    out = self.fc1(out)
    out = self.dropou1(out)
    return out

class EncoderLayers(nn.Module):
  def __init__(self,input_dim,heads,mlp_dim,dropout=0.1):
    super().__init__()

    self.attention_block = ResidualConnect(LayerNorm(input_dim,Attention(input_dim,heads=heads,dropout=dropout))) 
    self.mlp_block = ResidualConnect(LayerNorm(input_dim,MLP_Block(input_dim,mlp_dim,dropout=dropout)))


  # def forward(self,x,mask=None):
  #   x = self.attention_block(x,mask=mask)
  #   x = self.mlp_block(x)

  #   return x


class TransformerEncoder(nn.Module):

  def __init__(self,input_dim,depth,heads,mlp_dim,dropout=0.1):
    super().__init__()

    self.encoder_layers = nn.ModuleList([])

    for _ in range(depth):
      self.encoder_layers.append(nn.ModuleList([ResidualConnect(LayerNorm(input_dim,Attention(input_dim,heads=heads,dropout=dropout))),
      ResidualConnect(LayerNorm(input_dim,MLP_Block(input_dim,mlp_dim,dropout=dropout)))]))

  def forward(self,x,mask=None):
      for attention_block,mlp_block in self.encoder_layers:

        x = attention_block(x,mask=mask)
        x = mlp_block(x)
      return x





In [None]:
t = TransformerEncoder(64,1,8,128)
x = torch.randint(0,1,(1,64,64)).float()
def xsqr(x):
  return x*x

r = ResidualConnect(xsqr)
l = LayerNorm(1,xsqr)
l(torch.tensor(5).reshape(-1,1
                          ).float())

t(x)

In [None]:
  import torch
  import torch.nn as nn

  from einops import rearrange, repeat
 

  class ViT(nn.Module):
    def __init__(self,*,image_size,patch_size,num_classes,input_dim,depth,heads,mlp_dim,channels=3,dropout=0.1,embedding_dropout=0.1):
      super().__init__()


      assert not(image_size%patch_size),'image dimension should be factor of patch dimension'

      num_patches = (image_size//patch_size)**2

      patches_dim = channels*patch_size**2

      self.patch_size = patch_size
      self.pos_embeds =  nn.Parameter(torch.empty(1, (num_patches + 1), input_dim))

      self.patch_conv = nn.Conv2d(3,input_dim,kernel_size=patch_size,stride=patch_size)#change input_dim name


      self.cls_embeds = nn.Parameter(torch.zeros(1,1,input_dim))
      self.dropout = nn.Dropout(embedding_dropout)

      self.transformer_encoder = TransformerEncoder(input_dim,depth,heads,mlp_dim,dropout)

      self.to_cls_embeds =  nn.Identity()

      self.mlp_head = nn.Linear(input_dim,num_classes)

      self.init_weights()

    def init_weights(self):

      torch.nn.init.normal_(self.pos_embeds, std = .02) # initialized based on the paper

      torch.nn.init.xavier_uniform_(self.mlp_head.weight)
      torch.nn.init.normal_(self.mlp_head.bias, std = 1e-6)


    def forward(self,img,mask=None):


      img_patches = self.patch_conv(img)

      x = rearrange(img_patches,'b c h w -> b (h w) c')
      cls_embeds = self.cls_embeds.expand(img.shape[0],-1,-1)

      x = torch.cat((cls_embeds,x),dim=1)
      x+= self.pos_embeds
      x = self.dropout(x)
      x = self.transformer_encoder(x,mask)
      x = self.to_cls_embeds(x[:,0])

      x = self.mlp_head(x)

      return x







In [None]:
BATCH_SIZE_TRAIN = 500
BATCH_SIZE_TEST = 100

DL_PATH = "C:\Pytorch\Spyder\CIFAR10_data" # Use your own path
# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
transform = torchvision.transforms.Compose(
     [torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR),
     torchvision.transforms.RandomAffine(8, translate=(.15,.15)),
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])


train_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=True,
                                        download=True, transform=transform)

test_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=False,
                                       download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN,
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST,
                                         shuffle=False)


In [None]:
import time
def train(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model.train()

    for i, (data, target) in enumerate(data_loader):
        data = data.to('cuda')
        target = target.to('cuda')
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())
            
def evaluate(model, data_loader, loss_history):
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0

    with torch.no_grad():
        for data, target in data_loader:
            data = data.to('cuda')
            target = target.to('cuda')
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

In [None]:
N_EPOCHS = 150

model = ViT(image_size=32, patch_size=4, num_classes=10, channels=3,
            input_dim=64, depth=6, heads=8, mlp_dim=128)
model = model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)



train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    start_time = time.time()
    train(model, optimizer, train_loader, train_loss_history)
    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
    evaluate(model, test_loader, test_loss_history)

print('Execution time')

PATH = "ViTnet_Cifar10_4x4_aug_1.pt" # Use your own path
torch.save(model.state_dict(), PATH)