### 使用nn.Module和nn.functional实现神经网络

'nn.Module'用来实现自定义的网络（或部分的网络），需要实现：
- '__init__(self, ...)': 网络的初始化
- 'forward(self, x)': 网络的前馈过程

'nn.Module'的常用方法：
- 'zero_grad()': 清空所有层的梯度
- 'train()': 训练模式
- 'eval()': 测试模式
- 'cuda()'/'cpu()': 将整个网络（所有层）迁移到GPU或CPU上
- 'parameters()': 返回所有的参数（在初始化的时候有可能使用到）

常用的组件：
- 'nn.Conv2d': 二维卷积层
- 'nn.Linear': 全连接层
- 'nn.MaxPool2d'/'nn.AvgPool2d': Pooling层
- 'nn.ReLU': ReLU层
- 'nn.Dropout': Dropout层
- 'BatchNorm2d': BatchNorm层

'nn.Module'和'nn.functional'的关系：
'nn.Module'是对'nn.functional'的封装，将定义权重的过程封装起来，使用更方便。
对于一些没有参数的层，使用'nn.functional'中的函数可能会更方便一些。

训练与测试：以MNIST为例
- 数据读入
- 初始化模型、优化器
- 训练、测试


In [1]:

# 使用nn.Module和nn.functional实现MLP
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(100, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        out = F.softmax(self.fc2(x), dim=0)
        return out

In [2]:

net = MLPNet()

data = torch.ones(50, 100)
out = net(data)

print(out.size())

torch.Size([50, 10])


In [3]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(3136, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encoder(self, x):
        x = x.view((x.shape[0],28,28))
        x = x.unsqueeze(1)
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        return x

    def decoder(self, x):
        x = self.fc2(x)
        return x

In [4]:

from torchvision import datasets, transforms
import torch.optim as optim
import tqdm

def get_MNIST_dataloader(batch_size=64):
    train_loader = torch.utils.data.DataLoader(dataset=datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=datasets.MNIST('./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [5]:

def train_on_MNIST(init_lr=1e-4, num_epochs=10, device=torch.device("cuda")):
    train_loader, test_loader = get_MNIST_dataloader()

    # prepare model and optimizer
    model = Model().to(device)
    optimizer = optim.Adam(model.parameters(), lr=init_lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

    # training loop
    for epoch in range(1, 1+num_epochs):
        train_one_epoch(model, device, train_loader, optimizer, epoch)
        test_one_epoch(model, device, test_loader)
        scheduler.step()

    # saving
    torch.save(model.state_dict(), 'mnist_cnn.pth')

In [6]:

def train_one_epoch(model, device, train_loader, optimizer, epoch, log_interval=100):
    model.train()
    for batch_idx, (data, target) in tqdm.tqdm_notebook(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)

        output = model(data)
        loss = F.cross_entropy(output, target) # no softmax! F.cross_entropy = CE(Softmax(x))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f'Train epoch: {epoch} [Iter: {batch_idx*len(data)}/{len(train_loader.dataset)}]' + \
                  f'\t Loss: {loss.item():.6f}')

In [7]:

def test_one_epoch(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('-' * 10)
    print(f'Test: Average loss:{test_loss:.4f}, ' + \
          f'Accuracy: {100. * correct/len(test_loader.dataset):.0f}%')
    print('-' * 10)

In [8]:

train_on_MNIST()


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


0it [00:00, ?it/s]

Train epoch: 1 [Iter: 0/60000]	 Loss: 2.290933
Train epoch: 1 [Iter: 6400/60000]	 Loss: 0.262079
Train epoch: 1 [Iter: 12800/60000]	 Loss: 0.421512
Train epoch: 1 [Iter: 19200/60000]	 Loss: 0.247179
Train epoch: 1 [Iter: 25600/60000]	 Loss: 0.074998
Train epoch: 1 [Iter: 32000/60000]	 Loss: 0.175338
Train epoch: 1 [Iter: 38400/60000]	 Loss: 0.113575
Train epoch: 1 [Iter: 44800/60000]	 Loss: 0.033929
Train epoch: 1 [Iter: 51200/60000]	 Loss: 0.080976
Train epoch: 1 [Iter: 57600/60000]	 Loss: 0.117306
----------
Test: Average loss:0.0740, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 2 [Iter: 0/60000]	 Loss: 0.135707
Train epoch: 2 [Iter: 6400/60000]	 Loss: 0.116474
Train epoch: 2 [Iter: 12800/60000]	 Loss: 0.131173
Train epoch: 2 [Iter: 19200/60000]	 Loss: 0.050226
Train epoch: 2 [Iter: 25600/60000]	 Loss: 0.108192
Train epoch: 2 [Iter: 32000/60000]	 Loss: 0.070618
Train epoch: 2 [Iter: 38400/60000]	 Loss: 0.042846
Train epoch: 2 [Iter: 44800/60000]	 Loss: 0.155448
Train epoch: 2 [Iter: 51200/60000]	 Loss: 0.178205
Train epoch: 2 [Iter: 57600/60000]	 Loss: 0.050348
----------
Test: Average loss:0.0603, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 3 [Iter: 0/60000]	 Loss: 0.036096
Train epoch: 3 [Iter: 6400/60000]	 Loss: 0.049126
Train epoch: 3 [Iter: 12800/60000]	 Loss: 0.052084
Train epoch: 3 [Iter: 19200/60000]	 Loss: 0.060259
Train epoch: 3 [Iter: 25600/60000]	 Loss: 0.033311
Train epoch: 3 [Iter: 32000/60000]	 Loss: 0.027373
Train epoch: 3 [Iter: 38400/60000]	 Loss: 0.028097
Train epoch: 3 [Iter: 44800/60000]	 Loss: 0.091322
Train epoch: 3 [Iter: 51200/60000]	 Loss: 0.060410
Train epoch: 3 [Iter: 57600/60000]	 Loss: 0.058623
----------
Test: Average loss:0.0588, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 4 [Iter: 0/60000]	 Loss: 0.030468
Train epoch: 4 [Iter: 6400/60000]	 Loss: 0.010301
Train epoch: 4 [Iter: 12800/60000]	 Loss: 0.016564
Train epoch: 4 [Iter: 19200/60000]	 Loss: 0.059626
Train epoch: 4 [Iter: 25600/60000]	 Loss: 0.017018
Train epoch: 4 [Iter: 32000/60000]	 Loss: 0.032318
Train epoch: 4 [Iter: 38400/60000]	 Loss: 0.064526
Train epoch: 4 [Iter: 44800/60000]	 Loss: 0.061452
Train epoch: 4 [Iter: 51200/60000]	 Loss: 0.017773
Train epoch: 4 [Iter: 57600/60000]	 Loss: 0.058905
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 5 [Iter: 0/60000]	 Loss: 0.055722
Train epoch: 5 [Iter: 6400/60000]	 Loss: 0.123555
Train epoch: 5 [Iter: 12800/60000]	 Loss: 0.021306
Train epoch: 5 [Iter: 19200/60000]	 Loss: 0.199588
Train epoch: 5 [Iter: 25600/60000]	 Loss: 0.144897
Train epoch: 5 [Iter: 32000/60000]	 Loss: 0.086718
Train epoch: 5 [Iter: 38400/60000]	 Loss: 0.027593
Train epoch: 5 [Iter: 44800/60000]	 Loss: 0.038253
Train epoch: 5 [Iter: 51200/60000]	 Loss: 0.025756
Train epoch: 5 [Iter: 57600/60000]	 Loss: 0.066233
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 6 [Iter: 0/60000]	 Loss: 0.035878
Train epoch: 6 [Iter: 6400/60000]	 Loss: 0.069505
Train epoch: 6 [Iter: 12800/60000]	 Loss: 0.059408
Train epoch: 6 [Iter: 19200/60000]	 Loss: 0.040360
Train epoch: 6 [Iter: 25600/60000]	 Loss: 0.017304
Train epoch: 6 [Iter: 32000/60000]	 Loss: 0.051150
Train epoch: 6 [Iter: 38400/60000]	 Loss: 0.135936
Train epoch: 6 [Iter: 44800/60000]	 Loss: 0.053987
Train epoch: 6 [Iter: 51200/60000]	 Loss: 0.103070
Train epoch: 6 [Iter: 57600/60000]	 Loss: 0.029660
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 7 [Iter: 0/60000]	 Loss: 0.022966
Train epoch: 7 [Iter: 6400/60000]	 Loss: 0.042366
Train epoch: 7 [Iter: 12800/60000]	 Loss: 0.101577
Train epoch: 7 [Iter: 19200/60000]	 Loss: 0.034983
Train epoch: 7 [Iter: 25600/60000]	 Loss: 0.015327
Train epoch: 7 [Iter: 32000/60000]	 Loss: 0.050052
Train epoch: 7 [Iter: 38400/60000]	 Loss: 0.012108
Train epoch: 7 [Iter: 44800/60000]	 Loss: 0.107480
Train epoch: 7 [Iter: 51200/60000]	 Loss: 0.084626
Train epoch: 7 [Iter: 57600/60000]	 Loss: 0.192875
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 8 [Iter: 0/60000]	 Loss: 0.039832
Train epoch: 8 [Iter: 6400/60000]	 Loss: 0.030785
Train epoch: 8 [Iter: 12800/60000]	 Loss: 0.022965
Train epoch: 8 [Iter: 19200/60000]	 Loss: 0.119799
Train epoch: 8 [Iter: 25600/60000]	 Loss: 0.062288
Train epoch: 8 [Iter: 32000/60000]	 Loss: 0.137041
Train epoch: 8 [Iter: 38400/60000]	 Loss: 0.092392
Train epoch: 8 [Iter: 44800/60000]	 Loss: 0.027005
Train epoch: 8 [Iter: 51200/60000]	 Loss: 0.068983
Train epoch: 8 [Iter: 57600/60000]	 Loss: 0.037548
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 9 [Iter: 0/60000]	 Loss: 0.036981
Train epoch: 9 [Iter: 6400/60000]	 Loss: 0.064872
Train epoch: 9 [Iter: 12800/60000]	 Loss: 0.143403
Train epoch: 9 [Iter: 19200/60000]	 Loss: 0.057350
Train epoch: 9 [Iter: 25600/60000]	 Loss: 0.086227
Train epoch: 9 [Iter: 32000/60000]	 Loss: 0.046821
Train epoch: 9 [Iter: 38400/60000]	 Loss: 0.029497
Train epoch: 9 [Iter: 44800/60000]	 Loss: 0.042612
Train epoch: 9 [Iter: 51200/60000]	 Loss: 0.063328
Train epoch: 9 [Iter: 57600/60000]	 Loss: 0.013613
----------
Test: Average loss:0.0587, Accuracy: 98%
----------


0it [00:00, ?it/s]

Train epoch: 10 [Iter: 0/60000]	 Loss: 0.054516
Train epoch: 10 [Iter: 6400/60000]	 Loss: 0.037694
Train epoch: 10 [Iter: 12800/60000]	 Loss: 0.028837
Train epoch: 10 [Iter: 19200/60000]	 Loss: 0.024753
Train epoch: 10 [Iter: 25600/60000]	 Loss: 0.032908
Train epoch: 10 [Iter: 32000/60000]	 Loss: 0.060467
Train epoch: 10 [Iter: 38400/60000]	 Loss: 0.043509
Train epoch: 10 [Iter: 44800/60000]	 Loss: 0.059246
Train epoch: 10 [Iter: 51200/60000]	 Loss: 0.074750
Train epoch: 10 [Iter: 57600/60000]	 Loss: 0.074909
----------
Test: Average loss:0.0587, Accuracy: 98%
----------
