In [1]:
import torch

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(320, 50),
            torch.nn.Linear(50, 10),
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x  # 输出10维数据对应0-9的概率

model = Net()
model.load_state_dict(torch.load("model.pth", weights_only=True))
model.eval()  # 切换到评估模式

Net(
  (conv1): Sequential(
    (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=320, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [None]:
import tkinter as tk
from PIL import Image, ImageDraw, ImageOps
import torch
from torchvision import transforms

class PaintApp:
    def __init__(self, root):
        self.root = root
        self.root.title("GUI画板")

        self.main_frame = tk.Frame(root)
        self.main_frame.pack(fill=tk.BOTH, expand=True)

        self.canvas = tk.Canvas(self.main_frame, bg="white", width=140, height=140)
        self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        self.sidebar = tk.Frame(self.main_frame, width=200, bg="#f0f0f0")
        self.sidebar.pack(side=tk.RIGHT, fill=tk.Y)

        self.label = tk.Label(self.sidebar, text="预测值为0", bg="#f0f0f0", font=("Arial", 14))
        self.label.pack(pady=20)

        self.save_button = tk.Button(self.sidebar, text="推理", command=self.forward)
        self.save_button.pack(pady=10)

        self.save_button = tk.Button(self.sidebar, text="清除", command=self.clear)
        self.save_button.pack(pady=10)

        self.image = Image.new("RGB", (140, 140), "white")
        self.draw = ImageDraw.Draw(self.image)

        self.canvas.bind("<B1-Motion>", self.paint)  # 鼠标左键拖动绘画
        self.canvas.bind("<ButtonRelease-1>", self.reset_position)  # 鼠标左键拖动绘画

        self.last_x, self.last_y = None, None

    def paint(self, event):
        x, y = event.x, event.y
        if self.last_x and self.last_y:
            self.canvas.create_line(self.last_x, self.last_y, x, y, fill="black", width=16)
            self.canvas.create_oval(x - 8, y - 8, x + 8, y + 8, fill="black")
            self.canvas.create_oval(self.last_x - 8, self.last_y - 8, self.last_x + 8, self.last_y + 8, fill="black")
            self.draw.line((self.last_x, self.last_y, x, y), fill="black", width=16)
            self.draw.ellipse((x - 8, y - 8, x + 8, y + 8), outline="black")
            self.draw.ellipse((self.last_x - 8, self.last_y - 8, self.last_x + 8, self.last_y + 8), outline="black")
        self.last_x, self.last_y = x, y

    def reset_position(self, event):
        self.last_x, self.last_y = None, None

    def forward(self, event=None):
        resized_image = self.image.resize((28, 28))

        resized_image = ImageOps.invert(resized_image)
        resized_image = resized_image.convert('L')
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0], std=[1])
        ])
        tensor_image = transform(resized_image)
        
        outputs = model(tensor_image.unsqueeze(0))
        print(outputs.data)
        _, predicted = torch.max(outputs.data, dim=1)
        self.label.config(text=f"预测值为{int(predicted)}")
    def clear(self):
        self.canvas.delete("all")

        self.image = Image.new("RGB", (140, 140), "white")
        self.draw = ImageDraw.Draw(self.image)

        self.last_x, self.last_y = None, None

if __name__ == "__main__":
    root = tk.Tk()
    app = PaintApp(root)
    root.mainloop()
