In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [5]:
# 获取数据集
train_dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor(),download=True)

In [6]:
batch_size = 64

In [7]:
# 设置数据加载器
train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_data_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

In [8]:
# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(784,10)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        y = self.fc1(x)
        prob = self.softmax(y)
        return prob

In [9]:
# 实例化模型
model = Net()
# 定义损失函数和优化器
optimizer = optim.SGD(model.parameters(), lr=0.5)
ce_loss = nn.CrossEntropyLoss()

In [10]:
# 查看模型参数
for name,para in model.named_parameters():
    print(name,para.shape)

fc1.weight torch.Size([10, 784])
fc1.bias torch.Size([10])


In [11]:
# 完成所有批次的轮询，就代表一次训练的epoch结束
def train():
    for _,data in enumerate(train_data_loader):
        data,label = data
        train = data.view(data.shape[0],-1)
        label = F.one_hot(label,num_classes=10).float()
        pred = model(train)      
        loss = ce_loss(label,pred)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [12]:
# 测试集数据准确率
def test():
    correct = 0
    total = 0
    for _,data in enumerate(test_data_loader):
        data,label = data
        data = data.view(data.shape[0],-1)
        y_pred = model(data)
        y_pred_label = torch.argmax(y_pred,dim=1)
        correct += torch.eq(y_pred_label,label).sum().item()
        total += data.shape[0]
    return correct/total

In [14]:
for epoch in range(20):
    train()
    correct_prob = test()
    print('Epoch: {}, Correct Prob: {:.4f}'.format(epoch+1, correct_prob))

Epoch: 1, Correct Prob: 0.9217
Epoch: 2, Correct Prob: 0.9220
Epoch: 3, Correct Prob: 0.9227
Epoch: 4, Correct Prob: 0.9244
Epoch: 5, Correct Prob: 0.9240
Epoch: 6, Correct Prob: 0.9251
Epoch: 7, Correct Prob: 0.9250
Epoch: 8, Correct Prob: 0.9255
Epoch: 9, Correct Prob: 0.9260
Epoch: 10, Correct Prob: 0.9262
Epoch: 11, Correct Prob: 0.9252
Epoch: 12, Correct Prob: 0.9269
Epoch: 13, Correct Prob: 0.9251
Epoch: 14, Correct Prob: 0.9259
Epoch: 15, Correct Prob: 0.9268
Epoch: 16, Correct Prob: 0.9259
Epoch: 17, Correct Prob: 0.9286
Epoch: 18, Correct Prob: 0.9285
Epoch: 19, Correct Prob: 0.9275
Epoch: 20, Correct Prob: 0.9280
