In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


In [2]:
# 调整image size
transform = Compose([
    Resize((224, 224)),
    ToTensor()
])

img = Image.open("vitpic/1.png").convert("RGB")

x = transform(img)
x = x.unsqueeze(0)  # add batch dim
print(x.shape)  # torch.Size([1, 3, 224, 224])


torch.Size([1, 3, 224, 224])


第一步把image分割为pathces，然后将其flatten, 用einops

In [3]:
patch_size=16  # pixels
patches=rearrange(x,'b c (h s1) (w s2) -> b (h w) (s1 s2 c)',s1=patch_size,s2=patch_size)
print(patches.shape) # (batch, patch数量（224/16）^2, 每一个patch的维度（16x16x3）)

torch.Size([1, 196, 768])


In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        super().__init__()
        self.patch_size = patch_size

        # ✅ 使用 Conv2D 进行 Patch Embedding
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )

        # ✅ Class Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

        # ✅ 位置编码
        num_patches = (img_size // patch_size) ** 2
        self.positions = nn.Parameter(torch.randn(1, 197, emb_size))  # ✅ 确保 shape = [1, 197, 768]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.projection(x)  # shape: [B, 196, 768]

        # ✅ 复制 CLS Token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # shape: [B, 1, 768]
        x = torch.cat([cls_tokens, x], dim=1)  # shape: [B, 197, 768]

        # ✅ 添加 Position Encoding
        x = x + self.positions.expand(x.shape[0], -1, -1)  # ✅ 让 positions 适应 batch 维度

        return x


transformer 在vit中only encoder


In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

# patches_embedded=PatchEmbedding()(x)
#print(MultiHeadAttention()(patches_embedded).shape) # torch.Size([1, 197, 768])


# patches_embedded = PatchEmbedding()(x)  # x: [batch_size, 3, 224, 224] -> [1, 197, 768]
# mha = MultiHeadAttention()
# print(mha(patches_embedded).shape) 


直接用调库

In [6]:


# class MultiHeadAttention(nn.Module):
#     def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
#         super().__init__()
#         # 利用 PyTorch 内置的 MultiheadAttention 实现多头注意力
#         self.attention = nn.MultiheadAttention(
#             embed_dim=emb_size,
#             num_heads=num_heads,
#             dropout=dropout,
#             batch_first=True  # 确保输入输出形状为 (batch, seq, emb)
#         )

#     def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
#         # 对于 nn.MultiheadAttention, query, key, value 一般均为 x
#         # 如果提供 mask，则传递给 attn_mask 参数
#         att_output, _ = self.attention(x, x, x, attn_mask=mask)
#         return att_output



# # 测试
# patches_embedded = PatchEmbedding()(x)  # x: [batch_size, 3, 224, 224] -> [1, 197, 768]
# mha = MultiHeadAttention()
# print(mha(patches_embedded).shape)  # torch.Size([1, 197, 768])


Res

In [7]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self,x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
# class ResidualAdd(nn.Module):
#     def __init__(self, layer):
#         super().__init__()
#         self.layer = layer  # 任何传入的计算层（如 MHA 或 FFN）

#     def forward(self, x):
#         return x + self.layer(x)  # 直接残差连接


MLP

In [8]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size), #dmodel dff
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
# class FeedForwardBlock(nn.Module):
#     def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
#         super().__init__()
#         self.fc1 = nn.Linear(emb_size, expansion * emb_size)  # d_model -> d_ff
#         self.act = nn.GELU()  # 激活函数
#         self.dropout = nn.Dropout(drop_p)
#         self.fc2 = nn.Linear(expansion * emb_size, emb_size)  # d_ff -> d_model

#     def forward(self, x):
#         return self.fc2(self.dropout(self.act(self.fc1(x))))  # 线性 -> GELU -> Dropout -> 线性


Encoder Block组合

In [9]:
# class TransformerEncoderBlock(nn.Sequential):
#     def __init__(self, emb_size: int = 768, num_heads: int = 8, drop_p: float = 0., forward_expansion: int = 4):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(emb_size)
#         self.attn = ResidualAdd(MultiHeadAttention(emb_size, num_heads=num_heads))
#         self.dropout1 = nn.Dropout(drop_p)

#         self.norm2 = nn.LayerNorm(emb_size)
#         self.ffn = ResidualAdd(FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=drop_p))
#         self.dropout2 = nn.Dropout(drop_p)
# patches_embedded = PatchEmbedding()(x)
    # def forward(self, x):
    #     x = self.attn(self.norm1(x))  # MHA + 残差
    #     x = self.dropout1(x)
    #     x = self.ffn(self.norm2(x))  # FFN + 残差
    #     x = self.dropout2(x)
    #     return x
# print(TransformerEncoderBlock()(patches_embedded).shape) # torch.Size([1, 197, 768])
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
    ))
patches_embedded = PatchEmbedding()(x)
# print(TransformerEncoderBlock()(patches_embedded).shape) # torch.Size([1, 197, 768])


Encoder

In [10]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


    # def forward(self, x):
    #     for layer in self.layers:
    #         x = layer(x)  # 依次通过每个 Transformer Encoder Block
    #     return x


分类头


In [11]:
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))


In [12]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )


In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

# ✅ 训练目录
checkpoint_dir = Path("training_dir")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# ✅ 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ✅ 预处理（数据增强）
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # 防止 Patch 过度裁剪
    transforms.ToTensor(),
])

# ✅ 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root=checkpoint_dir, train=True, download=False, transform=transform)
# 修改后（测试集），和训练集保持一致
test_dataset = datasets.CIFAR10(
    root=checkpoint_dir, 
    train=False, 
    download=False, 
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # 这里可以加上与训练集一致的 Normalize
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

# ✅ **实例化模型**
model = ViT(
    in_channels=3,
    patch_size=16,  # ✅ 适应 CIFAR-10
    emb_size=768,
    img_size=224,
    depth=12,
    n_classes=10
).to(device)

# ✅ **优化器**
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

# ✅ **学习率调度器**
scheduler = CosineAnnealingLR(optimizer, T_max=50)

# ✅ 定义损失函数
criterion = nn.CrossEntropyLoss()

# ✅ **检查是否有已训练的 Checkpoint**
checkpoint_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
latest_checkpoint = checkpoint_dir / checkpoint_files[-1] if checkpoint_files else None

# ✅ **如果存在 Checkpoint，则加载**
start_epoch = 1
if latest_checkpoint:
    print(f"🔄 加载最新的 Checkpoint: {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)  # ✅ 忽略 shape 不匹配
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # 从下一个 epoch 开始训练
    print(f"✅ Checkpoint {latest_checkpoint} 加载成功，从 Epoch {start_epoch} 继续训练！")


# ✅ 训练过程
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    
    print(f"Epoch {epoch} 开始训练...")
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # ✅ 每 10 个 Batch 打印一次 Loss
        if batch_idx % 10 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}")

    scheduler.step()  # ✅ 更新学习率

    # ✅ 训练完成后，保存 Checkpoint
    checkpoint_path = checkpoint_dir / f"vit_epoch_{epoch}.pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item()
    }, checkpoint_path)
    print(f"✅ Checkpoint saved: {checkpoint_path}")

    

# ✅ 测试过程
def test(model, device, test_loader, criterion):
    print("testing")
    model.eval()
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)

            # if data.shape[1:] != (3, 224, 224):  # ✅ 检查数据 shape
            #     raise ValueError(f"Test batch shape mismatch: {data.shape}")

            output = model(data)  # ✅ 确保输入一致

            test_loss += criterion(output, target).item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n")


# ✅ 训练循环（扩展到 50 Epoch）
num_epochs = 50
for epoch in range(start_epoch, num_epochs + 1):  # ✅ 从 start_epoch 开始
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)


Using device: cuda
🔄 加载最新的 Checkpoint: training_dir/vit_epoch_1.pth
✅ Checkpoint training_dir/vit_epoch_1.pth 加载成功，从 Epoch 2 继续训练！
Epoch 2 开始训练...
Train Epoch: 2 [0/50000] Loss: 1.744868
Train Epoch: 2 [160/50000] Loss: 1.569748
Train Epoch: 2 [320/50000] Loss: 1.814432
Train Epoch: 2 [480/50000] Loss: 1.787008
Train Epoch: 2 [640/50000] Loss: 2.029443
Train Epoch: 2 [800/50000] Loss: 2.434287
Train Epoch: 2 [960/50000] Loss: 1.763276
Train Epoch: 2 [1120/50000] Loss: 1.773651
Train Epoch: 2 [1280/50000] Loss: 2.174492
Train Epoch: 2 [1440/50000] Loss: 2.105923
Train Epoch: 2 [1600/50000] Loss: 2.190749
Train Epoch: 2 [1760/50000] Loss: 2.200714
Train Epoch: 2 [1920/50000] Loss: 1.745773
Train Epoch: 2 [2080/50000] Loss: 2.010264
Train Epoch: 2 [2240/50000] Loss: 1.897064
Train Epoch: 2 [2400/50000] Loss: 2.396065
Train Epoch: 2 [2560/50000] Loss: 2.297965
Train Epoch: 2 [2720/50000] Loss: 2.259289
Train Epoch: 2 [2880/50000] Loss: 1.870618
Train Epoch: 2 [3040/50000] Loss: 1.954506
Tr