# Colab Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
"""
Change directory to where this file is located
"""
%cd 'COPY&PASTE FILE DIRECTORY HERE'

# Import Modules

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm.auto import tqdm

# ViT Model  

An Image is worth 16*16 words: Transformers for image recognition at scale  
https://arxiv.org/pdf/2010.11929.pdf

![Image](architecture.png)

In [None]:
class Patchification(nn.Module):
  """
  Divide image into patches 
  Input shape: [batch, channel, height, width]
  Return: [batch, number_of_patches, embedding_dimension]
  """
  def __init__(self, in_channels, patch_size, embedding_dim):
    super().__init__()
    
    # embedding_dim == out_channel of convolution.
    self.conv = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self, x):

    # input x shape: [batch, channel, height, width]
    x = self.conv(x)          
    B, E, H, W = x.shape                # x shape [batch, embedding_dim, height/patch_size, width/patch_size]
    x = x.permute(0, 2, 3, 1)           # x shape [batch, height/patch_size, width/patch_size, embedding_dim]
    x = x.reshape(B, -1, E)             # x shape [batch, number_of_patches, embedding_dim]
    return x

class Linear_Patchification(nn.Module):
  """
  Convolution can be replaced by linear projection
  If you want to implement ViT only using linear projection, you can use this class.
  Input shape: [batch, channel, height, width]
  Return: [batch, number_of_patches, embedding_dimension]
  """
  def __init__(self, in_channels, patch_size, embedding_dim):
    super().__init__()
    
    # Define the linear projection layer
    self.patch_size = patch_size
    self.embedding_dim = embedding_dim
    self.linear = nn.Linear(in_channels * self.patch_size[0] * self.patch_size[1], embedding_dim)

  def forward(self, x):

    # x shape [batch, channel, height, width]
    B, C, H, W = x.shape  
    # x shape = [B, C, H/P, P, W/P, P]
    x = x.reshape(B, C, H//self.patch_size[0], self.patch_size[0], W//self.patch_size[1], self.patch_size[1])
    # x shape = [B, H/P, W/P, C, P, P]
    x = x.permute(0,2,4,1,3,5)
    # x shape = [B, (H/P)*(W/P), C, P, P]
    x = x.reshape(B, -1, C * self.patch_size[0] * self.patch_size[1])

    x = self.linear(x)
    #####################
    return x

In [None]:
class MLP(nn.Module):
  """
  Feed-forward layer
  Input shape: [batch, number_of_patches, embedding_dimension]
  Return: [batch, number_of_patches, embedding_dimension]
  """
  def __init__(self, dim, hidden_dim, dropout = 0.):
    super().__init__()
    self.fc1 = nn.Linear(dim, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, dim)
    self.dropout = nn.Dropout(dropout)
    self.activation = nn.GELU()               # GELU activation function

  def forward(self, x):
    x = self.fc1(x)                           # x shape [batch, number_of_patches, hidden_dim]
    x = self.activation(x)      
    x = self.dropout(x)  
    x = self.fc2(x)                           # x shape [batch, number_of_patches, embedding_dimension]
    return x

In [None]:
class Attention(nn.Module):
  """
  Multi-head attention
  Input shape: [batch, number_of_patches, embedding_dimension]
  Return: [batch, number_of_patches, embedding_dimension]
  """
  def __init__(self, dim, num_heads, dropout = 0.):
    super().__init__()
    self.head_dim = dim // num_heads
    self.dim = dim
    self.num_heads = num_heads
    self.scale = self.head_dim ** 0.5                     
    self.dropout = nn.Dropout(dropout)
    self.qkv = nn.Linear(dim, dim * 3, bias=False) 
    self.fc = nn.Linear(dim, dim)  

  def forward(self, x):
    B, N, E = x.shape
    # qkv shape [batch, number_of_patches, 3*embedding_dimension]
    qkv = self.qkv(x)  
    # qkv shape [batch, number_of_patches, 3, num_heads, head_dim]
    qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim) 

    q = qkv[:, :, 0]         # Query shape [batch, number_of_patches, num_heads, head_dim]
    k = qkv[:, :, 1]         # Key shape [batch, number_of_patches, num_heads, head_dim]
    v = qkv[:, :, 2]         # Value shape [batch, number_of_patches, num_heads, head_dim]

    # Compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
    attn = torch.softmax(scores, dim=-1)
    attn = self.dropout(attn)

    # attn shape [batch, num_heads, number_of_patches, number_of_patches]
    # v shape [batch, num_heads, number_of_patches, head_dim]
    weighted_values = torch.matmul(attn, v)

    # weighted_values shape [batch, number_of_patches, num_heads, head_dim]
    weighted_values = weighted_values.transpose(1, 2).reshape(B, N, self.dim)
    x = self.fc(weighted_values)
    return x

class Block(nn.Module):
  """
  Attention block
  Input shape: [batch, number_of_patches, embedding_dimension]
  Return: [batch, number_of_patches, embedding_dimension]
  """
  def __init__(self, dim, num_heads, mlp_dim, dropout=0.):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.attention = Attention(dim, num_heads, dropout=dropout)
    self.norm2 = nn.LayerNorm(dim)
    self.mlp = MLP(dim, mlp_dim, dropout=dropout)

  def forward(self, x):
   
    norm_x = self.norm1(x)  
    attn_output = self.attention(norm_x)  
    x = x + attn_output                     #residual connection
    norm_x = self.norm2(x)  
    mlp_output = self.mlp(norm_x)  
    x = x + mlp_output                      #residual connection
    return x

In [None]:
class ViT(nn.Module):
    def __init__(self, image_shape, patch_size, num_classes, dim, num_heads, depth, mlp_dim, dropout = 0.):
        super().__init__()
        """
        image_shape: [channel, height, width]
        patch_size: [height, width]
        dim: Embedding dimension
        num_heads: Number of heads to be used in Multi-head Attention
        depth: Number of attention blocks to be used
        mlp_dim: Hidden dimension to be used in MLP layer (=feedforward layer)
        """

        # image_ch will be 3(RGB 3 channels) for CIFAR10 dataset
        image_ch, image_h, image_w = image_shape 
        patch_h, patch_w = patch_size

        assert image_h % patch_h == 0 and image_w % patch_w == 0, 'Image height & width must be divisible by those of patch respectively.'
        assert dim % num_heads == 0, 'Embedding dimension should be divisible by number of heads.'

        # e.g. [32 x 32] image & [8 x 8] patch size -> [4 x 4 = 16] patches
        num_patches = (image_h // patch_h) * (image_w // patch_w) 

        # Patchification using convolution.
        self.patchify = Patchification(image_ch, patch_size, dim)

        # Use linear patchification if you want to use linear layer instead of convolution.
        # self.patchify = Linear_Patchification(image_ch, patch_size, dim)

        # Learnable positional encoding, 1+ is for class token.
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, dim)) 

        # Class token which will be prepended to each image.
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # Initialize attention blocks
        self.attention_blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])

        # Classification head, maps the final vector to class dimension.
        self.classification_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)        # Shape: [batch, 1, dim]
        """
        For classification, we need to add cls token to each image.
        Then, at last use cls token for classification.
        """
        
        # patch shape: [batch, number_of_patches, dim]
        patches = self.patchify(img)  
        
        # x shape: [batch, 1 + number_of_patches, dim]
        x = torch.cat([cls_tokens, patches], dim=1)  
        x = x + self.pos_embedding 

        # Pass pathces through attention blocks
        for attention_block in self.attention_blocks:
            x = attention_block(x)  # x shape: [batch, 1 + number_of_patches, dim]

        # Use cls token
        # x shape: [batch, 1, dim]
        x = x[:,0,:]

        # x shape: [batch, num_classes]
        x = self.classification_head(x)  

        return x

# ViT Image Classification

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using PyTorch version: {}, Device: {}".format(torch.__version__, DEVICE))

In [None]:
def train(model, train_loader, optimizer, criterion, DEVICE):
    model.train()
    tqdm_bar = tqdm(train_loader)
    for batch_idx, (image, label) in enumerate(tqdm_bar):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        tqdm_bar.set_description("Epoch {} - train loss: {:.6f}".format(epoch, loss.item()))


def evaluate(model, test_loader, criterion, DEVICE):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for image, label in tqdm(test_loader):
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += criterion(output, label).item()
            prediction = output.max(1, keepdim=True)[1]
            correct += prediction.eq(label.view_as(prediction)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [None]:
BATCH_SIZE = 100

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Prepare Dataset & DataLoader
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Hyperparameters
EPOCHS = 10
patch_size = (4,4)
dim = 128
depth = 8
num_heads = 8
mlp_dim = 256
dropout = 0.
learning_rate = 0.001

model = ViT(image_shape = (3,32,32), patch_size = patch_size, num_classes = 10, dim = dim, num_heads = num_heads, depth = depth, mlp_dim = mlp_dim, dropout=dropout).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
# Train 

for epoch in range(1, EPOCHS + 1):
    train(model, trainloader, optimizer, criterion, DEVICE)
    test_loss, test_accuracy = evaluate(model, testloader, criterion, DEVICE)
    print("\n[EPOCH: {}], \tModel: ViT, \tTest Loss: {:.4f}, \tTest Accuracy: {:.2f} % \n".format(
        epoch, test_loss, test_accuracy))