In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torch.nn.functional as F
import tkinter as tk
from PIL import Image, ImageDraw

# 定义模型类（与训练时相同）
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear = torch.nn.Linear(320,10)
        self.Conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.Conv2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.Pool1 = torch.nn.MaxPool2d(2)
    
    def forward(self, x):
        batch_size = x.size(0)
        x = F.relu(self.Pool1(self.Conv1(x)))
        x = F.relu(self.Pool1(self.Conv2(x)))
        x = x.view(batch_size, -1)
        x = self.linear(x)
        return x

# 加载训练好的模型
def load_model(model_path=r"D:\Learning\DeepLearning\code\mnist_model.pth"):
    model = Net()
    if model_path:
        print("模型已经成功加载")
        model.load_state_dict(torch.load(model_path))
    else:
        # 如果没有提供模型路径，使用随机初始化的模型
        print("警告：使用随机初始化的模型，预测结果可能不准确")
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # 设置为评估模式
    return model, device

# 预处理手写图像
def preprocess_image(image):
    # 调整大小为28x28
    image = image.resize((28, 28), Image.Resampling.LANCZOS)
    # 转换为灰度图
    image = image.convert('L')
    # 转换为numpy数组
    image_array = np.array(image)
    # 反转颜色（因为MNIST是白字黑底）
    image_array = 255 - image_array
    # 归一化
    image_array = image_array / 255.0
    # 应用与训练时相同的标准化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 转换为tensor
    image_tensor = transform(image_array.astype(np.float32))
    # 添加batch维度
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

# 创建绘图应用
class DigitRecognizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        
        # 创建主窗口
        self.root = tk.Tk()
        self.root.title("手写数字识别")
        
        # 创建画布
        self.canvas = tk.Canvas(self.root, width=280, height=280, bg='black')
        self.canvas.grid(row=0, column=0, columnspan=3, padx=10, pady=10)
        
        # 创建按钮
        self.predict_btn = tk.Button(self.root, text="预测", command=self.predict, width=10)
        self.predict_btn.grid(row=1, column=0, padx=5, pady=5)
        
        self.clear_btn = tk.Button(self.root, text="清除", command=self.clear_canvas, width=10)
        self.clear_btn.grid(row=1, column=1, padx=5, pady=5)
        
        self.quit_btn = tk.Button(self.root, text="退出", command=self.root.quit, width=10)
        self.quit_btn.grid(row=1, column=2, padx=5, pady=5)
        
        # 创建结果显示标签
        self.result_label = tk.Label(self.root, text="请绘制数字并点击预测", font=("Arial", 14))
        self.result_label.grid(row=2, column=0, columnspan=3, pady=10)
        
        # 创建概率显示区域
        self.prob_frame = tk.Frame(self.root)
        self.prob_frame.grid(row=3, column=0, columnspan=3, pady=5)
        
        self.prob_labels = []
        for i in range(10):
            label = tk.Label(self.prob_frame, text=f"{i}: 0.0%", font=("Arial", 10))
            label.grid(row=0, column=i, padx=5)
            self.prob_labels.append(label)
        
        # 初始化绘图变量
        self.image = Image.new("L", (280, 280), 0)
        self.draw = ImageDraw.Draw(self.image)
        self.last_x, self.last_y = None, None
        
        # 绑定鼠标事件
        self.canvas.bind("<B1-Motion>", self.paint)
        self.canvas.bind("<ButtonRelease-1>", self.reset)
    
    def paint(self, event):
        x, y = event.x, event.y
        if self.last_x and self.last_y:
            # 在画布上绘制
            self.canvas.create_line(self.last_x, self.last_y, x, y, 
                                   width=15, fill='white', capstyle=tk.ROUND, smooth=tk.TRUE)
            # 在图像上绘制
            self.draw.line([self.last_x, self.last_y, x, y], fill=255, width=15)
        self.last_x, self.last_y = x, y
    
    def reset(self, event):
        self.last_x, self.last_y = None, None
    
    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (280, 280), 0)
        self.draw = ImageDraw.Draw(self.image)
        self.result_label.config(text="请绘制数字并点击预测")
        for label in self.prob_labels:
            label.config(text=f"{self.prob_labels.index(label)}: 0.0%")
    
    def predict(self):
        # 预处理图像
        input_tensor = preprocess_image(self.image)
        input_tensor = input_tensor.to(self.device)
        
        # 进行预测
        with torch.no_grad():
            output = self.model(input_tensor)
            probabilities = F.softmax(output, dim=1)
            predicted = torch.argmax(output, dim=1)
        
        # 获取预测结果和概率
        predicted_num = predicted.item()
        probs = probabilities[0].cpu().numpy()
        
        # 更新结果显示
        self.result_label.config(text=f"预测结果: {predicted_num}")
        
        # 更新概率显示
        for i, label in enumerate(self.prob_labels):
            prob_percent = probs[i] * 100
            label.config(text=f"{i}: {prob_percent:.1f}%")
            
            # 高亮显示预测的数字
            if i == predicted_num:
                label.config(bg="lightgreen")
            else:
                label.config(bg=self.root.cget('bg'))
    
    def run(self):
        self.root.mainloop()

# 主程序
if __name__ == "__main__":
    # 加载模型（请替换为您的模型路径）
    model_path = 'mnist_model.pth'  
    model, device = load_model(model_path)
    
    # 创建并运行应用
    app = DigitRecognizer(model, device)
    app.run()

模型已经成功加载
