In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib
matplotlib.use('QtAgg')  # 使用Qt后端
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
                             QLabel, QComboBox, QPushButton, QGroupBox, QSplitter, QTabWidget,
                             QProgressBar, QMessageBox, QTextEdit)
from PyQt6.QtCore import Qt, QThread, pyqtSignal
from PyQt6.QtGui import QImage, QPixmap
from PIL import Image, ImageDraw
import time

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]

# 定义模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        return x

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.4)
        
    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64*7*7)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm1 = nn.LSTM(28, 128, batch_first=True, bidirectional=False)
        self.lstm2 = nn.LSTM(128, 64, batch_first=True, bidirectional=False)
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = x.view(-1, 28, 28)  # (batch_size, seq_len, input_size)
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = x[:, -1, :]  # 取最后一个时间步的输出
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.lstm = nn.LSTM(28, 64, batch_first=True, bidirectional=True)
        self.attention = nn.Sequential(
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 1),
            nn.Softmax(dim=1)
        )
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = x.view(-1, 28, 28)  # (batch_size, seq_len, input_size)
        x, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size*2)
        attn_weights = self.attention(x)  # (batch_size, seq_len, 1)
        x = torch.sum(x * attn_weights, dim=1)  # (batch_size, hidden_size*2)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 训练线程类
class TrainThread(QThread):
    update_progress = pyqtSignal(int)
    training_message = pyqtSignal(str)
    training_complete = pyqtSignal(dict)
    
    def __init__(self, model, train_loader, test_loader, epochs=5, lr=0.001):
        super(TrainThread, self).__init__()
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.epochs = epochs
        self.lr = lr
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def run(self):
        self.model.to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        
        train_losses = []
        train_accs = []
        test_accs = []
        total_steps = len(self.train_loader) * self.epochs
        step_count = 0
        
        for epoch in range(self.epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
                
                step_count += 1
                progress = int(100 * step_count / total_steps)
                self.update_progress.emit(progress)
            
            epoch_loss = running_loss / len(self.train_loader)
            epoch_acc = 100. * correct / total
            train_losses.append(epoch_loss)
            train_accs.append(epoch_acc)
            
            # 测试集评估
            test_acc = self.evaluate()
            test_accs.append(test_acc)
            
            self.training_message.emit(f"Epoch {epoch+1}/{self.epochs}, Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%, Test Acc: {test_acc:.2f}%")
        
        # 计算模型参数数量
        total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        # 最终评估
        final_test_acc = self.evaluate()
        training_time = time.time() - self.start_time
        
        self.training_complete.emit({
            "model": self.model,
            "train_losses": train_losses,
            "train_accs": train_accs,
            "test_accs": test_accs,
            "training_time": training_time,
            "final_test_acc": final_test_acc,
            "total_params": total_params
        })
    
    def evaluate(self):
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        return 100. * correct / total
    
    def start(self):
        self.start_time = time.time()
        super(TrainThread, self).start()

# 主应用类
class MNISTApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("MNIST手写字符识别模型比较系统")
        self.setGeometry(100, 100, 1200, 800)
        
        # 初始化模型
        self.models = {
            "MLP": None,
            "CNN": None,
            "LSTM": None,
            "Attention": None
        }
        self.histories = {}
        self.model_perf = {}
        
        # 加载MNIST数据
        self.load_data()
        
        # 创建主部件
        self.main_widget = QWidget()
        self.setCentralWidget(self.main_widget)
        self.layout = QHBoxLayout(self.main_widget)
        
        # 创建左侧控制面板
        self.control_panel = QGroupBox("模型控制")
        control_layout = QVBoxLayout()
        
        self.model_combo = QComboBox()
        self.model_combo.addItems(["MLP", "CNN", "LSTM", "Attention"])
        control_layout.addWidget(QLabel("选择模型:"))
        control_layout.addWidget(self.model_combo)
        
        self.epoch_spin = QLabel("训练轮数: 5 (默认)")
        control_layout.addWidget(self.epoch_spin)
        
        self.train_btn = QPushButton("训练模型")
        self.train_btn.clicked.connect(self.train_model)
        control_layout.addWidget(self.train_btn)
        
        self.progress_bar = QProgressBar()
        self.progress_bar.setValue(0)
        control_layout.addWidget(self.progress_bar)
        
        self.log_area = QTextEdit()
        self.log_area.setReadOnly(True)
        control_layout.addWidget(QLabel("训练日志:"))
        control_layout.addWidget(self.log_area)
        
        self.compare_btn = QPushButton("比较所有模型")
        self.compare_btn.clicked.connect(self.compare_models)
        control_layout.addWidget(self.compare_btn)
        
        # 添加绘图区域
        self.canvas_label = QLabel("手写区域:")
        control_layout.addWidget(self.canvas_label)
        
        self.drawing_canvas = DrawingCanvas()
        control_layout.addWidget(self.drawing_canvas)
        
        # 添加按钮布局
        button_layout = QHBoxLayout()
        
        self.result_label = QLabel("识别结果: ")
        control_layout.addWidget(self.result_label)
        
        self.recognize_btn = QPushButton("识别数字")
        self.recognize_btn.clicked.connect(self.recognize_digit)
        button_layout.addWidget(self.recognize_btn)
        
        # 添加清空按钮
        self.clear_btn = QPushButton("清空手写区域")
        self.clear_btn.clicked.connect(self.drawing_canvas.clear)
        button_layout.addWidget(self.clear_btn)
        
        control_layout.addLayout(button_layout)
        
        control_layout.addStretch()
        self.control_panel.setLayout(control_layout)
        
        # 创建右侧可视化区域
        self.viz_tabs = QTabWidget()
        
        # 训练过程标签
        self.training_tab = QWidget()
        self.training_layout = QVBoxLayout(self.training_tab)
        self.training_figure = Figure(figsize=(8, 5))
        self.training_canvas = FigureCanvas(self.training_figure)
        self.training_layout.addWidget(self.training_canvas)
        
        # 混淆矩阵标签
        self.confusion_tab = QWidget()
        self.confusion_layout = QVBoxLayout(self.confusion_tab)
        self.confusion_figure = Figure(figsize=(8, 5))
        self.confusion_canvas = FigureCanvas(self.confusion_figure)
        self.confusion_layout.addWidget(self.confusion_canvas)
        
        # 模型比较标签
        self.comparison_tab = QWidget()
        self.comparison_layout = QVBoxLayout(self.comparison_tab)
        self.comparison_figure = Figure(figsize=(12, 5))
        self.comparison_canvas = FigureCanvas(self.comparison_figure)
        self.comparison_layout.addWidget(self.comparison_canvas)
        
        self.viz_tabs.addTab(self.training_tab, "训练过程")
        self.viz_tabs.addTab(self.confusion_tab, "混淆矩阵")
        self.viz_tabs.addTab(self.comparison_tab, "模型比较")
        
        # 添加分割器
        splitter = QSplitter(Qt.Orientation.Horizontal)
        splitter.addWidget(self.control_panel)
        splitter.addWidget(self.viz_tabs)
        splitter.setSizes([400, 800])
        
        self.layout.addWidget(splitter)
        
        # 初始状态
        self.result_label.setText("提示: 请先训练模型")
        self.log_area.append("系统已启动，请选择模型并训练")
        
    def load_data(self):
        """加载MNIST数据集"""
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST('data', train=False, transform=transform)
        
        self.train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        self.test_loader = DataLoader(test_dataset, batch_size=1000)
        
        # 保存一些样本用于可视化
        self.sample_data, self.sample_targets = next(iter(self.test_loader))
    
    def train_model(self):
        model_type = self.model_combo.currentText()
        self.log_area.clear()
        self.log_area.append(f"开始训练{model_type}模型...")
        self.progress_bar.setValue(0)
        
        # 创建模型
        if model_type == "MLP":
            self.models[model_type] = MLP()
        elif model_type == "CNN":
            self.models[model_type] = CNN()
        elif model_type == "LSTM":
            self.models[model_type] = LSTM()
        elif model_type == "Attention":
            self.models[model_type] = Attention()
        
        # 创建训练线程
        self.train_thread = TrainThread(
            self.models[model_type], 
            self.train_loader, 
            self.test_loader,
            epochs=5,  # 默认5个epochs
            lr=0.001
        )
        
        # 连接信号
        self.train_thread.update_progress.connect(self.update_progress)
        self.train_thread.training_message.connect(self.update_log)
        self.train_thread.training_complete.connect(self.on_training_complete)
        
        # 禁用按钮防止重复训练
        self.train_btn.setEnabled(False)
        self.model_combo.setEnabled(False)
        
        # 开始训练
        self.train_thread.start()
    
    def update_progress(self, value):
        self.progress_bar.setValue(value)
    
    def update_log(self, message):
        self.log_area.append(message)
    
    def on_training_complete(self, results):
        model_type = self.model_combo.currentText()
        self.models[model_type] = results["model"]
        self.histories[model_type] = results
        self.model_perf[model_type] = {
            "accuracy": results["final_test_acc"] / 100.0,
            "time": results["training_time"],
            "params": results["total_params"]
        }
        
        self.log_area.append(f"{model_type}模型训练完成!")
        self.log_area.append(f"测试准确率: {results['final_test_acc']:.2f}%")
        self.log_area.append(f"训练时间: {results['training_time']:.2f}秒")
        self.log_area.append(f"参数数量: {results['total_params']:,}")
        
        # 绘制训练历史
        self.plot_training_history(results, model_type)
        
        # 绘制混淆矩阵
        self.plot_confusion_matrix(model_type)
        
        # 启用按钮
        self.train_btn.setEnabled(True)
        self.model_combo.setEnabled(True)
        
        # 更新结果标签
        self.result_label.setText(f"{model_type}模型已训练完成! 测试准确率: {results['final_test_acc']:.2f}%")
    
    def plot_training_history(self, results, model_name):
        self.training_figure.clear()
        
        # 绘制损失曲线
        ax1 = self.training_figure.add_subplot(121)
        ax1.plot(results['train_losses'], label='训练损失')
        ax1.set_title(f'{model_name}训练损失')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('损失')
        ax1.legend()
        
        # 绘制准确率曲线
        ax2 = self.training_figure.add_subplot(122)
        ax2.plot(results['train_accs'], label='训练准确率')
        ax2.plot(results['test_accs'], label='测试准确率')
        ax2.set_title(f'{model_name}准确率')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('准确率 (%)')
        ax2.legend()
        
        self.training_figure.tight_layout()
        self.training_canvas.draw()
    
    def plot_confusion_matrix(self, model_name):
        self.confusion_figure.clear()
        model = self.models[model_name]
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
        
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, preds = torch.max(output, 1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
        
        # 计算混淆矩阵
        cm = np.zeros((10, 10), dtype=int)
        for true, pred in zip(all_targets, all_preds):
            cm[true, pred] += 1
        
        # 绘制混淆矩阵
        ax = self.confusion_figure.add_subplot(111)
        im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        ax.figure.colorbar(im, ax=ax)
        
        # 添加数值标签
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, str(cm[i, j]), 
                        ha="center", va="center", 
                        color="white" if cm[i, j] > cm.max()/2 else "black")
        
        ax.set(xticks=np.arange(cm.shape[1]),
               yticks=np.arange(cm.shape[0]),
               xticklabels=[str(i) for i in range(10)],
               yticklabels=[str(i) for i in range(10)],
               xlabel='预测标签',
               ylabel='真实标签',
               title=f'{model_name}混淆矩阵')
        
        self.confusion_figure.tight_layout()
        self.confusion_canvas.draw()
    
    def compare_models(self):
        if not self.model_perf:
            QMessageBox.warning(self, "警告", "请先训练至少一个模型!")
            return
            
        self.log_area.append("开始比较所有模型性能...")
        
        # 绘制模型比较图表
        self.plot_model_comparison()
        
        # 显示比较结果
        compare_text = "模型性能比较:\n"
        compare_text += f"{'模型':<10} | {'测试准确率':<10} | {'训练时间(s)':<12} | {'参数数量':<10}\n"
        compare_text += "-" * 50 + "\n"
        
        for model_name, perf in self.model_perf.items():
            compare_text += f"{model_name:<10} | {perf['accuracy']*100:<10.2f}% | "
            compare_text += f"{perf['time']:<12.2f} | {perf['params']:<10,}\n"
        
        self.log_area.append(compare_text)
        self.result_label.setText("模型比较完成! 查看'模型比较'标签页")
    
    def plot_model_comparison(self):
        self.comparison_figure.clear()
        
        if not self.model_perf:
            return
            
        # 模型名称和性能数据
        models = list(self.model_perf.keys())
        accuracies = [perf['accuracy'] * 100 for perf in self.model_perf.values()]
        times = [perf['time'] for perf in self.model_perf.values()]
        params = [perf['params']/1e6 for perf in self.model_perf.values()]  # 转换为百万
        
        # 创建子图
        ax1 = self.comparison_figure.add_subplot(131)
        ax2 = self.comparison_figure.add_subplot(132)
        ax3 = self.comparison_figure.add_subplot(133)
        
        # 绘制准确率比较
        bars1 = ax1.bar(models, accuracies, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
        ax1.set_title('模型准确率比较')
        ax1.set_ylabel('测试准确率 (%)')
        ax1.set_ylim(95, 100)
        for bar in bars1:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}%', ha='center', va='bottom')
        
        # 绘制训练时间比较
        bars2 = ax2.bar(models, times, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
        ax2.set_title('训练时间比较')
        ax2.set_ylabel('训练时间(秒)')
        for bar in bars2:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}s', ha='center', va='bottom')
        
        # 绘制参数数量比较
        bars3 = ax3.bar(models, params, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
        ax3.set_title('模型参数数量比较')
        ax3.set_ylabel('参数数量(百万)')
        for bar in bars3:
            height = bar.get_height()
            ax3.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}M', ha='center', va='bottom')
        
        self.comparison_figure.tight_layout()
        self.comparison_canvas.draw()
    
    def recognize_digit(self):
        model_type = self.model_combo.currentText()
        model = self.models[model_type]
        
        if model is None:
            QMessageBox.warning(self, "警告", f"请先训练{model_type}模型!")
            return
            
        # 获取绘图并预处理
        digit_img = self.drawing_canvas.get_image()
        
        # 转换为PyTorch张量
        tensor = torch.tensor(digit_img, dtype=torch.float32)
        
        # 进行预测
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
        
        with torch.no_grad():
            tensor = tensor.to(device)
            output = model(tensor)
            probs = torch.softmax(output, dim=1)
            conf, pred = torch.max(probs, 1)
            
        # 显示结果
        self.result_label.setText(f"识别结果: {pred.item()} (置信度: {conf.item():.2%})")

class DrawingCanvas(QLabel):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setMinimumSize(280, 280)
        self.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.setStyleSheet("background-color: white; border: 1px solid black;")
        self.image = Image.new("L", (280, 280), 255)
        self.draw = ImageDraw.Draw(self.image)
        self.last_point = None
        
    def mousePressEvent(self, event):
        if event.button() == Qt.MouseButton.LeftButton:
            self.last_point = event.pos()
            
    def mouseMoveEvent(self, event):
        if event.buttons() & Qt.MouseButton.LeftButton and self.last_point:
            current_point = event.pos()
            self.draw.line([self.last_point.x(), self.last_point.y(), 
                           current_point.x(), current_point.y()], 
                          fill=0, width=15)
            self.last_point = current_point
            self.update()
            
    def mouseReleaseEvent(self, event):
        self.last_point = None
        
    def paintEvent(self, event):
        super().paintEvent(event)
        qpixmap = self._get_pixmap()
        self.setPixmap(qpixmap.scaled(self.width(), self.height(), 
                                      Qt.AspectRatioMode.KeepAspectRatio, 
                                      Qt.TransformationMode.SmoothTransformation))
        
    def _get_pixmap(self):
        # 将PIL图像转换为QPixmap
        img = self.image.resize((28, 28), Image.LANCZOS).resize((280, 280), Image.NEAREST)
        data = img.tobytes("raw", "L")
        qimage = QImage(data, 280, 280, QImage.Format.Format_Grayscale8)
        return QPixmap.fromImage(qimage)
        
    def get_image(self):
        """获取28x28预处理图像，与MNIST数据集格式一致"""
        img = self.image.resize((28, 28), Image.LANCZOS)
        
        # 转换为numpy数组并归一化
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        # 反色：MNIST中背景为0，笔画为1
        img_array = 1.0 - img_array
        
        # 标准化处理，与训练数据一致
        img_array = (img_array - 0.1307) / 0.3081
        
        # 调整形状为(1, 1, 28, 28)以匹配模型输入
        return img_array.reshape(1, 1, 28, 28)
        
    def clear(self):
        """清空绘图区域"""
        self.image = Image.new("L", (280, 280), 255)
        self.draw = ImageDraw.Draw(self.image)
        self.update()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MNISTApp()
    window.show()
    sys.exit(app.exec())

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
