## **Vision Transformer: An Image is Worth 16X16 Words**

- The Transformer architecture is based on self-attention mechanisms and is highly parallelizable. 

- Transformer processes data in a sequence-to-sequence manner, allowing it to capture long-range dependencies more effectively than traditional models like recurrent neural networks (RNNs).

- Transformer architecture is initially designed for natural language processing tasks.

- **Vision Transformers(ViTs) are an adaptation of the Transformer architecture to handle image data**. 

- [Paper](https://arxiv.org/abs/2010.11929)

- Applications:
   - Object detection 
   - Image segmentation
   - Image classification
   - Action recognition
   - Visual grounding
   - Visual-question answering

### **Adapting Transformer from NLP to Vision**
* **From Words to Patches:**

    * Transformers in NLP: 

        * In NLP, transformers process sentences by treating each word as an input token. 
        
        * These tokens are then embedded into vectors and passed through the Transformer model to capture relationships between words.
        
    * Transformers in Vision: 
        * The core idea is to represent an **image as a sequence of smaller patches**

        * Each patch is treated as a token. 
        
        * These patches are embedded into vectors and fed into the Transformer model.

*  **Self-Attention Mechanism**

    * Global Context: 
    
        * The self-attention mechanism in Transformers allows the model to consider the relationships between all tokens (patches in the case of ViTs) simultaneously.

    * Attention Weights: 

        * Each patch's representation is updated by taking a weighted sum of the representations of all patches, where the weights are dynamically computed based on the similarity between patches. This allows the model to focus on the most relevant parts of the image for a given task.

### **Decoding Vision Transformer Architecture**
<center>
    <img src="./assets/vit.png"/>
</center>


**1. Image Patching**
* The input image is divided into fixed-size patches, typically 16x16 pixels. 

* Each patch is flattened into a 1D vector.

* Consider an input image with dimensions 𝐻 × 𝑊 × 𝐶

Let's say: Input Image→(224,224,3)

The image is divided into smaller patches of size P×P pixels. Let's assume each patch is 16×16 pixels i.e. P=16

Number of patches along height = 224/16 = 14

Number of patches along width = 224/16 = 14

Total patches =14×14=196

Each 16×16 patch is flattened into a 1D vector. The dimension of this vector will be P×P×C=16×16×3=768

Total 1D vectors of size 768 = 196

**2.  Patch Embedding**

* After dividing the input image into patches and flattening them, each patch is a high-dimensional vector. 

* The next step is to project these high-dimensional vectors into a lower-dimensional embedding space using a learned embedding matrix.

* In our example, this matrix will have dimensions (768,D), where D is the desired lower-dimensional space (embedding dimension).

For example, if D=128:

Embedding matrix E has dimensions (768,128)

Resulting embedded patch vectors will have dimensions (196,128).

**3. Positional Embedding**
 
* To preserve the spatial information of the patches, positional embeddings are added to the patch embeddings. 

* These embeddings encode the position of each patch within the image.

1. **Original Image**: $(224, 224, 3)$

2. **Patch Size**: $16 \times 16$

3. **Number of Patches**: $14 \times 14 = 196$

4. **Flattened Patch Dimension**: $768$

5. **Embedded Patch Dimension**: $128$

6. **Resulting Embedded Patches $Z$**: $(196, 128)$

7. **Positional Embedding Matrix Dimension $P$**: $(196, 128)$

8. **Position Encoded Patches Embeddings**: $Z+P$ = $(196, 128)$



**4. Transformer Encoder**
* The embedded patches are passed through a stack of transformer encoder layers.

* Each layer consists of multi-head self-attention and feed-forward neural networks (FFNs). 

* The self-attention mechanism allows the model to capture dependencies between different patches.



**5. Classification Head**
* The output of the transformer encoder is passed through a classification head, typically consisting of a layer normalization followed by a linear layer, to obtain the final class predictions.

### **Implementing Step by Step**

In [3]:
# cell 1
import torch.nn as nn
import torch
import math

**Image Splitting and Patch Embeddings**

In [7]:
# cell 2
class PatchEmbeddings(nn.Module):
    """
    Convert the image into patches and then project them into a vector space.
    """

    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]
        
        # Calculate the number of patches from the image size and patch size
        self.num_patches = (self.image_size // self.patch_size) ** 2
        
        # Create a projection layer to convert the image into patches
        # The layer projects each patch into a vector of size hidden_size
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)
        # self.projection = nn.Linear(self.num_patches*3, self.hidden_size)

    def forward(self, x):
        # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, hidden_size)(batchsize, 196, 128)
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x
    

In [8]:
# cell 3
# Testing the PatchEmbeddings class with dummy input

config = {
    "image_size": 224,
    "patch_size": 16,
    "num_channels": 3,
    "hidden_size": 128
}


patch_embeddings = PatchEmbeddings(config)

input_tensor = torch.randn(8, 3, 224, 224)  # batch_size=8, num_channels=3, image_size=224
output_tensor = patch_embeddings(input_tensor)

print(output_tensor.shape) 

torch.Size([8, 196, 128])


**Position Encoded Embeddings**

* **Note**: 
    * After the patches are converted to a sequence of embeddings, the [CLS] token is added to the beginning of the sequence, it will be used later in the classification layer to classify the image. 
    
    * The [CLS] token’s embedding is learned during training.


In [11]:
# cell 4
class Embeddings(nn.Module):
    """
    Combine the patch embeddings with the class token and position embeddings.
    """
        
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)
        
        # Create a learnable [CLS] token. [CLS] token is added to the beginning of the input sequence
        # and is used to classify the entire sequence
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))
        
        # Create position embeddings for the [CLS] token and the patch embeddings
        # Add 1 to the sequence length for the [CLS] token
        self.position_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches+1, config["hidden_size"]))
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size, _, _ = x.size()
        # Expand the [CLS] token to the batch size
        # (1, 1, hidden_size) -> (batch_size, 1, hidden_size)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        
        # Concatenate the [CLS] token to the beginning of the input sequence
        # This results in a sequence length of (num_patches + 1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add the position embeddings to the input sequence
        x = x + self.position_embeddings
        x = self.dropout(x)
        return x

In [12]:
# cell 5

# Testing the Position Encoded Embeddings with dummy input
config = {
    "image_size": 224,
    "patch_size": 16,
    "num_channels": 3,
    "hidden_size": 128,
    "hidden_dropout_prob": 0.5   
}

embedding = Embeddings(config)
input_tensor = torch.randn(8, 3, 224, 224)  # batch_size=8, num_channels=3, image_size=224
output_tensor = embedding(input_tensor)

print(output_tensor.shape) 

torch.Size([8, 197, 128])


**Single Attention Layer**

In [13]:
# cell 6

class AttentionHead(nn.Module):
    """
    A single attention head.
    This module is used in the MultiHeadAttention module.
    """
    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        
        # Create the query, key, and value projection layers
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Project the input into query, key, and value
        # The same input is used to generate the query, key, and value,
        # so it's usually called self-attention.
        # (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, attention_head_size)
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        
        # Calculate the attention scores
        # softmax(Q*K.T/sqrt(head_size))*V
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        # Calculate the attention output
        attention_output = torch.matmul(attention_probs, value)
        return (attention_output, attention_probs)

In [14]:
# cell 7

# Testing the AttentionHead class with dummy input
hidden_size = 128
attention_head_size = 32
dropout = 0.5
bias = True
attention_head = AttentionHead(hidden_size, attention_head_size, dropout, bias)
input_tensor = torch.randn(8, 197, 128)  # batch_size=8, sequence_length=197, hidden_size=128
attention_output, attention_probs = attention_head(input_tensor)

print("Attention Output:", attention_output.shape)
print("Attention Probs:", attention_probs.shape)

Attention Output: torch.Size([8, 197, 32])
Attention Probs: torch.Size([8, 197, 197])


**Multi-Head Attention**

In [15]:
# cell 8

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention module.
    This module is used in the TransformerEncoder module.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]
        
        # The attention head size is the hidden size divided by the number of attention heads
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # Whether or not to use bias in the query, key, and value projection layers
        self.qkv_bias = config["qkv_bias"]
        # Create a list of attention heads
        self.heads = nn.ModuleList([])
        for _ in range(self.num_attention_heads):
            head = AttentionHead(
                self.hidden_size,
                self.attention_head_size,
                config["attention_probs_dropout_prob"],
                self.qkv_bias
            )
            self.heads.append(head)
        # Create a linear layer to project the attention output back to the hidden size
        # In most cases, all_head_size and hidden_size are the same
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size, bias=True)
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x, output_attentions=False):
        # Calculate the attention output for each attention head
        attention_outputs = [head(x) for head in self.heads]
        
        # Concatenate the attention outputs from each attention head
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        
        # Project the concatenated attention output back to the hidden size
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        # Return the attention output and the attention probabilities (optional)
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)

In [16]:
# cell 9

# Testing MultiHeadAttention class with dummy input
config = {
    "image_size": 224,
    "patch_size": 16,
    "num_channels": 3,
    "hidden_size": 128,
    "hidden_dropout_prob": 0.5,
    "qkv_bias": True,
    "num_attention_heads": 4,
    "attention_probs_dropout_prob": 0.5
}
multi_head_attention = MultiHeadAttention(config)
input_tensor = torch.randn(8, 197, 128)  # batch_size=8, sequence_length=197, hidden_size=128
attention_output, attention_probs = multi_head_attention(input_tensor)

print("Attention Output:", attention_output.shape)

Attention Output: torch.Size([8, 197, 128])


**Transformer Encoder**

<center>
    <img src="./assets/encoder.png" width=200/>
</center>

* The transformer encoder is made of a stack of transformer layers. 
* Each transformer layer mainly consists of a multi-head attention module and a feed-forward network. 
* To better scale the model and stabilize training, two Layer normalization layers and skip connections are added to the transformer layer.

In [17]:
# cell 10

class MLP(nn.Module):
    """
    A multi-layer perceptron module.
    """
    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.activation = nn.GELU()
        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x

In [18]:
# cell 11

class Block(nn.Module):
    """
    A single transformer block.
    """

    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, output_attentions=False):
        # Self-attention
        attention_output, attention_probs = self.attention(self.layernorm_1(x), output_attentions=output_attentions)
        
        # Skip connection
        x = x + attention_output
        
        # Feed-forward network
        mlp_output = self.mlp(self.layernorm_2(x))
        # Skip connection
        x = x + mlp_output
        # Return the transformer block's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)

In [19]:
# cell 12

class Encoder(nn.Module):
    """
    The transformer encoder module.
    """

    def __init__(self, config):
        super().__init__()
        # Create a list of transformer blocks
        self.blocks = nn.ModuleList([])
        for _ in range(config["num_hidden_layers"]):
            block = Block(config)
            self.blocks.append(block)

    def forward(self, x, output_attentions=False):
        # Calculate the transformer block's output for each block
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        # Return the encoder's output and the attention probabilities (optional)
        if not output_attentions:
            return (x, None)
        else:
            return (x, all_attentions)

In [20]:
# cell 13

class ViTForClassfication(nn.Module):
    """
    The ViT model for classification.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.hidden_size = config["hidden_size"]
        self.num_classes = config["num_classes"]
        # Create the embedding module
        self.embedding = Embeddings(config)
        # Create the transformer encoder module
        self.encoder = Encoder(config)
        # Create a linear layer to project the encoder's output to the number of classes
        self.classifier = nn.Linear(config["hidden_size"], config["num_classes"])
        # Initialize the weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config["initializer_range"])
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.position_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.cls_token.dtype)

    def forward(self, x, output_attentions=False):
        # Calculate the embedding output
        embedding_output = self.embedding(x)
        
        # Calculate the encoder's output
        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
        
        # Calculate the logits, take the [CLS] token's output as features for classification
        logits = self.classifier(encoder_output[:, 0])
        # Return the logits and the attention probabilities (optional)
        if not output_attentions:
            return (logits, None)
        else:
            return (logits, all_attentions)

In [4]:
# cell 14

import torchvision
import torchvision.transforms as transforms

def prepare_data(batch_size=4, num_workers=2, train_sample_size=None, test_sample_size=None):
    train_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=train_transform)
    if train_sample_size is not None:
        # Randomly sample a subset of the training set
        indices = torch.randperm(len(trainset))[:train_sample_size]
        trainset = torch.utils.data.Subset(trainset, indices)
    


    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=num_workers)
    
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transform)
    if test_sample_size is not None:
        # Randomly sample a subset of the test set
        indices = torch.randperm(len(testset))[:test_sample_size]
        testset = torch.utils.data.Subset(testset, indices)
    
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)

    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainloader, testloader, classes

In [21]:
# cell 15

from torch import nn, optim

class Trainer:

    def __init__(self, model, optimizer, loss_fn, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device

    def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=0):
        """
        Train the model for the specified number of epochs.
        """
        # Keep track of the losses and accuracies
        train_losses, test_losses, accuracies = [], [], []
        # Train the model
        for i in range(epochs):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            if save_model_every_n_epochs > 0 and (i+1) % 5 == 0 and i+1 != epochs:
                print('\tSave checkpoint at epoch', i+1)
                torch.save(self.model.state_dict(), 'model_epoch_' + str(i+1) + '.pt')
                
        return train_losses, test_losses, accuracies

    def train_epoch(self, trainloader):
        """
        Train the model for one epoch.
        """
        self.model.train()
        total_loss = 0
        for batch in trainloader:
            # Move the batch to the device
            batch = [t.to(self.device) for t in batch]
            images, labels = batch
            # Zero the gradients
            self.optimizer.zero_grad()
            # Calculate the loss
            loss = self.loss_fn(self.model(images)[0], labels)
            # Backpropagate the loss
            loss.backward()
            # Update the model's parameters
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                # Move the batch to the device
                batch = [t.to(self.device) for t in batch]
                images, labels = batch
                
                # Get predictions
                logits, _ = self.model(images)

                # Calculate the loss
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)

                # Calculate the accuracy
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss


In [5]:
# Cell 16

batch_size = 16

# Load the CIFAR10 dataset
trainloader, testloader, _ = prepare_data(batch_size=batch_size)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:33<00:00, 5019702.88it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [23]:
# Cell 17

config = {
    "patch_size": 4,  # Input image size: 32x32 -> 8x8 patches
    "hidden_size": 48,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 48, # 4 * hidden_size
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10, # num_classes of CIFAR10
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
}
epochs = 3
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Create the model, optimizer, loss function and trainer
model = ViTForClassfication(config)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
loss_fn = nn.CrossEntropyLoss()
trainer = Trainer(model, optimizer, loss_fn, device=device)
train_losses, test_losses, accuracies = trainer.train(trainloader, testloader, epochs)

Using cpu device


  return torch._C._cuda_getDeviceCount() > 0


Epoch: 1, Train loss: 1.8559, Test loss: 1.6994, Accuracy: 0.3842
Epoch: 2, Train loss: 1.5523, Test loss: 1.4242, Accuracy: 0.4797
Epoch: 3, Train loss: 1.4072, Test loss: 1.3158, Accuracy: 0.5237
