# 手写识别 (MNIST)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 下载并加载数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
for epoch in range(5):
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} finished.")

# 保存模型参数
torch.save(model.state_dict(), 'mnist_model.pth')

Epoch 1 finished.
Epoch 2 finished.
Epoch 3 finished.
Epoch 4 finished.
Epoch 5 finished.


In [16]:
# 加载模型参数
model = Net()
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5):  # 再训练5轮
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(f"继续训练 Epoch {epoch+1} 完成")

# 保存新参数
torch.save(model.state_dict(), 'mnist_model_continue.pth')

继续训练 Epoch 1 完成
继续训练 Epoch 2 完成
继续训练 Epoch 3 完成
继续训练 Epoch 4 完成
继续训练 Epoch 5 完成


In [17]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
# 定义模型结构（要和训练时一致）
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 加载模型参数
model = Net()
model.load_state_dict(torch.load('mnist_model_continue.pth'))
model.eval()

# 假设有一张手写数字图片 img.png
img = Image.open('./data/Img/8.png').convert('L')
img = Image.fromarray(255 - np.array(img))  # 反色处理
transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
img_tensor = transform(img).unsqueeze(0)  # 增加batch维度

with torch.no_grad():
    output = model(img_tensor)
    pred = torch.argmax(output, dim=1).item()
    print(f"预测数字: {pred}")

预测数字: 5
