In [11]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 设置设备为MPS
device = torch.device("mps")

# 定义网络并移动到MPS设备
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.Tanh(),
                    nn.Linear(256, 10))
net = net.to(device)

# 初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

# 定义超参数
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 加载数据集并将数据移动到MPS设备
def load_data_fashion_mnist(batch_size):
    trans = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
    mnist_test = datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
    train_iter = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
    test_iter = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
    return train_iter, test_iter

train_iter, test_iter = load_data_fashion_mnist(batch_size)

# 训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer, device):
    net.to(device)
    print("training on", device)
    for epoch in range(num_epochs):
        metric = [0.0, 0.0, 0.0]  # 初始化为0的累加器
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            trainer.zero_grad()
            l.backward()
            trainer.step()
            metric[0] += l.item()
            metric[1] += (y_hat.argmax(dim=1) == y).sum().item()
            metric[2] += y.size(0)
        test_acc = evaluate_accuracy(net, test_iter, device)
        print(f'epoch {epoch + 1}, loss {metric[0] / metric[2]:.4f}, '
              f'train acc {metric[1] / metric[2]:.4f}, test acc {test_acc:.4f}')

# 评估函数
def evaluate_accuracy(net, data_iter, device):
    net.eval()
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            n += y.size(0)
    net.train()
    return acc_sum / n

# 调用训练函数
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer, device)

training on mps
epoch 1, loss 2460.7307, train acc 0.1173, test acc 0.1000
epoch 2, loss 1951.3880, train acc 0.1479, test acc 0.1000
epoch 3, loss 2142.0016, train acc 0.1450, test acc 0.1966
epoch 4, loss 2157.5633, train acc 0.1476, test acc 0.1948
epoch 5, loss 2024.4969, train acc 0.1527, test acc 0.1000
epoch 6, loss 1661.8179, train acc 0.1800, test acc 0.1956
epoch 7, loss 1812.3240, train acc 0.1715, test acc 0.1939
epoch 8, loss 1403.7622, train acc 0.2090, test acc 0.1782
epoch 9, loss 1333.7406, train acc 0.2210, test acc 0.2750
epoch 10, loss 1443.4080, train acc 0.1985, test acc 0.1994
