In [1]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

## Transformer Encoder
**Transformer: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)**
<p align="center">
    <img src="./assets/Multi-Head-Attention.png" width="750">
    <img src="./assets/Transformer-Encoder.png" width="200">
</p>

In [2]:
# Norm + Multi-Head Attention
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5  # 缩放因子，防止点积过大

        self.norm = nn.LayerNorm(dim)  # 层归一化，规范化输入

        self.attend = nn.Softmax(dim = -1)  # 计算注意力权重
        self.dropout = nn.Dropout(dropout)  # dropout 防止过拟合

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 一次性生成Q、K、V

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  # 将多头输出拼接后的结果映射回原维度
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()  # 如果只有一个头并且维度不变，直接返回

    def forward(self, x):
        # Norm
        x = self.norm(x)  # 输入归一化，提升训练稳定性

        # Linear
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 拆分为 Q、K、V 三个部分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  
        # 重排shape，方便做多头计算

        # Scaled Dot-Product Attention
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # QK转置后点乘，计算相似度
        attn = self.attend(dots)  # softmax 获取注意力权重
        attn = self.dropout(attn)  # dropout

        out = torch.matmul(attn, v)  # 加权求和，获得注意力输出
        
        # concat
        out = rearrange(out, 'b h n d -> b n (h d)')  # 将多头结果拼接回原shape
        return self.to_out(out)  # 映射回原维度并返回


**[GAUSSIAN ERROR LINEAR UNITS (GELUS)](https://arxiv.org/pdf/1606.08415)**
<p align="center">
    <img src="./assets/QuickGELU.png" width="500">
</p>

In [3]:
# Normal + MLP
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),           # 层归一化，提升训练稳定性
            # MLP 
            nn.Linear(dim, hidden_dim),  # 线性升维
            nn.GELU(),                   # GELU 激活函数，非线性变换
            nn.Dropout(dropout),         # dropout 防止过拟合
            nn.Linear(hidden_dim, dim),  # 线性降维，回到原始维度
            nn.Dropout(dropout)          # 再次 dropout
        )

    def forward(self, x):
        return self.net(x)  # 前向计算

In [4]:
# Transformer Encode
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # Transformer 最后的输出归一化
        self.layers = nn.ModuleList([])  # 存储每一层的Attention和FeedForward

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),  # 多头自注意力
                FeedForward(dim, mlp_dim, dropout = dropout)  # 前馈神经网络
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x   # 残差连接 + 注意力
            x = ff(x) + x     # 残差连接 + 前馈网络

        return self.norm(x)  # 最终归一化输出

## Vision Transformer
**[Vision Transformer(Vit)](https://arxiv.org/pdf/2010.11929)**
<p align="center">
    <img src="./assets/Vit.png" width="700">
</p>

In [None]:
# Vision Transformer
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        """
        Vision Transformer (ViT) 模型

        参数说明:
        image_size: 输入图片大小 (可以是单个整数或tuple)
        patch_size: 每个 patch 的大小 (可以是单个整数或tuple)
        num_classes: 分类任务的类别数量
        dim: patch 嵌入的维度
        depth: Transformer 的层数
        heads: 多头自注意力中的头数
        mlp_dim: Transformer中MLP层的隐藏维度
        pool: 'cls' 或 'mean'，表示池化方式
        channels: 输入图像的通道数（默认是3，即RGB图像）
        dim_head: 每个注意力头的维度
        dropout: Transformer中的dropout率
        emb_dropout: patch嵌入后的dropout率
        """
        super().__init__()
        
        # 解析图像大小和patch大小，保证为 (height, width) 形式
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        # 确保图像尺寸能够被patch大小整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算patch的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)

        # 每个patch展平成一维后的长度
        patch_dim = channels * patch_height * patch_width

        # pool方式检查
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # patch embedding 模块，将图像分割成patch并进行线性映射
        self.to_patch_embedding = nn.Sequential(
            # 使用einops重排，将输入图像划分为patch
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            
            # 对每个patch进行LayerNorm标准化
            nn.LayerNorm(patch_dim),

            # 线性映射到指定维度
            nn.Linear(patch_dim, dim),

            # 再次进行LayerNorm
            nn.LayerNorm(dim),
        )

        # 位置编码 + [CLS] token
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # +1是为cls token预留的
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 分类用的CLS token
        self.dropout = nn.Dropout(emb_dropout)  # 嵌入后的Dropout

        # Transformer 编码器
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool  # 池化方式 ('cls' 或 'mean')
        self.to_latent = nn.Identity()  # 占位层，保持x不变

        # 最后的分类MLP头
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        """
        前向传播
        """
        # Linear Projection of Flattened Patches
        # 将输入图像划分为patch并嵌入
        x = self.to_patch_embedding(img)  # [batch_size, num_patches, dim]

        b, n, _ = x.shape  # 获取批次大小和patch数量

         # Patch + Position + Embedding
        # 添加分类token（cls token），并拼接到patch序列的最前面
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # [batch_size, 1, dim]
        x = torch.cat((cls_tokens, x), dim=1)  # [batch_size, num_patches + 1, dim]

        # 添加位置编码
        x += self.pos_embedding[:, :(n + 1)]  # [batch_size, num_patches + 1, dim]
        x = self.dropout(x)  # 植入dropout防止过拟合

        # Transformer Encode
        x = self.transformer(x)

        # 选择池化方式
        if self.pool == 'mean':
            x = x.mean(dim=1)  # 对所有token（包括cls token）取平均
        else:
            x = x[:, 0]  # 只取cls token

        x = self.to_latent(x)  # 这里为Identity，保持x不变

        # MLP Head
        # 分类输出
        return self.mlp_head(x)
