In [9]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import einsum
import math
from einops import rearrange, repeat
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # ImageNet标准化
])
transform_train = transforms.Compose([
    transforms.Resize((40, 40)),  # 先放大
    transforms.RandomCrop(32),     # 再随机裁剪
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # ImageNet标准化
])
batch_size=4
trainset=torchvision.datasets.CIFAR10(root=r'./',train=True,download=True,transform=transform_train)
testset=torchvision.datasets.CIFAR10(root=r'./',train=False,download=True,transform=transform)
trainloader=torch.utils.data.DataLoader(dataset=trainset,batch_size=batch_size,shuffle=True,num_workers=2)
testloader=torch.utils.data.DataLoader(dataset=testset,batch_size=batch_size,shuffle=False,num_workers=2)
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

In [10]:
import torchvision.models as models
from torch.nn import MultiheadAttention
class PatchEmbedding(nn.Module):
    """将图像分割成patch并嵌入"""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
        # 可学习的位置编码
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
    def forward(self, x):
        B, C, H, W = x.shape
        # 投影到patch嵌入
        x = self.proj(x)  # [B, embed_dim, H/patch, W/patch]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # 添加cls token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        return x

In [11]:
class MultiHeadAttention(nn.Module):
    """多头自注意力机制"""
    def __init__(self, embed_dim=256, num_heads=8, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim必须能被num_heads整除"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        # 生成Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 计算注意力权重
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        # 应用注意力权重
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

In [12]:
class MLP(nn.Module):
    """多层感知机"""
    def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.1):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [13]:
class TransformerBlock(nn.Module):
    """Transformer模块"""
    def __init__(self, embed_dim=256, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = MLP(embed_dim, mlp_hidden_dim, dropout=dropout)
        
    def forward(self, x):
        # 残差连接 + 层归一化 + 注意力
        x = x + self.attn(self.norm1(x))
        # 残差连接 + 层归一化 + MLP
        x = x + self.mlp(self.norm2(x))
        return x

In [14]:
class VisionTransformer(nn.Module):
    """完整的Vision Transformer模型"""
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = self.patch_embed.cls_token
        
        # Transformer层
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # 初始化权重
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        # Patch嵌入
        x = self.patch_embed(x)
        
        # 通过Transformer层
        for block in self.blocks:
            x = block(x)
            
        # 层归一化
        x = self.norm(x)
        
        # 使用cls token进行分类
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        
        return x

In [19]:
# 创建模型实例
def create_vit_cifar10(model_name='small'):
    """创建不同规模的ViT模型"""
    configs = {
        'tiny': {'embed_dim': 192, 'depth': 6, 'num_heads': 6},
        'small': {'embed_dim': 256, 'depth': 8, 'num_heads': 8},
        'base': {'embed_dim': 384, 'depth': 12, 'num_heads': 12}
    }
    
    config = configs.get(model_name, configs['small'])
    
    model = VisionTransformer(
        img_size=32,
        patch_size=4,
        in_channels=3,
        num_classes=10,
        embed_dim=config['embed_dim'],
        depth=config['depth'],
        num_heads=config['num_heads'],
        mlp_ratio=4.0,
        dropout=0.1
    ).to(device)
    
    return model


In [21]:
import matplotlib.pyplot as plt
model = create_vit_cifar10('tiny')
LR=0.0001
criterion=nn.CrossEntropyLoss().cuda()
optimizer=optim.Adam(model.parameters(),LR,weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,     # 第一次重启的周期
    T_mult=2,   # 每次重启后周期倍增
    eta_min=1e-6  # 最小学习率
)
max_grad_norm = 1.0  # 梯度裁剪阈值
# 添加图像描述变量
image_captions = []

# 创建存储训练指标的列表
train_losses = []
val_losses = []  # 新增验证损失列表
val_accuracies = []
learning_rates = []
epochs=30
for epoch in range(epochs):
    running_loss=0.0
    total=0
    correct=0
    model.train()
    # 添加当前epoch的图像描述
    epoch_caption = f"Epoch {epoch+1}/{epochs}: Training in progress..."
    image_captions.append(epoch_caption)
    for i,data in enumerate(trainloader,0):
        inputs,labels=data
        inputs=inputs.cuda()
        labels=labels.cuda()
        optimizer.zero_grad()
        outputs=model(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        running_loss+=loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if i % 2000 == 1999:  # 每2000个batch打印一次
            train_acc = 100 * correct / total
            batch_caption = (f'Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(trainloader)}], '
                            f'Loss: {running_loss/2000:.4f}, Acc: {train_acc:.2f}%')
            print(batch_caption)
            image_captions.append(batch_caption)
            running_loss = 0.0
            total = 0
            correct = 0
            
        # 更新学习率
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)

     # 验证过程
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_acc = 100 * val_correct / val_total
    avg_val_loss = val_loss / len(testloader)
    val_losses.append(avg_val_loss)  # 记录验证损失
    val_accuracies.append(val_acc)
    
    epoch_summary = (f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, '
                    f'Accuracy: {val_acc:.2f}%, LR: {current_lr:.6f}')
    print(epoch_summary)
    image_captions.append(epoch_summary)
    train_losses.append(val_loss)


print('Finished Training')

Epoch [1/30], Batch [2000/12500], Loss: 2.1268, Acc: 19.06%
Epoch [1/30], Batch [4000/12500], Loss: 2.0096, Acc: 24.31%
Epoch [1/30], Batch [6000/12500], Loss: 1.9687, Acc: 25.77%
Epoch [1/30], Batch [8000/12500], Loss: 1.9362, Acc: 28.30%
Epoch [1/30], Batch [10000/12500], Loss: 1.8778, Acc: 30.85%
Epoch [1/30], Batch [12000/12500], Loss: 1.8916, Acc: 29.02%
Epoch 1, Validation Loss: 1.7575, Accuracy: 36.21%, LR: 0.000088
Epoch [2/30], Batch [2000/12500], Loss: 1.8340, Acc: 32.06%
Epoch [2/30], Batch [4000/12500], Loss: 1.7986, Acc: 34.30%
Epoch [2/30], Batch [6000/12500], Loss: 1.7576, Acc: 34.95%
Epoch [2/30], Batch [8000/12500], Loss: 1.7419, Acc: 36.40%
Epoch [2/30], Batch [10000/12500], Loss: 1.7966, Acc: 34.24%
Epoch [2/30], Batch [12000/12500], Loss: 1.7821, Acc: 34.74%
Epoch 2, Validation Loss: 1.6334, Accuracy: 41.22%, LR: 0.000089
Epoch [3/30], Batch [2000/12500], Loss: 1.7372, Acc: 35.90%
Epoch [3/30], Batch [4000/12500], Loss: 1.6890, Acc: 38.41%
Epoch [3/30], Batch [6000/

KeyboardInterrupt: 

In [None]:
# 保存最终模型
torch.save(model.state_dict(), 'ViT2.pth')
final_caption = "Training completed. Model saved as final_model.pth"
print(final_caption)
image_captions.append(final_caption)

# 绘制测试集准确率变化折线图
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs+1), val_accuracies, 'b-o', linewidth=2)
plt.title('Test Accuracy Over Epochs', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.xticks(range(1, epochs+1))
plt.ylim(0, 100)  # 确保y轴从0到100%

# 标记最高准确率
max_acc = max(val_accuracies)
max_epoch = val_accuracies.index(max_acc) + 1
plt.annotate(f'Max: {max_acc:.2f}%', 
             xy=(max_epoch, max_acc),
             xytext=(max_epoch+1, max_acc-5),
             arrowprops=dict(facecolor='red', shrink=0.05),
             fontsize=12)

# 保存图表
plt.savefig('ViT.png', dpi=300, bbox_inches='tight')
plt.close()

# 添加图表描述
plot_caption = ("Accuracy Plot: Shows the model's performance improvement on the test set over training epochs. "
               f"Highest accuracy of {max_acc:.2f}% achieved at epoch {max_epoch}.")
image_captions.append(plot_caption)

# 保存所有图像描述到文件
with open('training_report.txt', 'w') as f:
    for caption in image_captions:
        f.write(caption + '\n')

print("Training report and accuracy plot saved successfully.")