In [None]:
import torch
import pandas as pd
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# 读取数据
fashion_mnist_train = pd.read_csv("data/fashion-mnist_train.csv")
fashion_mnist_test = pd.read_csv("data/fashion-mnist_test.csv")
# 将数据转换为张量
X_train = torch.tensor(fashion_mnist_train.iloc[:, 1:].values, dtype=torch.float32).reshape(-1, 1, 28, 28)
y_train = torch.tensor(fashion_mnist_train.iloc[:, 0].values, dtype=torch.int64)
X_test = torch.tensor(fashion_mnist_test.iloc[:, 1:].values, dtype=torch.float32).reshape(-1, 1, 28, 28)
y_test = torch.tensor(fashion_mnist_test.iloc[:, 0].values, dtype=torch.int64)
# 构建数据集
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

In [None]:
# 搭建模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.Sigmoid(),
    nn.Linear(84, 10),
)

In [None]:
# 模型训练
def train(net, train_dataset, test_dataset, lr, epoch_num, batch_size, device):
    def init_weights(layer):
        if type(layer) == nn.Linear or type(layer) == nn.Conv2d:
            nn.init.xavier_uniform_(layer.weight)

    net.apply(init_weights)
    net.to(device)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    for epoch in range(epoch_num):
        net.train()
        for X, y in DataLoader(train_dataset, batch_size, shuffle=True):
            X, y = X.to(device), y.to(device)
            y_pred = net(X)
            loss_value = loss(y_pred, y)
            loss_value.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"epoch:{epoch} loss:{loss_value.item()}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train(net, train_dataset, test_dataset, lr=0.9, epoch_num=10, batch_size=128, device="cuda")