In [1]:
import torch 
import torch.nn as nn  
import matplotlib.pyplot as plt 
import torchvision
import numpy as np 
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [2]:
learning_rate = 0.001
weight_decay = 0.0001
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
batch_size = 64

In [3]:
num_classes = 100
input_shape = (32,32,3)
import torchvision.transforms as transforms
#dataset download 
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
                transforms.Resize(image_size)])

trainset = torchvision.datasets.CIFAR100(root='./data.cifar100', train=True,
                                    download=True, transform=transform)
testset = torchvision.datasets.CIFAR100(root='./data.cifar100', train=False,
                                    download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
train_dataloader = torch.utils.data.DataLoader(trainset,batch_size=64)
test_dataloader = torch.utils.data.DataLoader(testset,batch_size=64)

len_train, len_test = len(train_dataloader), len(test_dataloader)

data_iter = iter(train_dataloader)
img, label = next(data_iter)

In [6]:
class MLPS(nn.Module): 
    def __init__(self,hidden_units,drop_prob=0.0,hidden_size1=64):
        super(MLPS,self).__init__()
        self.layers = []
        hidden_units.insert(0,hidden_size1)
        for i in range(0,len(hidden_units)-1): 
            self.layers.append(nn.Linear(hidden_units[i],hidden_units[i+1]))
            self.layers.append(nn.Dropout(drop_prob))
        
        self.layers = nn.Sequential(*self.layers)
    
    def forward(self,x): 
        for layer in self.layers:
            x = layer(x)
        return x   


In [7]:
class TransformerBlock(nn.Module):
    def __init__(self,hidden_units,num_heads,projection_dim,dropout=0.0):
        super(TransformerBlock,self).__init__()
        self.hidden_units = hidden_units
        self.attention = nn.MultiheadAttention(projection_dim,num_heads,dropout=dropout)
        self.norm = nn.LayerNorm(projection_dim,eps=1e-6)
        self.mlp = MLPS(hidden_units)

    def forward(self,encoded_patches):
        x1 = encoded_patches
        #add norm layer here
        attention_out, out_weights = self.attention(x1,x1,x1)
        x2 = x1 + x1 #residual concatenation 
        #add norm layer here
        x3 = self.mlp(x2)
        encoded = x3 + x2
        return encoded
        
  

In [49]:
class VIT(nn.Module):
   def __init__(self,transformer=None,image_size=image_size,patch_size=patch_size,num_classes=num_classes,projection_dim=projection_dim,num_patches=num_patches,
   hidden_units=transformer_units, num_heads = num_heads,mlp_head_units=mlp_head_units, drop_prob=0.0):
      super(VIT,self).__init__()
      assert image_size % patch_size == 0, "image size must be dividible by patch size" 
      self.num_patches =  num_patches
      self.patch_dim = 3 * patch_size**2
   
      self.flatten_patches = Rearrange('b c (h px1) (w px2) -> b (h w) (px1 px2 c)', px1 = patch_size, px2 = patch_size)
      self.patch_emedding = nn.Linear(self.patch_dim,projection_dim)

      self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, projection_dim))   #pose embedding 
      self.class_token = nn.Parameter(torch.randn(1, 1, projection_dim))   #class embedding 
     
      self.transformer = TransformerBlock(hidden_units,num_heads,projection_dim,drop_prob)
      self.pose_embedding = nn.Embedding(num_patches+1,projection_dim)
      self.projection = nn.Linear(108,projection_dim)
      #output
      self.mlp_head = MLPS(mlp_head_units,drop_prob=0.5,hidden_size1=9216)

      self.to_latent = nn.Linear(mlp_head_units[-1],num_classes)
      self.flatten = nn.Flatten()
      self.drop = nn.Dropout(drop_prob)
      self.softmax = nn.Softmax(dim=0)

   
   def forward(self,img):
      #"pre process"
      """
      x = self.flatten_patches(x)
      x = self.patch_emedding(x)
      """
      patch = self.flatten_patches(img)
      #patch encoding 
      positions = torch.arange(0,self.num_patches)
      encoded1 = self.pose_embedding(positions)
      encoded2 = self.projection(patch)
      encoded = encoded1 + encoded2
      block = self.transformer(encoded)
      block_norm = block #add norm here
      representation = self.drop(self.flatten(block_norm))
      features = self.mlp_head(representation)
      logits = self.to_latent(features)
      logits = self.softmax(logits)
      return logits 
   

In [50]:
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
mlp_head_units = [2048,1024] 
model = VIT(hidden_units=transformer_units)
img, label = next(data_iter)
img = img[0]
img = img.unsqueeze(0)
model(img)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], grad_fn=<SoftmaxBackward>)

In [51]:
criterion = nn.CrossEntropyLoss(reduction=True)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)


In [52]:

device = "cuda" if torch.cuda.is_available() else "cpu" 

if device == "cuda":
    model = model.to(device)

val_loss_min = float("Inf")
for i in range(0,num_epochs):
    train_loss,val_loss = 0.,0. 
    model.train()

    for img, label in train_dataloader:
        if device == "cuda":
            img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        out = model(img)
        print(out.shape,label.shape)
        print(sum(out))
        print(label)
        loss = criterion(label,out)
        loss.backward()
        optim.step() 
        train_loss += loss.item() * batch_size
    
    #validation
    model.eval()
    for img, label in val_dataloader:
        if device == "cuda":
            img, label = img.to(device), label.to(device)
        out = model(img)
        loss = criterion(out,label)
        val_loss += loss.item() * batch_size
    
    val_loss = val_loss/len(val_dataloader.dataset)
    train_loss = train_loss/len(train_dataloader.dataset)

    # print training/validation statistics 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, val_loss))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        val_loss_min,
        val_loss))
        torch.save(model.state_dict(), 'model_cifar.pt')
        val_loss_min = val_loss




torch.Size([64, 100]) torch.Size([64])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], grad_fn=<AddBackward0>)
tensor([19, 29,  0, 11,  1, 86, 90, 28, 

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)