#### 作业五

比较LeNet在使用和不使用批量规范化情况下, 不同学习率, 不同batch size,在fashion Mnist数据集上的训练收敛情况。

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
num_epochs = 10
batch_sizes = [64, 128, 256]
learning_rates = [0.001, 0.01, 0.1]
use_batch_norm = [True, False]

# Fashion MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

# 数据加载器
train_loaders = {batch_size: DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) for batch_size in batch_sizes}
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

# LeNet模型结构
class LeNet(nn.Module):
    def __init__(self, use_batch_norm=False):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.use_batch_norm = use_batch_norm
        if self.use_batch_norm:
            self.bn1 = nn.BatchNorm2d(6)
            self.bn2 = nn.BatchNorm2d(16)

    def forward(self, x):
        if self.use_batch_norm:
            x = self.bn1(torch.relu(self.conv1(x)))
            x = torch.max_pool2d(x, 2)
            x = self.bn2(torch.relu(self.conv2(x)))
            x = torch.max_pool2d(x, 2)
        else:
            x = torch.relu(self.conv1(x))
            x = torch.max_pool2d(x, 2)
            x = torch.relu(self.conv2(x))
            x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 训练函数
def train_model(model, train_loader, optimizer, criterion):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 评估函数
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# 训练循环
for use_bn in use_batch_norm:
    for lr in learning_rates:
        for batch_size in batch_sizes:
            model = LeNet(use_batch_norm=use_bn).to(device)
            optimizer = optim.Adam(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()

            print(f"训练LeNet{'使用' if use_bn else '不使用'}批量规范化，学习率: {lr}, 批量大小: {batch_size}")

            for epoch in range(num_epochs):
                train_loader = train_loaders[batch_size]
                train_model(model, train_loader, optimizer, criterion)
                accuracy = evaluate_model(model, test_loader)
                print(f'Epoch [{epoch+1}/{num_epochs}], 准确率: {accuracy:.4f}')

            print()


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%


Extracting ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

训练LeNet使用批量规范化，学习率: 0.001, 批量大小: 64
Epoch [1/10], 准确率: 0.8748
Epoch [2/10], 准确率: 0.8835
Epoch [3/10], 准确率: 0.8840
Epoch [4/10], 准确率: 0.8924
Epoch [5/10], 准确率: 0.8959
Epoch [6/10], 准确率: 0.8916
Epoch [7/10], 准确率: 0.8989
Epoch [8/10], 准确率: 0.8939
Epoch [9/10], 准确率: 0.8940
Epoch [10/10], 准确率: 0.8981

训练LeNet使用批量规范化，学习率: 0.001, 批量大小: 128
Epoch [1/10], 准确率: 0.8562
Epoch [2/10], 准确率: 0.8715
Epoch [3/10], 准确率: 0.8778
Epoch [4/10], 准确率: 0.8893
Epoch [5/10], 准确率: 0.8926
Epoch [6/10], 准确率: 0.8928
Epoch [7/10], 准确率: 0.8962
Epoch [8/10], 准确率: 0.8959
Epoch [9/10], 准确率: 0.8991
Epoch [10/10], 准确率: 0.8983

训练LeNet使用批量规范化，学习率: 0.001, 批量大小: 256
Epoch [1/10], 准确率: 0.8409
Epoch [2/10], 准确率: 0.8718
Epoch [3/10], 准确率: 0.8778
Epoch [4/10], 准确率: 0.8832
Epoch [5/10], 准确率: 0.8839
Epoch [6/10], 准确率: 0.8857
Epoch [7/10], 准确率: 0.8897
Epoch [8/10], 准确率: 0.8885
Epoch [9/10], 准确率: 0.8943
Epoch [10/10], 准确率: 0.8923

训练LeNet使用批量规范化，