In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import os

In [24]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # self.conv1 = nn.Conv2d(1, 1, kernel_size=5)
        self.conv2 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=5)
        # self.fc1 = nn.Linear(1024, 128, bias=True)
        # self.fc2 = nn.Linear(128, 10, bias=True)
        self.fc1 = nn.Linear(512, 10, bias=True)

    def forward(self, x):
        # x = self.conv1(x)
        x = torch.relu(nn.MaxPool2d(2)(self.conv2(x)))
        x = torch.relu(nn.MaxPool2d(2)(self.conv3(x)))
        x = x.view(-1, 512)
        # x = torch.relu(self.fc1(x))
        # x = self.fc2(x)
        x = self.fc1(x)
        return x

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [9]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

In [16]:
def train(model, train_loader, criterion, optimizer, num_epochs=10, weights_dir='./weights'):
    # 遍历每个epoch
    for epoch in range(num_epochs):
        # 将模型设置为训练模式
        model.train()
        # 初始化每个epoch的损失和正确预测数量
        running_loss = 0.0
        correct = 0
        total = 0
        # 遍历每个batch
            # 将梯度置零
        for images, labels in train_loader:
            optimizer.zero_grad()
            # 前向传播
            outputs = model(images)
            # 计算损失
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
        test(model, test_loader)#验证集输出
        # 在每个epoch结束后保存模型参数
        model_path = os.path.join(weights_dir, f'model_epoch_{epoch + 1}.pth')
        torch.save(model.state_dict(), model_path)
        print(f"Model's state_dict saved to {model_path}")

In [25]:
# 检查是否存在权重保存路径，如果不存在则创建
weights_dir = './weights'
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir)
# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=10, weights_dir=weights_dir)

# 测试模型
# test(model, test_loader)

Epoch [1/10], Loss: 0.1996, Accuracy: 94.04%
Test Accuracy: 97.58%
Model's state_dict saved to ./weights\model_epoch_1.pth
Epoch [2/10], Loss: 0.0585, Accuracy: 98.20%
Test Accuracy: 98.54%
Model's state_dict saved to ./weights\model_epoch_2.pth
Epoch [3/10], Loss: 0.0414, Accuracy: 98.71%
Test Accuracy: 98.99%
Model's state_dict saved to ./weights\model_epoch_3.pth
Epoch [4/10], Loss: 0.0338, Accuracy: 98.96%
Test Accuracy: 99.00%
Model's state_dict saved to ./weights\model_epoch_4.pth
Epoch [5/10], Loss: 0.0277, Accuracy: 99.12%
Test Accuracy: 98.96%
Model's state_dict saved to ./weights\model_epoch_5.pth
Epoch [6/10], Loss: 0.0233, Accuracy: 99.28%
Test Accuracy: 98.92%
Model's state_dict saved to ./weights\model_epoch_6.pth
Epoch [7/10], Loss: 0.0194, Accuracy: 99.38%
Test Accuracy: 99.01%
Model's state_dict saved to ./weights\model_epoch_7.pth
Epoch [8/10], Loss: 0.0168, Accuracy: 99.46%
Test Accuracy: 99.18%
Model's state_dict saved to ./weights\model_epoch_8.pth
Epoch [9/10], Lo

测试一下