# MNIST 手写数字识别 - 模型推理

这个笔记本用于加载训练好的模型并进行推理。

第一个代码块 - 导入库：

In [None]:
# 导入必要的库
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from google.colab import drive
from PIL import Image

第二个代码块 - 挂载 Google Drive：

In [None]:
# 挂载 Google Drive
drive.mount('/content/drive')

第三个代码块 - 定义模型：

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Linear(32 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型实例
model = ConvNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载训练好的模型参数
model.load_state_dict(torch.load('/content/drive/MyDrive/mnist_training_results/mnist_model.pth'))
model = model.to(device)
model.eval()

第四个代码块 - 测试模型：

In [None]:
# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载测试集
test_dataset = torchvision.datasets.MNIST(root='./data',
                                        train=False,
                                        transform=transform,
                                        download=True)

# 随机选择一些图像进行展示和预测
n_samples = 5
indices = torch.randperm(len(test_dataset))[:n_samples]
fig, axes = plt.subplots(1, n_samples, figsize=(15, 3))

with torch.no_grad():
    for i, idx in enumerate(indices):
        img, label = test_dataset[idx]
        img = img.unsqueeze(0).to(device)
        output = model(img)
        pred = output.argmax(dim=1).item()
        
        # 显示图像和预测结果
        axes[i].imshow(test_dataset[idx][0].squeeze(), cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'预测: {pred}\n实际: {label}')

plt.tight_layout()
plt.show()

第五个代码块 - 混淆矩阵：

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 创建测试数据加载器
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 收集所有预测结果
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

# 计算混淆矩阵
cm = confusion_matrix(all_labels, all_preds)

# 绘制混淆矩阵
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.show()

第六个代码块 - 自定义图像测试：

In [None]:
from google.colab import files

def predict_digit(image_path):
    # 加载并预处理图像
    img = Image.open(image_path).convert('L')  # 转换为灰度图
    img = img.resize((28, 28))  # 调整大小为 28x28
    
    # 应用相同的转换
    img_tensor = transform(img)
    img_tensor = img_tensor.unsqueeze(0).to(device)
    
    # 进行预测
    with torch.no_grad():
        output = model(img_tensor)
        pred = output.argmax(dim=1).item()
        prob = torch.softmax(output, dim=1)[0]
    
    # 显示图像和预测结果
    plt.figure(figsize=(4, 4))
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.title(f'预测结果: {pred}\n置信度: {prob[pred]:.2%}')
    plt.show()

print("请上传一张手写数字图像（建议使用黑底白字的 28x28 像素图像）")
uploaded = files.upload()

for filename in uploaded.keys():
    predict_digit(filename)