# 0.
1. 多分类实战：https://www.bilibili.com/video/BV18g4119737?p=50&vd_source=70200f7d09862fd682e5f89b22c89125
2. 用的基本操作，不是nn.linear

In [19]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

batch_size = 200
learning_rate = 0.01
epochs = 10

# 1. 读数据

In [20]:
def load_data(batch_size):
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=True)

    return train_loader, test_loader


1. 注意**初始化**非常重要，这里如果不初始化，loss会在几个epoch之后就停止更新了
2. 老师说尤其是你在设计一个新方法的时候，这个初始化很重要

# 2. The net

In [29]:
w1, b1 = torch.randn(200, 784, requires_grad=True), \
         torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True), \
         torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True), \
         torch.zeros(10, requires_grad=True)

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

tensor([[-0.1379,  0.0907, -0.2397,  ..., -0.1635,  0.0773, -0.0396],
        [-0.0032, -0.1304,  0.0804,  ...,  0.0829, -0.0414, -0.2695],
        [ 0.1056, -0.1060,  0.1014,  ..., -0.0688, -0.1250, -0.2325],
        ...,
        [-0.0144,  0.1861, -0.0685,  ...,  0.2042, -0.1325, -0.1233],
        [ 0.0839,  0.0826, -0.0567,  ...,  0.1165,  0.2900, -0.1748],
        [ 0.0684, -0.0050, -0.1290,  ...,  0.0289, -0.0427, -0.1003]],
       requires_grad=True)

2. 一般没有经过relu或者softmax的东西叫做“logits”

In [30]:
def forward(x):
    x=x@w1.t()+b1
    x=F.relu(x)
    x=x@w2.t()+b2
    x=F.relu(x)
    x=x@w3.t()+b3
    x = F.relu(x)
    return x

In [31]:

optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)

criterion = nn.CrossEntropyLoss()

# 3. Training


In [32]:
train_loader, test_loader = load_data(batch_size)
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1,28*28)

        logits = forward(data)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.noem())
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criterion(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.0018, Accuracy: 8947/10000 (89%)


Test set: Average loss: 0.0014, Accuracy: 9171/10000 (92%)


Test set: Average loss: 0.0012, Accuracy: 9289/10000 (93%)


Test set: Average loss: 0.0011, Accuracy: 9364/10000 (94%)


Test set: Average loss: 0.0010, Accuracy: 9399/10000 (94%)


Test set: Average loss: 0.0009, Accuracy: 9453/10000 (95%)


Test set: Average loss: 0.0009, Accuracy: 9473/10000 (95%)


Test set: Average loss: 0.0008, Accuracy: 9511/10000 (95%)


Test set: Average loss: 0.0008, Accuracy: 9532/10000 (95%)


Test set: Average loss: 0.0008, Accuracy: 9559/10000 (96%)



# 4. Testing (代码见上)
1. 不能一直training下去，必须时不时的在test（valid）数据集上进行测试
2. 但是test会消耗时间，
    - 不能train完一个batch就做一个test，而且一个test也不会只测试一个batch的数据
    - 每次epoch做一次test也可以选择