# 数据集加载与预处理

In [None]:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

# 定义数据预处理
trans = transforms.ToTensor()

# 加载数据集（不重复下载）
mnist_train = torchvision.datasets.FashionMNIST(
    root="../../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../../data", train=False, transform=trans, download=False)

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [None]:
print(f'Training dataset size: {len(mnist_train)}')
print(f'Test dataset size: {len(mnist_test)}')

# 从训练加载器中获取一个 batch 的样本
for X, y in train_loader:
    print(f'Batch of data shape: {X.shape}')
    print(f'Batch of labels shape: {y.shape}')
    break

# 类标签
classes = mnist_train.classes
print(f'Classes: {classes}')

# 2 定义模型

In [None]:
import torch
from torch import nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model = MLP()

# 3 定义损失函数和优化器

In [None]:
# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 4 训练循环

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    for X, y in train_loader:
        # 清零梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(X)
        
        # 计算损失
        loss = criterion(outputs, y)
        
        # 反向传播
        loss.backward()
        
        # 更新参数
        optimizer.step()
        
        running_loss += loss.item() * X.size(0)
    
    # 打印每个 epoch 的训练损失
    epoch_loss = running_loss / len(mnist_train)
    print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
    
    # 每个 epoch 结束后评估模型
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in test_loader:
            outputs = model(X)
            _, predicted = torch.max(outputs, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
    accuracy = correct / total
    print(f'Epoch {epoch + 1}, Accuracy: {accuracy:.4f}')

print("Training Complete!")