<a href="https://colab.research.google.com/github/Pinery-lee/dl-interview-map/blob/main/src/ViT_2020.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from functools import partial
from collections import OrderedDict
from typing import Optional, Callable

class MLPBlock(nn.Module):
    """Transformer MLP 块 (GELU 激活函数)"""
    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__()
        self.layer0 = nn.Linear(in_dim, mlp_dim)
        self.layer1 = nn.GELU()
        self.layer2 = nn.Dropout(dropout)
        self.layer3 = nn.Linear(mlp_dim, in_dim)
        self.layer4 = nn.Dropout(dropout)

    def forward(self, x):
        return self.layer4(self.layer3(self.layer2(self.layer1(self.layer0(x)))))

class EncoderBlock(nn.Module):
    """Transformer 编码器层"""
    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Self Attention 部分
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

        # MLP 部分
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        # 残差连接 1
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        # 残差连接 2
        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

class Encoder(nn.Module):
    """Transformer 编码器堆栈"""
    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))
        self.dropout = nn.Dropout(dropout)

        layers = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, norm_layer
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, x: torch.Tensor):
        x = x + self.pos_embedding
        x = self.dropout(x)
        x = self.layers(x)
        return self.ln(x)

class VisionTransformer(nn.Module):
    """纯净版 ViT 架构"""
    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        hidden_dim: int = 768,
        mlp_dim: int = 3072,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
    ):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # Patch Embedding: 使用卷积实现分块投影
        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        seq_length = (image_size // patch_size) ** 2

        # CLS Token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout
        )

        # 分类头
        head_layers = OrderedDict()
        if representation_size is None:
            head_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            head_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            head_layers["act"] = nn.Tanh()
            head_layers["head"] = nn.Linear(representation_size, num_classes)
        self.heads = nn.Sequential(head_layers)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, seq_len)
        x = x.reshape(n, self.hidden_dim, -1)
        # (n, hidden_dim, seq_len) -> (n, seq_len, hidden_dim)
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x: torch.Tensor):
        x = self._process_input(x)
        n = x.shape[0]

        # 拼接 CLS Token
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        # 进入 Transformer Encoder
        x = self.encoder(x)

        # 取第一个 token (CLS) 进行分类
        x = x[:, 0]
        x = self.heads(x)
        return x

def vit_b_16(num_classes: int = 1000, **kwargs):
    """构建 ViT-Base 16 模型"""
    return VisionTransformer(
        image_size=224,
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        num_classes=num_classes,
        **kwargs
    )

# --- 测试代码 ---
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = vit_b_16(num_classes=10).to(device)
    print("vit_b_16结构：", model)

    # 模拟一张 224x224 的三通道图片
    img = torch.randn(1, 3, 224, 224).to(device)
    output = model(img)

    print(f"输入形状: {img.shape}")
    print(f"输出形状: {output.shape}")  # 应该是 [1, 10]

vit_b_16结构： VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (layer0): Linear(in_features=768, out_features=3072, bias=True)
          (layer1): GELU(approximate='none')
          (layer2): Dropout(p=0.0, inplace=False)
          (layer3): Linear(in_features=3072, out_features=768, bias=True)
          (layer4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, ele

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. 加载 Hugging Face 的预处理器和模型
model_name = "google/vit-base-patch16-224-in21k"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
print("vit-base-patch16-224-in21k结构：", model)

# 2. 修改分类头以适配 37 个类别
# MobileNetV2 在 transformers 中的分类层叫 classifier
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, 37)
model.to(device)

# 3. 准备数据集
# 注意：使用预处理器中的均值和标准差
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

train_dataset = torchvision.datasets.OxfordIIITPet(
    root='./data', split='trainval', download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# 4. 优化器和损失函数
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

# 5. 训练循环
print("开始微调 vit-base-patch16-224-in21k...")
model.train()
for epoch in range(5):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # transformers 模型的 forward 返回的是一个 sequence 分类输出对象
        outputs = model(inputs)
        logits = outputs.logits # 获取逻辑输出

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch + 1} Average Loss: {running_loss / len(train_loader):.4f}')

print("训练结束")

Using device: cuda


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vit-base-patch16-224-in21k结构： ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bi

In [7]:
# 测试
# 加载测试集 (split='test')
test_dataset = torchvision.datasets.OxfordIIITPet(
    root='./data', split='test', download=True, transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
def evaluate_accuracy(model, data_loader, device):
    model.eval()  # 切换到评估模式（关闭 Dropout）
    correct = 0
    total = 0

    # 测试时不需要计算梯度，节省显存和计算资源
    with torch.no_grad():
        for data in data_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)

            # 取概率最大的类别作为预测结果
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'模型在测试集上的准确率: {accuracy:.2f}%')
    return accuracy

# 训练结束后调用
evaluate_accuracy(model, test_loader, device)

模型在测试集上的准确率: 91.17%


91.16925592804579

In [4]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [5]:
# 看一下在Imagenet中，AlexNet的张量变化
from torchinfo import summary
summary(model, input_size=(1,3,224,224))

Layer (type:depth-idx)                                  Output Shape              Param #
ViTForImageClassification                               [1, 37]                   --
├─ViTModel: 1-1                                         [1, 197, 768]             --
│    └─ViTEmbeddings: 2-1                               [1, 197, 768]             152,064
│    │    └─ViTPatchEmbeddings: 3-1                     [1, 196, 768]             590,592
│    │    └─Dropout: 3-2                                [1, 197, 768]             --
│    └─ViTEncoder: 2-2                                  [1, 197, 768]             --
│    │    └─ModuleList: 3-3                             --                        85,054,464
│    └─LayerNorm: 2-3                                   [1, 197, 768]             1,536
├─Linear: 1-2                                           [1, 37]                   28,453
Total params: 85,827,109
Trainable params: 85,827,109
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 200.84
