In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [15]:
# 获取数据集
train_dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor(),download=True)

In [16]:
batch_size = 64

In [17]:
# 设置数据加载器
train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_data_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

In [18]:
# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784,512),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(256,10),
        )
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        y = self.layers(x)
        prob = self.softmax(y)
        return prob

In [19]:
# 实例化模型
model = Net()
# 定义损失函数和优化器
optimizer = optim.SGD(model.parameters(), lr=0.5)
ce_loss = nn.CrossEntropyLoss()

In [20]:
# 查看模型参数
for name,para in model.named_parameters():
    print(name,para.shape)

layers.0.weight torch.Size([512, 784])
layers.0.bias torch.Size([512])
layers.3.weight torch.Size([256, 512])
layers.3.bias torch.Size([256])
layers.6.weight torch.Size([10, 256])
layers.6.bias torch.Size([10])


In [21]:
# 完成所有批次的轮询，就代表一次训练的epoch结束
def train():
    model.train()
    for _,data in enumerate(train_data_loader):
        data,label = data
        train = data.view(data.shape[0],-1)
        label = F.one_hot(label,num_classes=10).float()
        pred = model(train)      
        loss = ce_loss(label,pred)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [22]:
# 测试集数据准确率
def test():
    model.eval()
    correct = 0
    total = 0
    for _,data in enumerate(test_data_loader):
        data,label = data
        data = data.view(data.shape[0],-1)
        y_pred = model(data)
        y_pred_label = torch.argmax(y_pred,dim=1)
        correct += torch.eq(y_pred_label,label).sum().item()
        total += data.shape[0]
    print('Test Accuracy: {:.2f}%'.format(100*correct/total))
    
    correct = 0
    total = 0
    for _,data in enumerate(train_data_loader):
        data,label = data
        data = data.view(data.shape[0],-1)
        y_pred = model(data)
        y_pred_label = torch.argmax(y_pred,dim=1)
        correct += torch.eq(y_pred_label,label).sum().item()
        total += data.shape[0]
    print('Train Accuracy: {:.2f}%'.format(100*correct/total))


In [26]:
for epoch in range(20):
    print("Epoch:{:2d}".format(epoch+1))
    train()
    correct_prob = test()
    print("="*20)

Epoch: 1
Test Accuracy: 97.03%
Train Accuracy: 97.37%
Epoch: 2
Test Accuracy: 97.30%
Train Accuracy: 97.72%
Epoch: 3
Test Accuracy: 97.15%
Train Accuracy: 97.78%
Epoch: 4
Test Accuracy: 97.37%
Train Accuracy: 97.86%
Epoch: 5
Test Accuracy: 97.41%
Train Accuracy: 97.92%
Epoch: 6
Test Accuracy: 97.36%
Train Accuracy: 98.07%
Epoch: 7
Test Accuracy: 97.62%
Train Accuracy: 98.22%
Epoch: 8
Test Accuracy: 97.56%
Train Accuracy: 98.22%
Epoch: 9
Test Accuracy: 97.83%
Train Accuracy: 98.38%
Epoch:10
Test Accuracy: 97.66%
Train Accuracy: 98.28%
Epoch:11
Test Accuracy: 97.69%
Train Accuracy: 98.36%
Epoch:12
Test Accuracy: 97.81%
Train Accuracy: 98.33%
Epoch:13
Test Accuracy: 97.62%
Train Accuracy: 98.33%
Epoch:14
Test Accuracy: 97.78%
Train Accuracy: 98.49%
Epoch:15
Test Accuracy: 97.72%
Train Accuracy: 98.44%
Epoch:16
Test Accuracy: 97.85%
Train Accuracy: 98.50%
Epoch:17
Test Accuracy: 97.52%
Train Accuracy: 98.34%
Epoch:18
Test Accuracy: 97.66%
Train Accuracy: 98.46%
Epoch:19
Test Accuracy: 97.8