1. 导入库

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
import numpy as np

2. 定义超参数

In [2]:
# Define hyperparameters
num_classes = 37  # Oxford-IIIT Pet Dataset has 37 categories
image_size = 128  # 输入图像大小
patch_size = 16  # 图像分割为16x16的patch
num_patches = (image_size // patch_size) ** 2  # 每张图像的patch数量
projection_dim = 64  # 投影维度
num_heads = 4  # 注意力头数
transformer_units = [projection_dim * 2, projection_dim]  # Transformer层的单位
transformer_layers = 8  # Transformer层数
mlp_head_units = [2048, 1024]  # MLP头的单位
learning_rate = 0.001  # 学习率
weight_decay = 0.0001 # 权重衰减
batch_size = 32 # 批量大小
num_epochs = 50 # 训练轮数

3. 准备数据

In [None]:
# Data preparation
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = OxfordIIITPet(root='./data', split='trainval', target_types='category', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = OxfordIIITPet(root='./data', split='test', target_types='category', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

  9%|▉         | 69.9M/792M [00:50<02:09, 5.58MB/s] 

4. 定义MLP

In [None]:
# Define MLP
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

5. 定义补丁嵌入

In [None]:
# Define Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.projection(x).flatten(2).transpose(1, 2)

6. 定义Transformer块

In [None]:
# Define Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = MLP([[embed_dim, mlp_hidden_dim], [mlp_hidden_dim, embed_dim]], dropout)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.attention(x, x, x)[0]
        x = self.norm1(x + attention_output)
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

7. 定义Vision Transformer

In [None]:
# Define Vision Transformer (ViT)
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, embed_dim, num_heads, transformer_units, num_layers, mlp_head_units):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches, embed_dim))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, transformer_units[0], 0.1)
            for _ in range(num_layers)
        ])
        self.mlp_head = MLP([[embed_dim, mlp_head_units[0]], [mlp_head_units[0], mlp_head_units[1]]], 0.5)
        self.classifier = nn.Linear(mlp_head_units[1], num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x + self.position_embedding
        for block in self.transformer_blocks:
            x = block(x)
        x = x.mean(dim=1)
        x = self.mlp_head(x)
        x = self.classifier(x)
        return x

8. 训练模型

In [None]:
# Instantiate the model
model = VisionTransformer(
    num_classes=num_classes,
    image_size=image_size,
    patch_size=patch_size,
    embed_dim=projection_dim,
    num_heads=num_heads,
    transformer_units=transformer_units,
    num_layers=transformer_layers,
    mlp_head_units=mlp_head_units
)

# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(trainloader):
        images, labels = images.to('cuda'), labels.to('cuda')
        model = model.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")

9. 评估模型

In [None]:
# Evaluation loop with visualization
model.eval()
correct = 0
total = 0

with torch.no_grad():
    images, labels = next(iter(testloader))
    images, labels = images.to('cuda'), labels.to('cuda')
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # Visualize predictions
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        img = images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5  # Unnormalize
        axes[i].imshow(img)
        axes[i].set_title(f"Predicted: {trainset.classes[predicted[i]]}\nActual: {trainset.classes[labels[i]]}")
        axes[i].axis("off")
    plt.show()

    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")