<a href="https://colab.research.google.com/github/BarryLiu-97/Pytorch-Tutorial/blob/master/09_Advanced_CNN_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import time

In [3]:
batch_size = 64
transform = transforms.Compose([
  transforms.ToTensor(),              #将数据转换为channel×width×height格式，为了更高效地进行运算
  transforms.Normalize((0.1307, ), (0.3081, ))  #均值和标准差，用于数据标准化，这是对MNIST进行计算后得到的结果，已经算好了
])

train_dataset = datasets.MNIST(root='../dataset/mnist',
                train=True, download=True,
                transform = transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

test_dataset = datasets.MNIST(root='../dataset/mnist',
                train=False, download=True,
                transform = transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

In [10]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super(ResidualBlock, self).__init__()
    self.channels = channels
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

  def forward(self, x):
    y = F.relu(self.conv1(x))
    y = self.conv2(y)
    return F.relu(x+y)  # 残差块先求和再激活,所以残差块的输入输出同维度 是求和不是拼接

In [11]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1 ,16, kernel_size=5)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
    self.mp = nn.MaxPool2d(2)

    self.rblock1 = ResidualBlock(16)
    self.rblock2 = ResidualBlock(32)

    self.fc = nn.Linear(512,10)

  def forward(self, x):
    in_size = x.size(0)
    x = self.mp(F.relu(self.conv1(x)))
    x = self.rblock1(x)
    x = self.mp(F.relu(self.conv2(x)))
    x = self.rblock2(x)
    x = x.view(in_size, -1)
    x = self.fc(x)
    return x

In [12]:
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  #冲量值设置为0.5，优化训练过程
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 若当前cuda是可行的，使用第一个可见的设备(即GPU)
model.to(device) # 参数、缓存等，所有的模块放入cuda，使用GPU

Net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (mp): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (rblock1): ResidualBlock(
    (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (rblock2): ResidualBlock(
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

In [13]:
def train(epoch):
  running_loss = 0.
  for batch_idx, data in enumerate(train_loader, 0):
    inputs, target = data
    inputs, target = inputs.to(device), target.to(device)  # 转移到cuda，且在同一块显卡上
    optimizer.zero_grad()

    # forward + backard + update
    outputs = model(inputs)
    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    if batch_idx % 300 == 299:
      print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx + 1, running_loss / 2000))
      running_loss = 0.0

In [14]:
def test():
  correct = 0
  total = 0
  with torch.no_grad():
    for data in test_loader:
      inputs, target = data
      inputs, target = inputs.to(device), target.to(device)
      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, dim=1)
      total += target.size(0)
      correct += (predicted == target).sum().item()
  print('Accuracy on test set: %d %% [%d/%d]' % (100*correct / total, correct, total))

In [15]:
start = time.time()
for epoch in range(10):
  train(epoch)
  test()
end = time.time()
print(str(end-start) + 's')

[1,   300] loss: 0.077
[1,   600] loss: 0.024
[1,   900] loss: 0.017
Accuracy on test set: 97 % [9703/10000]
[2,   300] loss: 0.013
[2,   600] loss: 0.012
[2,   900] loss: 0.012
Accuracy on test set: 97 % [9768/10000]
[3,   300] loss: 0.009
[3,   600] loss: 0.009
[3,   900] loss: 0.009
Accuracy on test set: 98 % [9819/10000]
[4,   300] loss: 0.007
[4,   600] loss: 0.007
[4,   900] loss: 0.007
Accuracy on test set: 98 % [9817/10000]
[5,   300] loss: 0.006
[5,   600] loss: 0.006
[5,   900] loss: 0.006
Accuracy on test set: 98 % [9876/10000]
[6,   300] loss: 0.005
[6,   600] loss: 0.006
[6,   900] loss: 0.005
Accuracy on test set: 98 % [9868/10000]
[7,   300] loss: 0.005
[7,   600] loss: 0.005
[7,   900] loss: 0.004
Accuracy on test set: 98 % [9875/10000]
[8,   300] loss: 0.004
[8,   600] loss: 0.004
[8,   900] loss: 0.004
Accuracy on test set: 99 % [9902/10000]
[9,   300] loss: 0.003
[9,   600] loss: 0.004
[9,   900] loss: 0.004
Accuracy on test set: 98 % [9896/10000]
[10,   300] loss: 0