In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# 设置随机种子以确保可复现性
torch.manual_seed(42)
np.random.seed(42)

In [2]:
# 1. 数据加载与预处理
def load_data():
    """加载并预处理MNIST数据集"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
    ])

    # 加载训练集和测试集
    train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST('data', train=False, transform=transform)

    # 划分训练集和验证集
    train_size = int(0.8 * len(train_data))
    val_size = len(train_data) - train_size
    train_set, val_set = random_split(train_data, [train_size, val_size])

    return train_set, val_set, test_data

In [3]:
# 2. 定义6层的RNN模型
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        
        self.num_layers = num_layers
        self.layer_norm = nn.LayerNorm(input_size)  # 对最后一维（feature_size）归一化
        
        self.rnn = nn.RNN(
            input_size, 
            hidden_size, 
            num_layers=num_layers,  #通过num_layers直接控制层数
            batch_first=True,
            dropout=0.3  # 仅在 num_layers>1 时生效
        )
        
        
        self.fc1 = nn.Linear(hidden_size, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        
        x = self.layer_norm(x)
        # 初始化隐藏状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.rnn.hidden_size).to(x.device)
        
        # 前向传播 RNN
        # 对每个样本的 (sequence_length, input_size) 维度做归一化
        
        out, _ = self.rnn(x, h0)

        # 取最后一个时间步的输出
        out = out[:, -1, :]

        # 全连接层
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

In [4]:
# 3. 训练模型
def train(model, train_loader, val_loader, epochs, lr):
    """训练RNN模型"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_loss = float('inf')
    best_model = None

    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss = 0
        for data, target in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]'):
            data = data.squeeze(1).to(device)  # 调整数据形状以适应RNN输入
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        train_losses.append(train_loss)

        # 验证阶段
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]'):
                data = data.squeeze(1).to(device)  # 调整数据形状以适应RNN输入
                target = target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item()

                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100. * correct / total
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        print(f'Epoch {epoch+1}/{epochs}, '
              f'Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, '
              f'Val Acc: {val_accuracy:.2f}%')

        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict().copy()
            print(f'Best model saved at epoch {epoch+1} with val loss: {val_loss:.4f}')

    # 加载最佳模型
    model.load_state_dict(best_model)

    return model, train_losses, val_losses, val_accuracies

In [5]:
# 4. 评估模型
def evaluate(model, test_loader):
    """评估模型在测试集上的性能"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    test_loss = 0
    correct = 0
    total = 0
    all_targets = []
    all_predictions = []

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for data, target in tqdm(test_loader, desc='Evaluating'):
            data = data.squeeze(1).to(device)  # 调整数据形状以适应RNN输入
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()

            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_targets.extend(target.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / total

    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')

    # 计算混淆矩阵
    cm = confusion_matrix(all_targets, all_predictions)

    return test_accuracy, cm

In [6]:
# 5. 可视化函数
def plot_training_history(train_losses, val_losses, val_accuracies):
    """绘制训练历史"""
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Validation Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

def plot_confusion_matrix(cm):
    """绘制混淆矩阵"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=list(range(10)), yticklabels=list(range(10)))
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

In [7]:
# 6. 主函数
def main(input_size=28, hidden_size=1024, num_classes=10, num_layers=4, epochs=50, lr=0.0001):

    # 加载数据
    train_set, val_set, test_data = load_data()

    # 创建数据加载器
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=128)
    test_loader = DataLoader(test_data, batch_size=128)

    # 创建模型
    model = RNN(input_size, hidden_size, num_layers, num_classes)

    # 训练模型
    model, train_losses, val_losses, val_accuracies = train(model, train_loader, val_loader, epochs, lr)

    # 评估模型
    test_accuracy, cm = evaluate(model, test_loader)

    # 可视化训练历史和混淆矩阵
    plot_training_history(train_losses, val_losses, val_accuracies)
    plot_confusion_matrix(cm)

    # 保存模型
    torch.save(model.state_dict(), 'mnist_rnn_model.pth')
    print("Model saved as 'mnist_rnn_model.pth'")

In [8]:
if __name__ == "__main__":
    main(input_size=28, hidden_size=512, num_classes=10, num_layers=4, epochs=20, lr=0.0001)

Epoch 1/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 66.21it/s]
Epoch 1/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 83.42it/s]


Epoch 1/20, Train Loss: 0.7354, Val Loss: 0.3507, Val Acc: 89.34%
Best model saved at epoch 1 with val loss: 0.3507


Epoch 2/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 67.24it/s]
Epoch 2/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 83.15it/s]


Epoch 2/20, Train Loss: 0.2825, Val Loss: 0.2183, Val Acc: 93.47%
Best model saved at epoch 2 with val loss: 0.2183


Epoch 3/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 66.63it/s]
Epoch 3/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 79.95it/s]


Epoch 3/20, Train Loss: 0.1980, Val Loss: 0.1803, Val Acc: 94.83%
Best model saved at epoch 3 with val loss: 0.1803


Epoch 4/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 64.68it/s]
Epoch 4/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 82.32it/s]


Epoch 4/20, Train Loss: 0.1512, Val Loss: 0.1748, Val Acc: 94.61%
Best model saved at epoch 4 with val loss: 0.1748


Epoch 5/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 63.93it/s]
Epoch 5/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 78.74it/s]


Epoch 5/20, Train Loss: 0.1246, Val Loss: 0.1180, Val Acc: 96.49%
Best model saved at epoch 5 with val loss: 0.1180


Epoch 6/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 63.84it/s]
Epoch 6/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 77.90it/s]


Epoch 6/20, Train Loss: 0.1068, Val Loss: 0.1105, Val Acc: 96.71%
Best model saved at epoch 6 with val loss: 0.1105


Epoch 7/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 64.65it/s]
Epoch 7/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 80.38it/s]


Epoch 7/20, Train Loss: 0.0943, Val Loss: 0.1297, Val Acc: 96.27%


Epoch 8/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 64.73it/s]
Epoch 8/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 84.71it/s]


Epoch 8/20, Train Loss: 0.0832, Val Loss: 0.0977, Val Acc: 97.28%
Best model saved at epoch 8 with val loss: 0.0977


Epoch 9/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 65.61it/s]
Epoch 9/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 82.18it/s]


Epoch 9/20, Train Loss: 0.0796, Val Loss: 0.0796, Val Acc: 97.67%
Best model saved at epoch 9 with val loss: 0.0796


Epoch 10/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 66.31it/s]
Epoch 10/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 82.17it/s]


Epoch 10/20, Train Loss: 0.0736, Val Loss: 0.0789, Val Acc: 97.69%
Best model saved at epoch 10 with val loss: 0.0789


Epoch 11/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 63.75it/s]
Epoch 11/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 81.02it/s]


Epoch 11/20, Train Loss: 0.0653, Val Loss: 0.0796, Val Acc: 97.68%


Epoch 12/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 65.25it/s]
Epoch 12/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 81.25it/s]


Epoch 12/20, Train Loss: 0.0628, Val Loss: 0.0760, Val Acc: 97.78%
Best model saved at epoch 12 with val loss: 0.0760


Epoch 13/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 64.94it/s]
Epoch 13/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 80.48it/s]


Epoch 13/20, Train Loss: 0.0610, Val Loss: 0.0726, Val Acc: 97.88%
Best model saved at epoch 13 with val loss: 0.0726


Epoch 14/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 63.63it/s]
Epoch 14/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 77.72it/s]


Epoch 14/20, Train Loss: 0.0518, Val Loss: 0.0641, Val Acc: 98.25%
Best model saved at epoch 14 with val loss: 0.0641


Epoch 15/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 63.99it/s]
Epoch 15/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 79.56it/s]


Epoch 15/20, Train Loss: 0.0517, Val Loss: 0.0781, Val Acc: 97.78%


Epoch 16/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 66.06it/s]
Epoch 16/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 80.11it/s]


Epoch 16/20, Train Loss: 0.0477, Val Loss: 0.0688, Val Acc: 98.04%


Epoch 17/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 65.70it/s]
Epoch 17/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 80.81it/s]


Epoch 17/20, Train Loss: 0.0454, Val Loss: 0.0652, Val Acc: 98.08%


Epoch 18/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 64.95it/s]
Epoch 18/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 82.22it/s]


Epoch 18/20, Train Loss: 0.0465, Val Loss: 0.0583, Val Acc: 98.31%
Best model saved at epoch 18 with val loss: 0.0583


Epoch 19/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 66.47it/s]
Epoch 19/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 83.34it/s]


Epoch 19/20, Train Loss: 0.0416, Val Loss: 0.0709, Val Acc: 98.03%


Epoch 20/20 [Train]: 100%|██████████| 375/375 [00:05<00:00, 65.45it/s]
Epoch 20/20 [Val]: 100%|██████████| 94/94 [00:01<00:00, 83.24it/s]


Epoch 20/20, Train Loss: 0.0396, Val Loss: 0.0518, Val Acc: 98.60%
Best model saved at epoch 20 with val loss: 0.0518


Evaluating: 100%|██████████| 79/79 [00:00<00:00, 80.46it/s]


Test Loss: 0.0499, Test Acc: 98.57%
Model saved as 'mnist_rnn_model.pth'
