In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms

In [2]:
batch_size=200
learning_rate=0.01
epochs=10

# 加载数据
train_loader=torch.utils.data.DataLoader(
    datasets.MNIST('./data/mnist_data',train=True,download=True,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)
test_loader=torch.utils.data.DataLoader(
    datasets.MNIST('./data/mnist_data',train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)

# 设置 w,b
w1,b1=torch.randn(200,784,requires_grad=True),torch.zeros(200,requires_grad=True)
w2,b2=torch.randn(200,200,requires_grad=True),torch.zeros(200,requires_grad=True)
w3,b3=torch.randn(10,200,requires_grad=True),torch.zeros(10,requires_grad=True)

# 凯明初始化
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

tensor([[-0.0423,  0.0632, -0.0175,  ...,  0.0806, -0.0335,  0.1036],
        [-0.0469,  0.0462, -0.0763,  ...,  0.0747,  0.0031,  0.0391],
        [-0.0893,  0.0443, -0.0561,  ..., -0.0358, -0.0032,  0.0118],
        ...,
        [ 0.0126, -0.0641,  0.0088,  ...,  0.0512,  0.0392, -0.1226],
        [ 0.0011, -0.0091,  0.0948,  ..., -0.0441, -0.0434, -0.0412],
        [ 0.0497, -0.0045,  0.0017,  ...,  0.0639, -0.0478, -0.0056]],
       requires_grad=True)

tensor([[-0.1636, -0.0047, -0.0737,  ..., -0.0167,  0.0511, -0.0890],
        [ 0.0469,  0.2278,  0.0908,  ..., -0.0747, -0.1335, -0.0593],
        [-0.0363, -0.0550, -0.0865,  ..., -0.0886,  0.1190,  0.1841],
        ...,
        [ 0.0617, -0.0483, -0.0846,  ...,  0.0553, -0.1021,  0.0264],
        [ 0.0164,  0.0867,  0.1176,  ...,  0.0875, -0.1059, -0.0507],
        [-0.1996, -0.0749,  0.0290,  ..., -0.0624,  0.0489,  0.0447]],
       requires_grad=True)

tensor([[ 0.0585,  0.0637,  0.1080,  ..., -0.1817, -0.0669, -0.0616],
        [ 0.1255,  0.0199, -0.1292,  ...,  0.0520,  0.0520, -0.0572],
        [ 0.1271,  0.1009, -0.1160,  ..., -0.0662,  0.0811,  0.0907],
        ...,
        [ 0.0328, -0.0123, -0.0406,  ..., -0.0608,  0.0188,  0.0404],
        [ 0.1325, -0.0342,  0.1442,  ..., -0.0628,  0.0576, -0.2505],
        [ 0.0906,  0.0835,  0.1371,  ..., -0.0644,  0.0615, -0.0891]],
       requires_grad=True)

In [3]:
# 前向计算
def forward(x):
    x=x@w1.t()+b1
    x=F.relu(x)
    x=x@w2.t()+b2
    x=F.relu(x)
    x=x@w3.t()+b3
    x=F.relu(x)
    return x

In [4]:
optimizer=optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon=nn.CrossEntropyLoss()

for epoch in range(epochs):
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)

        logits = forward(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()
        
        if batch_idx % 100 ==0: 
            print("Train epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}".format(
                epoch,batch_idx * len(data),len(train_loader.dataset),
                100.*batch_idx/len(train_loader),loss.item()))
        
    test_loss=0 
    correct=0 
    for data,target in test_loader:
        data=data.view(-1,28*28)
        logits=forward(data)
        test_loss+=criteon(logits,target).item()
        
        pred=logits.data.max(1)[1]
        correct+=pred.eq(target.data).sum()
        
    test_loss /= len(test_loader.dataset)
    print('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss,correct,len(test_loader.dataset),
        100.*correct/len(test_loader.dataset)))


 Test set: Average loss: 0.0034, Accuracy: 7761/10000 (78%)


 Test set: Average loss: 0.0015, Accuracy: 9141/10000 (91%)


 Test set: Average loss: 0.0012, Accuracy: 9280/10000 (93%)


 Test set: Average loss: 0.0011, Accuracy: 9362/10000 (94%)


 Test set: Average loss: 0.0010, Accuracy: 9403/10000 (94%)


 Test set: Average loss: 0.0009, Accuracy: 9431/10000 (94%)


 Test set: Average loss: 0.0009, Accuracy: 9473/10000 (95%)


 Test set: Average loss: 0.0008, Accuracy: 9505/10000 (95%)


 Test set: Average loss: 0.0008, Accuracy: 9529/10000 (95%)


 Test set: Average loss: 0.0008, Accuracy: 9542/10000 (95%)

