In [1]:
import os
import sys

# 把项目根目录加入 Python 路径（必须）
project_root = os.path.abspath("..")
sys.path.append(project_root)

print("Project root:", project_root)

from utils.seed import set_seed
set_seed(42)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from models import MLP, SimpleCNN, ResNet18, ViTTiny
from trainer import Trainer
from data import get_emnist_dataloaders

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Project root: f:\.proj\AI\MU\D2L\emnist_project
Using device: cuda


In [None]:
num_classes = 10
train_loader, val_loader, num_classes = get_emnist_dataloaders(
    data_dir="../data",
    batch_size=1,
    split="balanced"
)

In [3]:
print(f'训练例数量:{len(train_loader)}') # 存储了元组的“列表”
print(f'测试例数量:{len(val_loader )}')

训练例数量:1763
测试例数量:294


In [None]:
# 初始化模型
model = ViTTiny(num_classes=num_classes).to(device)

# 优化器 & 损失函数
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# TensorBoard
writer = SummaryWriter("../logs/ViTTiny_exp1")

# Trainer
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    writer=writer,
    num_classes=num_classes,
    model_name="ViTTiny"
)

# 训练循环
epochs = 20

for epoch in range(epochs):
    
    train_loss, train_acc = trainer.train_one_epoch(train_loader, epoch)
    val_loss, val_auc = trainer.validate(val_loader, epoch)

    print(
        f"Epoch [{epoch+1}/{epochs}] "
        f"Train Loss: {train_loss:.4f}, "
        f"Train Acc: {train_acc:.4f}, "
        f"Val Loss: {val_loss:.4f}, "
        f"Val AUC: {val_auc:.4f}"
    )



Epoch [1/20] Train Loss: 1.3374, Train Acc: 0.6109, Val Loss: 0.7531, Val AUC: 0.9915
Epoch [2/20] Train Loss: 0.7318, Train Acc: 0.7614, Val Loss: 0.5894, Val AUC: 0.9940
Epoch [3/20] Train Loss: 0.6258, Train Acc: 0.7923, Val Loss: 0.5448, Val AUC: 0.9946
Epoch [4/20] Train Loss: 0.5754, Train Acc: 0.8061, Val Loss: 0.5098, Val AUC: 0.9954
Epoch [5/20] Train Loss: 0.5344, Train Acc: 0.8161, Val Loss: 0.4811, Val AUC: 0.9957
Epoch [6/20] Train Loss: 0.5126, Train Acc: 0.8235, Val Loss: 0.4733, Val AUC: 0.9958
Epoch [7/20] Train Loss: 0.4922, Train Acc: 0.8292, Val Loss: 0.4488, Val AUC: 0.9961
