In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from pathlib import Path
import os

# ===== 路径设置 =====
current_path = Path("vitnew.ipynb").resolve()
root_dir = current_path.parent.parent.parent
data_dir = root_dir / "vittraining"         # 数据集目录
checkpoint_dir = root_dir / "VITtest2"        # Checkpoint 存放目录
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# ===== 模型定义 =====

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        super().__init__()
        # 使用 Conv2d 进行 patch embedding，再展平
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        # 位置编码：patch 数量 + 1 个 CLS token
        num_patches = (img_size // patch_size) ** 2
        self.positions = nn.Parameter(torch.randn(1, num_patches + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)                       # [B, num_patches, emb_size]
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # [B, 1, emb_size]
        x = torch.cat([cls_tokens, x], dim=1)          # [B, num_patches+1, emb_size]
        return x + self.positions

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0.):
        super().__init__()
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.num_heads = num_heads

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        # 融合 qkv 后拆分，并分 head
        qkv = rearrange(self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(energy.dtype).min
            energy.mask_fill_(~mask, fill_value)
        scaling = self.projection.in_features ** 0.5
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        out = torch.einsum('bhqk, bhlv -> bhql', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.projection(out)

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size)
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size: int = 768, drop_p: float = 0.,
                 forward_expansion: int = 4, forward_drop_p: float = 0.,
                 num_heads: int = 8):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads=num_heads),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            ))
        )

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, emb_size: int = 768, drop_p: float = 0.,
                 forward_expansion: int = 4, forward_drop_p: float = 0.,
                 num_heads: int = 8):
        layers = [TransformerEncoderBlock(emb_size, drop_p, forward_expansion, forward_drop_p, num_heads)
                  for _ in range(depth)]
        super().__init__(*layers)

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 10):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

class ViT(nn.Sequential):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768,
                 img_size: int = 224, depth: int = 12, n_classes: int = 10,
                 drop_p: float = 0., forward_expansion: int = 4, forward_drop_p: float = 0.,
                 num_heads: int = 8):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size, drop_p, forward_expansion, forward_drop_p, num_heads),
            ClassificationHead(emb_size, n_classes)
        )

# ===== 数据预处理与加载 =====

# 训练集采用简单的数据增强和归一化；测试集仅 Resize + 归一化
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
test_dataset  = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

# ===== 训练设置 =====

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=50)
criterion = nn.CrossEntropyLoss()

# ===== Checkpoint 加载 =====

start_epoch = 1
checkpoint_files = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pth"))
if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    print(f"加载 checkpoint: {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"从 epoch {start_epoch} 开始继续训练。")

# ===== 训练与测试函数 =====

def train_epoch(model, device, loader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(loader.dataset)}] Loss: {loss.item():.6f}")
    scheduler.step()
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch} 平均 loss: {avg_loss:.6f}")
    return avg_loss

def test_epoch(model, device, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    accuracy = 100.0 * correct / len(loader.dataset)
    print(f"Test set: 平均 loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

# ===== 主训练循环 =====

num_epochs = 50
for epoch in range(start_epoch, num_epochs + 1):
    print(f"=== Epoch {epoch} ===")
    train_epoch(model, device, train_loader, optimizer, criterion, epoch)
    test_epoch(model, device, test_loader, criterion)
    
    # 保存 checkpoint
    checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pth"
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict()
    }, checkpoint_path)
    print(f"Checkpoint 保存于 {checkpoint_path}")


=== Epoch 1 ===


  scaler = torch.cuda.amp.GradScaler()  # 混合精度梯度缩放器
  with torch.cuda.amp.autocast():


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3152x1576 and 768x768)