# Transformer

<div style="display: flex; align-items: center;">
    <img src="../imgs/Transformer.jpg" alt="Your Image" width="300" style="margin-right: 20px;">
    <div>
        <p>The Transformer model is a revolutionary deep learning architecture that, with its unique self attention mechanism as its core, completely changes the way sequence modeling is done. This mechanism allows the model to consider all elements in parallel when processing sequences, rather than gradually processing them in order like traditional recurrent neural networks, greatly improving computational efficiency. Through multi head attention, Transformer can simultaneously capture sequence information from different perspectives, enhancing the model's ability to learn complex features.</p>
        <p>In addition to self attention mechanism, Transformer also introduces positional encoding to solve the problem of element order in sequences, which is crucial for maintaining the temporal sensitivity of sequence data. In each encoder and decoder layer of the model, the output of the self attention layer is transmitted to the feedforward network for further feature extraction and processing. In order to improve the training stability of deep networks, Transformer adopts layer normalization technology and alleviates the problem of gradient vanishing through residual connections, making the training of deep networks more feasible.</p>
        <p>These design features of the Transformer model have quickly made it mainstream in the field of natural language processing, especially in tasks such as machine translation, text summarization, and question answering systems. Its flexibility and powerful representation ability have also shown wide application potential in other fields such as speech recognition and image processing, making it one of the most influential models in the current field of deep learning.</p>
    </div>
</div>

# Vision Transformer

<div style="display: flex; align-items: center;">
    <img src="../imgs/ViT.jpg" alt="Your Image" width="600" style="margin-right: 20px;">
    <div>
        <p>ViT (vision transformer) is a model proposed by Google in 2020 that directly applies transformer to image classification. Many subsequent works have been improved based on ViT. The idea of ViT is simple: directly divide the image into fixed size patches, and then obtain patch embeddings through linear transformation, which is similar to NLP's words and word embeddings. Since the input of the transformer is a sequence of token embeddings, the patch embeddings of the image can be fed into the transformer for feature extraction and classification. As shown in the schematic diagram of the ViT model, in fact, the ViT model only uses the Encoder of the transformer to extract features (the original transformer also has a decoder section, which is used to implement sequence to sequence, such as machine translation).</p>
        <p></p>
    </div>
</div>

## ViT

Vision Transformers (ViTs), while highly effective on large-scale datasets, may underperform compared to Convolutional Neural Networks (CNNs) on smaller or simpler datasets without pre-training due to several factors. ViTs require substantial data to leverage their large model capacity, which can lead to overfitting on limited data. Their design focuses on capturing global dependencies, which might be excessive for the local pattern recognition needed in smaller datasets. Additionally, without the feature-rich initialization provided by pre-training, ViTs struggle to learn from scratch, unlike CNNs that are inherently efficient and can quickly adapt to available data due to their architectural advantages in processing spatial hierarchies. So in this chapter, we will not see the practical effects of ViT.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Patch Embedding
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.embed_dim = embed_dim
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x

# Attention
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each shape: (B, num_heads, N, head_dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

# Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False):
        super(EncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads, qkv_bias)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Encoder
class Encoder(nn.Module):
    def __init__(self, depth, dim, num_heads, mlp_ratio=4., qkv_bias=False):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(dim, num_heads, mlp_ratio, qkv_bias) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

# Vision Transformer (ViT)
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(0.1)
        
        self.encoder = Encoder(depth, embed_dim, num_heads, mlp_ratio, qkv_bias)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_layer_weights)

    def _init_layer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, 1 + N, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.encoder(x)
        cls_token_final = x[:, 0]  # Extract the class token
        x = self.head(cls_token_final)
        return x