In [1]:
# 导入依赖库
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Transformer Encoder基础组件（用于复用至ViT）
class SelfAttention(nn.Module):
    """Scaled Dot-Product Attention层：实现注意力分数计算、缩放与掩码功能"""
    def __init__(self, dim_q, dim_k, dim_v):
        super(SelfAttention, self).__init__()
        self.linear_q = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_q, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_q, dim_v, bias=False)
        self._norm_fact = 1 / math.sqrt(dim_k)  # 缩放因子，避免分数值过大导致SoftMax梯度消失
        self.dim_q = dim_q
        self.dim_k = dim_k
        self.dim_v = dim_v

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_q = x.shape
        assert dim_q == self.dim_q, f"输入维度{dim_q}与初始化dim_q{self.dim_q}不匹配"
        
        # 生成Q、K、V
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 计算注意力分数并缩放
        attn_scores = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact
        # 应用掩码（屏蔽无效位置）
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # 注意力权重归一化与上下文向量计算
        attn_weights = F.softmax(attn_scores, dim=-1)
        att = torch.bmm(attn_weights, v)
        return att, attn_weights


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention层：将输入拆分为多子空间并行计算注意力，提升模型表达能力"""
    def __init__(self, dim_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert dim_model % num_heads == 0, f"输入维度{dim_model}需能被头数{num_heads}整除"
        
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.head_dim = dim_model // num_heads  # 每个注意力头的维度
        
        # Q、K、V线性映射层
        self.linear_q = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_k = nn.Linear(dim_model, dim_model, bias=False)
        self.linear_v = nn.Linear(dim_model, dim_model, bias=False)
        # 复用Scaled Dot-Product Attention
        self.self_attn = SelfAttention(dim_q=self.head_dim, dim_k=self.head_dim, dim_v=self.head_dim)
        # 多头结果拼接后的线性变换层
        self.linear_out = nn.Linear(dim_model, dim_model, bias=False)

    def _split_heads(self, x):
        """将输入拆分为多个注意力头，确保张量内存连续以适配后续操作"""
        batch_size, seq_len, dim_model = x.shape
        return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _concat_heads(self, x):
        """将多个注意力头的输出拼接为完整维度"""
        batch_size, num_heads, seq_len, head_dim = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_model = x.shape
        assert dim_model == self.dim_model, f"输入维度{dim_model}与初始化dim_model{self.dim_model}不匹配"
        
        # Q、K、V线性映射
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        
        # 拆分多头并计算注意力
        q_split = self._split_heads(q)
        k_split = self._split_heads(k)
        v_split = self._split_heads(v)
        
        # 展平批次与头维度，适配SelfAttention输入
        q_reshaped = q_split.view(-1, seq_len, self.head_dim)
        k_reshaped = k_split.view(-1, seq_len, self.head_dim)
        v_reshaped = v_split.view(-1, seq_len, self.head_dim)
        mask_reshaped = mask.repeat(self.num_heads, 1, 1) if mask is not None else None
        
        att_split, att_weights_split = self.self_attn(q_reshaped, mask=mask_reshaped)
        
        # 拼接多头结果并线性变换
        att_reshaped = att_split.view(batch_size, self.num_heads, seq_len, self.head_dim)
        att_concat = self._concat_heads(att_reshaped)
        out = self.linear_out(att_concat)
        
        # 计算所有头的平均注意力权重
        att_weights = att_weights_split.view(batch_size, self.num_heads, seq_len, seq_len).mean(dim=1)
        return out, att_weights


class AddNorm(nn.Module):
    """Add&Norm层：实现残差连接与层归一化，稳定模型训练"""
    def __init__(self, dim_model, eps=1e-6):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(dim_model, eps=eps)  # 层归一化，避免内部协变量偏移
        self.residual_weight = nn.Parameter(torch.ones(1))  # 残差连接权重，提升灵活性

    def forward(self, x, residual):
        # 残差相加（Add）
        add_out = x + self.residual_weight * residual
        # 层归一化（Norm）
        out = self.norm(add_out)
        return out


class FeedForward(nn.Module):
    """前馈网络（FFN）：对注意力输出做非线性变换，增强模型表达能力"""
    def __init__(self, dim_model, hidden_dim=3072, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(dim_model, hidden_dim, bias=True)
        self.gelu = nn.GELU()  # GELU激活函数，相比ReLU更易捕捉非线性特征
        self.dropout = nn.Dropout(dropout)  #  dropout正则化，防止过拟合
        self.linear2 = nn.Linear(hidden_dim, dim_model, bias=True)

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        out = self.linear2(x)
        return out


class EncoderLayer(nn.Module):
    """单个Transformer Encoder层：Multi-Head Attention → Add&Norm → FFN → Add&Norm"""
    def __init__(self, dim_model, num_heads, hidden_dim=3072, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.multi_head_attn = MultiHeadAttention(dim_model=dim_model, num_heads=num_heads)
        self.add_norm1 = AddNorm(dim_model=dim_model)
        self.feed_forward = FeedForward(dim_model=dim_model, hidden_dim=hidden_dim, dropout=dropout)
        self.add_norm2 = AddNorm(dim_model=dim_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 多头注意力计算
        att_out, att_weights = self.multi_head_attn(x, mask=mask)
        att_out = self.dropout(att_out)
        # 第一次Add&Norm
        add_norm1_out = self.add_norm1(att_out, residual=x)
        
        # 前馈网络计算
        ff_out = self.feed_forward(add_norm1_out)
        ff_out = self.dropout(ff_out)
        # 第二次Add&Norm
        out = self.add_norm2(ff_out, residual=add_norm1_out)
        
        return out, att_weights


class TransformerEncoder(nn.Module):
    """完整Transformer Encoder：堆叠多个EncoderLayer，实现深度特征编码"""
    def __init__(self, dim_model=768, num_heads=12, num_layers=12, hidden_dim=3072, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        # 堆叠num_layers个EncoderLayer
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(
                dim_model=dim_model,
                num_heads=num_heads,
                hidden_dim=hidden_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        self.dim_model = dim_model

    def forward(self, x, mask=None):
        batch_size, seq_len, dim_model = x.shape
        assert dim_model == self.dim_model, f"输入维度{dim_model}与初始化dim_model{self.dim_model}不匹配"
        
        out = x
        all_att_weights = []
        # 逐层传递计算
        for encoder_layer in self.encoder_layers:
            out, att_weights = encoder_layer(out, mask=mask)
            all_att_weights.append(att_weights)
        
        # 收集所有层的注意力权重
        all_att_weights = torch.stack(all_att_weights, dim=0)
        return out, all_att_weights


# ViT核心组件：Patch Embedding（图像分块与线性投影）
class PatchEmbedding(nn.Module):
    """将图像拆分为固定尺寸的块，并通过线性投影转换为指定维度的特征序列"""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, dim_model=768):
        super(PatchEmbedding, self).__init__()
        # 确保图像尺寸能被块尺寸整除，实现无重叠分块
        assert img_size % patch_size == 0, f"图像尺寸{img_size}需能被块尺寸{patch_size}整除"
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 图像拆分后的总块数
        # 用卷积层实现“分块+线性投影”（kernel=patch_size，stride=patch_size）
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=dim_model,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # 输入x：[batch_size, in_channels, img_size, img_size]（如CIFAR-10：[B,3,32,32]）
        x = self.proj(x)  # 卷积分块：[B, dim_model, img_size/patch_size, img_size/patch_size]
        patch_emb = x.flatten(2)  # 展平块维度：[B, dim_model, num_patches]
        patch_emb = patch_emb.transpose(1, 2)  # 调整为序列格式：[B, num_patches, dim_model]
        return patch_emb


# 完整Vision Transformer（ViT）前向模型
class VisionTransformer(nn.Module):
    """Vision Transformer：将Transformer架构迁移至视觉任务，实现图像分类前向传播"""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, dim_model=768, 
                 num_heads=12, num_layers=12, hidden_dim=3072, num_classes=10, dropout=0.1):
        super(VisionTransformer, self).__init__()
        # 1. 图像块嵌入（Patch Embedding）
        self.patch_emb = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            dim_model=dim_model
        )
        self.num_patches = self.patch_emb.num_patches  # 图像拆分后的总块数
        
        # 2. 分类Token（[class] Token）：用于聚合全局特征，生成分类结果
        self.class_token = nn.Parameter(torch.zeros(1, 1, dim_model))  # 可学习参数
        # 3. 位置嵌入（Positional Embedding）：注入图像块的位置信息
        self.pos_emb = nn.Parameter(torch.zeros(1, self.num_patches + 1, dim_model))  # +1为分类Token
        self.dropout = nn.Dropout(dropout)
        
        # 4. 复用Transformer Encoder：对块特征序列做深度编码
        self.encoder = TransformerEncoder(
            dim_model=dim_model,
            num_heads=num_heads,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout=dropout
        )
        
        # 5. 分类头：基于分类Token的输出生成最终分类结果
        self.norm = nn.LayerNorm(dim_model)  # 层归一化，稳定输出分布
        self.fc = nn.Linear(dim_model, num_classes)  # 线性映射到类别数

        # 参数初始化：采用Xavier均匀初始化，避免初始梯度异常
        nn.init.xavier_uniform_(self.class_token)
        nn.init.xavier_uniform_(self.pos_emb)

    def forward(self, x):
        # 输入x：[batch_size, in_channels, img_size, img_size]（如CIFAR-10：[B,3,32,32]）
        batch_size = x.shape[0]
        
        # 步骤1：图像分块与线性投影
        patch_emb = self.patch_emb(x)  # [B, num_patches, dim_model]
        
        # 步骤2：添加分类Token（扩展到批次大小后与块特征拼接）
        class_token = self.class_token.expand(batch_size, -1, -1)  # [B,1,dim_model]
        seq = torch.cat([class_token, patch_emb], dim=1)  # [B, num_patches+1, dim_model]
        
        # 步骤3：添加位置嵌入并应用dropout
        seq = seq + self.pos_emb  # 注入位置信息
        seq = self.dropout(seq)
        
        # 步骤4：Transformer Encoder编码
        encoder_out, all_att_weights = self.encoder(seq)  # [B, num_patches+1, dim_model]
        
        # 步骤5：提取分类Token的输出，生成分类预测
        class_token_out = encoder_out[:, 0, :]  # 分类Token对应输出（第0位）
        class_token_out = self.norm(class_token_out)  # 层归一化
        logits = self.fc(class_token_out)  # [B, num_classes]（分类预测结果）
        
        return logits, all_att_weights


# CIFAR-10数据集加载（适配ViT输入需求）
def load_cifar10(batch_size=64):
    """加载CIFAR-10数据集，包含训练集数据增强与验证集标准化预处理"""
    # 训练集预处理：随机裁剪+水平翻转（数据增强）+标准化
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))  # CIFAR-10统计参数
    ])
    # 验证集预处理：仅标准化（避免数据分布偏移，确保评估客观）
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    # 加载数据集（自动下载至指定路径）
    train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )
    val_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=val_transform
    )

    # 构建数据加载器（多线程加载，提升数据读取效率）
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader


# ViT前向模型测试（验证模型完整性与维度正确性）
if __name__ == "__main__":
    # 1. 测试参数配置（适配CIFAR-10与GPU环境）
    batch_size = 2
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备：{device}")

    # 2. 加载CIFAR-10验证集（仅用于前向测试，无需训练）
    _, val_loader = load_cifar10(batch_size=batch_size)
    val_iter = iter(val_loader)
    imgs, labels = next(val_iter)  # 获取1批次测试数据
    imgs, labels = imgs.to(device), labels.to(device)

    # 3. 初始化ViT模型（采用ViT-Base简化配置，适配GPU显存）
    vit_model = VisionTransformer(
        img_size=32,    # CIFAR-10图像尺寸
        patch_size=4,   # 4×4图像块（无重叠分块）
        num_classes=10, # CIFAR-10类别数
        dim_model=768,  # 特征维度（ViT-Base标准值）
        num_heads=12,   # 注意力头数（ViT-Base标准值）
        num_layers=12   # Encoder层数（ViT-Base标准值）
    )
    vit_model = vit_model.to(device)

    # 4. 执行前向传播（禁用梯度计算，节省内存）
    vit_model.eval()
    with torch.no_grad():
        logits, all_att_weights = vit_model(imgs)

    # 5. 验证输出维度与结果合理性
    print(f"\n=== ViT前向传播测试结果 ===")
    print(f"输入图像形状：{imgs.shape}（预期：[{batch_size},3,32,32]）")
    print(f"分类预测形状：{logits.shape}（预期：[{batch_size},10]）")
    print(f"注意力权重形状：{all_att_weights.shape}（预期：[12,{batch_size},{65},{65}]）")
    print(f"预测类别：{torch.argmax(logits, dim=1)}（真实类别：{labels}）")
    print("\nViT前向模型实现完成，维度匹配预期，可正常用于后续视觉任务！")

使用设备：cuda:0

=== ViT前向传播测试结果 ===
输入图像形状：torch.Size([2, 3, 32, 32])（预期：[2,3,32,32]）
分类预测形状：torch.Size([2, 10])（预期：[2,10]）
注意力权重形状：torch.Size([12, 2, 65, 65])（预期：[12,2,65,65]）
预测类别：tensor([3, 8], device='cuda:0')（真实类别：tensor([3, 8], device='cuda:0')）

ViT前向模型实现完成，维度匹配预期，可正常用于后续视觉任务！
