In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# ---------构建网络----------
# input: datasets
#
# return: predicted (tensor   batch size * 10)
#
# 1. conv + relu + max pooling
# 2. conv + relu + max pooling
# 3. FC + FC + log softmax
# ---------------------------
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1 , 20 , 5)
        self.conv2 = nn.Conv2d(20 , 50 , 5)
        self.FC1 = nn.Linear(50 * 4 * 4, 500)
        self.FC2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.shape[0], 4 * 4 * 50)
        x = F.relu(self.FC1(x))
        x = self.FC2(x)
        return F.log_softmax(x, dim=1)

# 下载MNIST数据集
# mnist_data = datasets.MNIST('./mnist_data', train=True, transform=transforms.Compose(
#     [transforms.ToTensor(), ]
# ))

#定义训练过程
def train(model, device, train_data, optimizer, epoch):
    for idx, (data, target) in enumerate(train_data):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.nll_loss(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0:
            print(f'Epoch: {epoch} , Iteration: {idx}, Loss: {loss.item()}')

# 定义测试过程
def test(model, device, test_data):
    total_loss = 0.
    correct = 0.
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_data):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.nll_loss(output, target, reduction='mean').item()
            pred = output.argmax(dim = 1)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    total_loss /= len(test_data.dataset)
    acc = correct / len(test_data.dataset) 
    print(f'Test Loss: {total_loss} Accuracy: {acc:.2%}')

# 选择device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义训练集
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist_data', train = True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])),
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 0, 
    pin_memory = True
    )

# 定义测试集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist_data', train = False, 
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)
        )])), 
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 0, 
    pin_memory = True
    )

# 主要参数
lr = 0.01
momentum = 0.5
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = momentum)
num_epoch = 2


if __name__ == "__main__":
    for epoch in range(num_epoch):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

# save
torch.save(model.state_dict(), 'mnist_cnn.pt')


RuntimeError: Dataset not found. You can use download=True to download it