In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
import time
from dataclasses import dataclass

In [4]:
pwd

'/Users/ahs/Library/Mobile Documents/com~apple~CloudDocs/深度学习与神经网络/PJ2/codes'

In [5]:
@dataclass
class TrainingConfig:
    batch_size: int = 128
    num_epochs: int = 50
    learning_rate: float = 0.001
    weight_decay: float = 1e-4
    data_path: str = 'data'
    checkpoint_dir: str = 'checkpoints'
    num_workers: int = 2

def set_seed(seed=42):
    """设置随机种子以确保可重复性"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

def get_device():
    """获取可用的设备"""
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        print('Using MPS (Metal Performance Shaders)')
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        print('Using CUDA')
    else:
        device = torch.device('cpu')
        print('Using CPU')
    return device

def get_transforms():
    """获取数据预处理转换"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    return transform_train, transform_test

def load_data(config: TrainingConfig):
    """加载数据集"""
    transform_train, transform_test = get_transforms()
    
    trainset = torchvision.datasets.CIFAR10(
        root=config.data_path, 
        train=True, 
        download=False,
        transform=transform_train)
    
    trainloader = DataLoader(
        trainset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=config.num_workers)
    
    testset = torchvision.datasets.CIFAR10(
        root=config.data_path, 
        train=False, 
        download=False,
        transform=transform_test)
    testloader = DataLoader(
        testset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=config.num_workers)
    
    return trainloader, testloader

def generate_alexnet():
    """生成AlexNet模型"""
    return nn.Sequential(
        # 第一层卷积：32x32 -> 16x16
        nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        # 第二层卷积：16x16 -> 8x8
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        # 第三层卷积：8x8 -> 4x4
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        # 第四层卷积：4x4 -> 4x4
        nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        
        # 第五层卷积：4x4 -> 2x2
        nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Flatten(),
        # 2x2x512 = 2048
        nn.Linear(2048, 1024), nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(1024, 512), nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(512, 10)
    )

def train_epoch(model, train_loader, criterion, optimizer, device):
    """训练一个epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        pbar.set_postfix({'loss': running_loss/total, 'acc': 100.*correct/total})
    
    return running_loss/len(train_loader), 100.*correct/total

def validate(model, val_loader, criterion, device):
    """验证模型"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return running_loss/len(val_loader), 100.*correct/total

def plot_training_curves(train_losses, train_accs, val_losses, val_accs):
    """绘制训练曲线"""
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

def save_checkpoint(model, optimizer, epoch, val_acc, config: TrainingConfig):
    """保存检查点"""
    if not os.path.exists(config.checkpoint_dir):
        os.makedirs(config.checkpoint_dir)
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
    }, f'{config.checkpoint_dir}/best_model.pth')

def train(config: TrainingConfig):
    """主训练函数"""
    # 设置设备
    device = get_device()
    
    # 设置随机种子
    set_seed()
    
    # 加载数据
    trainloader, testloader = load_data(config)
    
    # 创建模型
    model = generate_alexnet().to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5)
    
    # 训练记录
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    best_acc = 0
    
    # 训练循环
    for epoch in range(config.num_epochs):
        print(f'\nEpoch {epoch+1}/{config.num_epochs}')
        
        # 训练
        train_loss, train_acc = train_epoch(
            model, trainloader, criterion, optimizer, device)
        
        # 验证
        val_loss, val_acc = validate(model, testloader, criterion, device)
        
        # 更新学习率
        old_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_acc)
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr != old_lr:
            print(f'Learning rate changed from {old_lr:.6f} to {new_lr:.6f}')
        
        # 记录结果
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(model, optimizer, epoch, val_acc, config)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # 绘制训练曲线
    plot_training_curves(train_losses, train_accs, val_losses, val_accs)

In [6]:
config = TrainingConfig()
train(config)

Using MPS (Metal Performance Shaders)

Epoch 1/50


Training: 100%|██████████| 391/391 [00:50<00:00,  7.76it/s, loss=0.0145, acc=27.2]


Train Loss: 1.8570, Train Acc: 27.16%
Val Loss: 1.5874, Val Acc: 40.12%

Epoch 2/50


Training: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s, loss=0.0112, acc=46.6]


Train Loss: 1.4383, Train Acc: 46.65%
Val Loss: 1.2480, Val Acc: 53.91%

Epoch 3/50


Training: 100%|██████████| 391/391 [00:40<00:00,  9.56it/s, loss=0.00949, acc=56.5]


Train Loss: 1.2130, Train Acc: 56.49%
Val Loss: 1.0800, Val Acc: 61.25%

Epoch 4/50


Training: 100%|██████████| 391/391 [00:40<00:00,  9.72it/s, loss=0.00837, acc=62]  


Train Loss: 1.0703, Train Acc: 62.05%
Val Loss: 0.9537, Val Acc: 66.12%

Epoch 5/50


Training: 100%|██████████| 391/391 [00:41<00:00,  9.45it/s, loss=0.00756, acc=66.2]


Train Loss: 0.9666, Train Acc: 66.19%
Val Loss: 0.8877, Val Acc: 68.68%

Epoch 6/50


Training: 100%|██████████| 391/391 [00:45<00:00,  8.63it/s, loss=0.00699, acc=69.4]


Train Loss: 0.8936, Train Acc: 69.36%
Val Loss: 0.8383, Val Acc: 71.24%

Epoch 7/50


Training: 100%|██████████| 391/391 [00:45<00:00,  8.64it/s, loss=0.00649, acc=71.8]


Train Loss: 0.8297, Train Acc: 71.75%
Val Loss: 0.7735, Val Acc: 72.41%

Epoch 8/50


Training: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s, loss=0.0061, acc=73.3] 


Train Loss: 0.7806, Train Acc: 73.27%
Val Loss: 0.7648, Val Acc: 73.56%

Epoch 9/50


Training: 100%|██████████| 391/391 [01:09<00:00,  5.60it/s, loss=0.00579, acc=74.8]


Train Loss: 0.7402, Train Acc: 74.83%
Val Loss: 0.7316, Val Acc: 74.81%

Epoch 10/50


Training: 100%|██████████| 391/391 [01:04<00:00,  6.08it/s, loss=0.00557, acc=75.8]


Train Loss: 0.7127, Train Acc: 75.78%
Val Loss: 0.6780, Val Acc: 76.68%

Epoch 11/50


Training: 100%|██████████| 391/391 [00:49<00:00,  7.90it/s, loss=0.00529, acc=77.1]


Train Loss: 0.6763, Train Acc: 77.09%
Val Loss: 0.6815, Val Acc: 76.76%

Epoch 12/50


Training: 100%|██████████| 391/391 [00:46<00:00,  8.49it/s, loss=0.00512, acc=77.8]


Train Loss: 0.6553, Train Acc: 77.75%
Val Loss: 0.6390, Val Acc: 77.98%

Epoch 13/50


Training: 100%|██████████| 391/391 [00:57<00:00,  6.77it/s, loss=0.005, acc=78.5]  


Train Loss: 0.6393, Train Acc: 78.46%
Val Loss: 0.6286, Val Acc: 78.61%

Epoch 14/50


Training: 100%|██████████| 391/391 [00:59<00:00,  6.62it/s, loss=0.00487, acc=78.8]


Train Loss: 0.6224, Train Acc: 78.82%
Val Loss: 0.6193, Val Acc: 78.79%

Epoch 15/50


Training: 100%|██████████| 391/391 [01:32<00:00,  4.22it/s, loss=0.00476, acc=79.3]


Train Loss: 0.6091, Train Acc: 79.29%
Val Loss: 0.6421, Val Acc: 78.04%

Epoch 16/50


Training: 100%|██████████| 391/391 [01:20<00:00,  4.89it/s, loss=0.00456, acc=80.3]


Train Loss: 0.5832, Train Acc: 80.33%
Val Loss: 0.5935, Val Acc: 79.31%

Epoch 17/50


Training: 100%|██████████| 391/391 [01:06<00:00,  5.89it/s, loss=0.00447, acc=80.6]


Train Loss: 0.5712, Train Acc: 80.64%
Val Loss: 0.6499, Val Acc: 77.74%

Epoch 18/50


Training: 100%|██████████| 391/391 [00:48<00:00,  8.00it/s, loss=0.00438, acc=80.9]


Train Loss: 0.5603, Train Acc: 80.87%
Val Loss: 0.5723, Val Acc: 80.13%

Epoch 19/50


Training: 100%|██████████| 391/391 [01:07<00:00,  5.78it/s, loss=0.00437, acc=81.2]


Train Loss: 0.5586, Train Acc: 81.16%
Val Loss: 0.5811, Val Acc: 80.39%

Epoch 20/50


Training: 100%|██████████| 391/391 [01:21<00:00,  4.79it/s, loss=0.00417, acc=81.9]


Train Loss: 0.5335, Train Acc: 81.88%
Val Loss: 0.5932, Val Acc: 80.34%

Epoch 21/50


Training: 100%|██████████| 391/391 [00:52<00:00,  7.49it/s, loss=0.00415, acc=82.1]


Train Loss: 0.5312, Train Acc: 82.06%
Val Loss: 0.5521, Val Acc: 81.16%

Epoch 22/50


Training: 100%|██████████| 391/391 [00:41<00:00,  9.44it/s, loss=0.00403, acc=82.4]


Train Loss: 0.5149, Train Acc: 82.44%
Val Loss: 0.5670, Val Acc: 80.69%

Epoch 23/50


Training: 100%|██████████| 391/391 [00:43<00:00,  9.00it/s, loss=0.00399, acc=82.7]


Train Loss: 0.5104, Train Acc: 82.69%
Val Loss: 0.5573, Val Acc: 80.96%

Epoch 24/50


Training: 100%|██████████| 391/391 [00:53<00:00,  7.25it/s, loss=0.00396, acc=83.1]


Train Loss: 0.5061, Train Acc: 83.10%
Val Loss: 0.5407, Val Acc: 81.84%

Epoch 25/50


Training: 100%|██████████| 391/391 [00:55<00:00,  7.01it/s, loss=0.00392, acc=83.1]


Train Loss: 0.5007, Train Acc: 83.11%
Val Loss: 0.5591, Val Acc: 81.02%

Epoch 26/50


Training: 100%|██████████| 391/391 [00:40<00:00,  9.58it/s, loss=0.00386, acc=83.4]


Train Loss: 0.4932, Train Acc: 83.38%
Val Loss: 0.5429, Val Acc: 81.37%

Epoch 27/50


Training: 100%|██████████| 391/391 [00:45<00:00,  8.59it/s, loss=0.00371, acc=83.8]


Train Loss: 0.4741, Train Acc: 83.77%
Val Loss: 0.5650, Val Acc: 81.02%

Epoch 28/50


Training: 100%|██████████| 391/391 [00:51<00:00,  7.56it/s, loss=0.0037, acc=83.9] 


Train Loss: 0.4731, Train Acc: 83.95%
Val Loss: 0.5483, Val Acc: 82.04%

Epoch 29/50


Training: 100%|██████████| 391/391 [00:43<00:00,  8.94it/s, loss=0.00365, acc=84.3]


Train Loss: 0.4671, Train Acc: 84.30%
Val Loss: 0.5364, Val Acc: 82.05%

Epoch 30/50


Training: 100%|██████████| 391/391 [00:57<00:00,  6.86it/s, loss=0.0036, acc=84.6] 


Train Loss: 0.4599, Train Acc: 84.57%
Val Loss: 0.5500, Val Acc: 81.51%

Epoch 31/50


Training: 100%|██████████| 391/391 [00:54<00:00,  7.11it/s, loss=0.00359, acc=84.4]


Train Loss: 0.4586, Train Acc: 84.36%
Val Loss: 0.5286, Val Acc: 82.46%

Epoch 32/50


Training: 100%|██████████| 391/391 [00:53<00:00,  7.31it/s, loss=0.00347, acc=84.9]


Train Loss: 0.4435, Train Acc: 84.90%
Val Loss: 0.5501, Val Acc: 81.97%

Epoch 33/50


Training: 100%|██████████| 391/391 [00:48<00:00,  8.14it/s, loss=0.00351, acc=84.8]


Train Loss: 0.4487, Train Acc: 84.84%
Val Loss: 0.5428, Val Acc: 82.11%

Epoch 34/50


Training: 100%|██████████| 391/391 [00:47<00:00,  8.16it/s, loss=0.00346, acc=84.9]


Train Loss: 0.4428, Train Acc: 84.94%
Val Loss: 0.5422, Val Acc: 81.84%

Epoch 35/50


Training: 100%|██████████| 391/391 [00:55<00:00,  7.07it/s, loss=0.00343, acc=85.2]


Train Loss: 0.4381, Train Acc: 85.23%
Val Loss: 0.5511, Val Acc: 81.93%

Epoch 36/50


Training: 100%|██████████| 391/391 [00:47<00:00,  8.21it/s, loss=0.00337, acc=85.4]


Train Loss: 0.4305, Train Acc: 85.36%
Val Loss: 0.5280, Val Acc: 82.03%

Epoch 37/50


Training: 100%|██████████| 391/391 [00:50<00:00,  7.68it/s, loss=0.00337, acc=85.5]


Train Loss: 0.4315, Train Acc: 85.46%
Val Loss: 0.5293, Val Acc: 82.48%

Epoch 38/50


Training: 100%|██████████| 391/391 [00:58<00:00,  6.65it/s, loss=0.00328, acc=85.8]


Train Loss: 0.4199, Train Acc: 85.76%
Val Loss: 0.5091, Val Acc: 83.35%

Epoch 39/50


Training: 100%|██████████| 391/391 [00:50<00:00,  7.71it/s, loss=0.0033, acc=85.6] 


Train Loss: 0.4223, Train Acc: 85.61%
Val Loss: 0.5436, Val Acc: 82.57%

Epoch 40/50


Training: 100%|██████████| 391/391 [00:47<00:00,  8.31it/s, loss=0.00329, acc=85.8]


Train Loss: 0.4207, Train Acc: 85.78%
Val Loss: 0.5316, Val Acc: 82.73%

Epoch 41/50


Training: 100%|██████████| 391/391 [01:08<00:00,  5.69it/s, loss=0.00325, acc=86]  


Train Loss: 0.4158, Train Acc: 86.01%
Val Loss: 0.5407, Val Acc: 82.65%

Epoch 42/50


Training: 100%|██████████| 391/391 [01:02<00:00,  6.23it/s, loss=0.00317, acc=86.4]


Train Loss: 0.4055, Train Acc: 86.45%
Val Loss: 0.5330, Val Acc: 82.89%

Epoch 43/50


Training: 100%|██████████| 391/391 [00:54<00:00,  7.14it/s, loss=0.00322, acc=86.2]


Train Loss: 0.4111, Train Acc: 86.22%
Val Loss: 0.5588, Val Acc: 81.70%

Epoch 44/50


Training: 100%|██████████| 391/391 [00:55<00:00,  7.07it/s, loss=0.00317, acc=86.6]


Learning rate changed from 0.001000 to 0.000500
Train Loss: 0.4051, Train Acc: 86.61%
Val Loss: 0.5331, Val Acc: 82.71%

Epoch 45/50


Training: 100%|██████████| 391/391 [01:12<00:00,  5.41it/s, loss=0.00266, acc=88.5]


Train Loss: 0.3400, Train Acc: 88.49%
Val Loss: 0.4922, Val Acc: 84.24%

Epoch 46/50


Training: 100%|██████████| 391/391 [01:05<00:00,  5.96it/s, loss=0.00253, acc=89.1]


Train Loss: 0.3234, Train Acc: 89.13%
Val Loss: 0.4893, Val Acc: 84.55%

Epoch 47/50


Training: 100%|██████████| 391/391 [01:03<00:00,  6.18it/s, loss=0.00251, acc=89]  


Train Loss: 0.3215, Train Acc: 88.96%
Val Loss: 0.4987, Val Acc: 84.45%

Epoch 48/50


Training: 100%|██████████| 391/391 [01:19<00:00,  4.90it/s, loss=0.00243, acc=89.5]


Train Loss: 0.3110, Train Acc: 89.52%
Val Loss: 0.5009, Val Acc: 84.36%

Epoch 49/50


Training: 100%|██████████| 391/391 [00:54<00:00,  7.15it/s, loss=0.00241, acc=89.5]


Train Loss: 0.3083, Train Acc: 89.49%
Val Loss: 0.4864, Val Acc: 84.72%

Epoch 50/50


Training: 100%|██████████| 391/391 [00:59<00:00,  6.54it/s, loss=0.00232, acc=89.9]


Train Loss: 0.2961, Train Acc: 89.91%
Val Loss: 0.4792, Val Acc: 85.10%
