In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score

# 定义超参数
batch_size = 64
num_epochs = 5
learning_rate = 0.001

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                           download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
                                          download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

# 定义模型
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = NeuralNet(28*28, 512, 10)


# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images.reshape(-1, 28*28))
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')

# 评估模型
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images.reshape(-1, 28*28))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: ./data\MNIST\raw\train-images-idx3-ubyte.gz
Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ./data\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

100%|█████████████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<?, ?B/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch [1/5], Step [100/938], Loss: 0.3161
Epoch [1/5], Step [200/938], Loss: 0.2051
Epoch [1/5], Step [300/938], Loss: 0.3086
Epoch [1/5], Step [400/938], Loss: 0.1177
Epoch [1/5], Step [500/938], Loss: 0.1565
Epoch [1/5], Step [600/938], Loss: 0.1312
Epoch [1/5], Step [700/938], Loss: 0.1938
Epoch [1/5], Step [800/938], Loss: 0.0932
Epoch [1/5], Step [900/938], Loss: 0.1100
Epoch [2/5], Step [100/938], Loss: 0.2632
Epoch [2/5], Step [200/938], Loss: 0.0603
Epoch [2/5], Step [300/938], Loss: 0.1277
Epoch [2/5], Step [400/938], Loss: 0.1430
Epoch [2/5], Step [500/938], Loss: 0.0310
Epoch [2/5], Step [600/938], Loss: 0.0200
Epoch [2/5], Step [700/938], Loss: 0.0717
Epoch [2/5], Step [800/938], Loss: 0.0981
Epoch [2/5], Step [900/938], Loss: 0.1119
Epoch [3/5], Step [100/938], Loss: 0.1509
Epoch [3/5], Step [200/938], Loss: 0.0559
Epoch [3/5], Step [300/938], Loss: 0.0563
Epoch [3/5], Step [400/938], Loss: 0.1342
E