训练脚本

In [14]:
import pandas as pd
import torch
from loguru import logger
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

数据加载

In [15]:
X = pd.read_csv("Dataset/train_image_labeled.csv").values / 255.0  # 归一化
y = pd.read_csv("Dataset/train_label.csv").values.flatten()

划分训练集和验证集

In [16]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

转换为 Tensor

In [17]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

定义 MLP 模型

In [18]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)

model = MLP()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
loss_fn = nn.CrossEntropyLoss()

训练

In [19]:
for epoch in range(100):
    model.train()
    for xb, yb in train_loader:
        pred = model(xb)
        loss = loss_fn(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 验证准确率
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            pred = model(xb).argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    print(f"Epoch {epoch+1}, Accuracy: {correct/total:.4f}")

Epoch 1, Accuracy: 0.7671
Epoch 2, Accuracy: 0.7929
Epoch 3, Accuracy: 0.8233
Epoch 4, Accuracy: 0.8317
Epoch 5, Accuracy: 0.8392
Epoch 6, Accuracy: 0.8350
Epoch 7, Accuracy: 0.8396
Epoch 8, Accuracy: 0.8404
Epoch 9, Accuracy: 0.8367
Epoch 10, Accuracy: 0.8425
Epoch 11, Accuracy: 0.8275
Epoch 12, Accuracy: 0.8150
Epoch 13, Accuracy: 0.8483
Epoch 14, Accuracy: 0.8454
Epoch 15, Accuracy: 0.8550
Epoch 16, Accuracy: 0.8496
Epoch 17, Accuracy: 0.8442
Epoch 18, Accuracy: 0.8483
Epoch 19, Accuracy: 0.8500
Epoch 20, Accuracy: 0.8542
Epoch 21, Accuracy: 0.8654
Epoch 22, Accuracy: 0.8542
Epoch 23, Accuracy: 0.8425
Epoch 24, Accuracy: 0.8458
Epoch 25, Accuracy: 0.8546
Epoch 26, Accuracy: 0.8550
Epoch 27, Accuracy: 0.8567
Epoch 28, Accuracy: 0.8504
Epoch 29, Accuracy: 0.8567
Epoch 30, Accuracy: 0.8512
Epoch 31, Accuracy: 0.8617
Epoch 32, Accuracy: 0.8525
Epoch 33, Accuracy: 0.8467
Epoch 34, Accuracy: 0.8483
Epoch 35, Accuracy: 0.8617
Epoch 36, Accuracy: 0.8446
Epoch 37, Accuracy: 0.8583
Epoch 38, 