In [None]:
from data_loader import load_mnist, show_image
import os
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
data_dir = os.path.join(project_root, 'data')
train_images, train_labels = load_mnist(data_dir, kind='train')
test_images, test_labels = load_mnist(data_dir, kind='t10k')
show_image(train_images, train_labels, test_images, test_labels)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import logging
from pathlib import Path

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Path("logs").mkdir(exist_ok=True)
Path("models").mkdir(exist_ok=True)

logging.basicConfig(
    filename='logs/training.log',
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)

In [None]:
from cnn_model import FashionCNN, print_model
model = FashionCNN().to(device)
print_model(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
from train import train_model, preprocess_data
train_loader, test_loader = preprocess_data(train_images, train_labels, test_images, test_labels)
# 训练模型
start_time = time.time()
train_loss, train_acc, test_loss, test_acc = train_model(
    model, train_loader, test_loader, criterion, optimizer, num_epochs=20
)
print(f'Training completed in {time.time()-start_time:.2f} seconds')

In [None]:
from plot import plot_confusion_matrix, visualize_history
visualize_history(train_loss, train_acc, test_loss, test_acc)

In [None]:
from predict import evaluate_model
# 加载最佳模型并绘制混淆矩阵
model.load_state_dict(torch.load('models/best_model.pth'))
final_loss, final_acc = evaluate_model(model, test_loader, criterion)
print(f'Best Model Test Accuracy: {final_acc:.4f}')
plot_confusion_matrix(model, test_loader)

# 验证所有保存模型的准确率
model_files = sorted(Path('models').glob('model_epoch_*.pth'))
results = []
for model_file in model_files:
    model.load_state_dict(torch.load(model_file))
    _, acc = evaluate_model(model, test_loader, criterion)
    epoch = int(model_file.stem.split('_')[-1])
    results.append((epoch, acc))
    
# 打印结果表格
from prettytable import PrettyTable
table = PrettyTable()
table.field_names = ["Epoch", "Test Accuracy"]
for epoch, acc in sorted(results):
    table.add_row([epoch, f"{acc:.4f}"])
print(table)