Pytorch复现ResNet，对minist数据集进行分类  

In [50]:
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# Prepare Dataset

In [51]:
batch_size = 64
transform = transforms.Compose([
   transforms.ToTensor(), #ToTensor()能够把灰度范围从0-255变换到0-1之间
   transforms.Normalize((0.1307,),(0.3081,)) #transform.Normalize()则把0-1变换到(-1,1)，只有一个通道所以有1个值，第一位是均值，第二位是标准差
])

train_dataset = datasets.MNIST(root = '../dataset/mnist',
                              train = True,
                              download = True,
                              transform = transform)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_dataset = datasets.MNIST(root = '../dataset/mnist',
                             train = False,
                             download = True,
                             transform = transform)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

# Design Model

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210122517037.png" alt="image-20220210122517037" style="zoom:33%;" />

In [52]:
class ResidualBlock(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        return F.relu(y+x) #与x相加后再激活

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210135808693.png" alt="image-20220210135808693" style="zoom:33%;" />

## 其他结构的Resnet

参考文献：https://arxiv.org/pdf/1603.05027.pdf

### constant scaling

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210190421669.png" alt="image-20220210190421669" style="zoom: 50%;" />

In [75]:
class ResidualBlock_b(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock_b,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        return F.relu(0.5*y+0.5*x) #与x相加后再激活

### exclusive gating

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210193303728.png" alt="image-20220210193303728" style="zoom: 50%;" />

In [76]:
class ResidualBlock_c(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock_c,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        
        sekf.conv1x1 = nn.Conv2d(channels,channels,kernel=1)
        
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        
        conv1x1 = F.sigmoid(self.conv1x1(x))
        
        x = x*(1-conv1x1)
        
        y = y*conv1x1
        
        return F.relu(y+x) #与x相加后再激活

###  shortcut-only gating

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210194239715.png" alt="image-20220210194239715" style="zoom:50%;" />

In [77]:
class ResidualBlock_d(nn.Module):
    def __init__(self_d,channels):
        super(ResidualBlock,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        
        sekf.conv1x1 = nn.Conv2d(channels,channels,kernel=1)
        
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        
        conv1x1 = F.sigmoid(self.conv1x1(x))
        
        x = x*(1-conv1x1)
        
        return F.relu(y+x) #与x相加后再激活

### conv shortcut

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210194409750.png" alt="image-20220210194409750" style="zoom:50%;" />

In [78]:
class ResidualBlock_e(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock_e,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        
        sekf.conv1x1 = nn.Conv2d(channels,channels,kernel=1)
        
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        
        x = self.conv1x1(x)
        
        return F.relu(y+x) #与x相加后再激活

### dropout shortcut

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220210194628363.png" alt="image-20220210194628363" style="zoom:50%;" />

In [79]:
class ResidualBlock_f(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock_f,self).__init__()
        self.channels=channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        
        x = F.dropout(x,p=0.5)
        
        return F.relu(y+x) #与x相加后再激活

### 放到网络中

In [80]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,16,kernel_size=5)
        self.conv2 = nn.Conv2d(16,32,kernel_size=5)
        self.mp = nn.MaxPool2d(2)
        
        self.rblock1 = ResidualBlock_f(16)
        self.rblock2 = ResidualBlock_f(32)
        
        self.fc = nn.Linear(512,10)
        
    def forward(self,x):
        in_size = x.size(0)
        x=self.mp(F.relu(self.conv1(x)))
        x=self.rblock1(x)
        x=self.mp(F.relu(self.conv2(x)))
        x=self.rblock2(x)
        x=x.view(in_size,-1)
        #print(x.size())
        x=self.fc(x)
        return x

In [81]:
# 初始化一个输入x，算一下全连接层的输入是几维的
x = torch.rand([1,1,28,28])
net = Net()
with torch.no_grad():
    net(x) #torch.Size([1, 512])

# construct loss and optimizer

In [82]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.5)

# Training cycle

In [83]:
def train(epoch,device):
    running_loss = 0.0
    for branch_idx,data in enumerate(train_loader,0):
        images,targets=data
        images, targets = images.to(device),targets.to(device)
        
        optimizer.zero_grad()
        outputs=net(images)
        loss=criterion(outputs,targets)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        
        if branch_idx % 299 ==0:
            print('[%d,%5d] loss:%.3f' % ((epoch+1,branch_idx+1,running_loss/300)))
            running_loss=0.0

In [84]:
def test(device):
    total = 0
    correct = 0
    with torch.no_grad(): #停止跟踪历史记录
        for data in test_loader:
            images, targets = data
            images, targets = images.to(device),targets.to(device)
            _,predicted = torch.max(net(images),dim=1)
            total+=images.size(0)
            correct+=(predicted==targets).sum().item()
    print('Accuracy on test set is %d %%'%(correct/total*100))

In [85]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)
for epoch in range(10):
    train(epoch,device)
    test(device)

[1,    1] loss:0.008
[1,  300] loss:0.770
[1,  599] loss:0.260
[1,  898] loss:0.180
Accuracy on test set is 96 %
[2,    1] loss:0.000
[2,  300] loss:0.138
[2,  599] loss:0.120
[2,  898] loss:0.110
Accuracy on test set is 97 %
[3,    1] loss:0.000
[3,  300] loss:0.096
[3,  599] loss:0.087
[3,  898] loss:0.089
Accuracy on test set is 97 %
[4,    1] loss:0.000
[4,  300] loss:0.072
[4,  599] loss:0.074
[4,  898] loss:0.070
Accuracy on test set is 98 %
[5,    1] loss:0.000
[5,  300] loss:0.064
[5,  599] loss:0.061
[5,  898] loss:0.064
Accuracy on test set is 98 %
[6,    1] loss:0.000
[6,  300] loss:0.057
[6,  599] loss:0.055
[6,  898] loss:0.052
Accuracy on test set is 98 %
[7,    1] loss:0.000
[7,  300] loss:0.056
[7,  599] loss:0.052
[7,  898] loss:0.045
Accuracy on test set is 98 %
[8,    1] loss:0.000
[8,  300] loss:0.047
[8,  599] loss:0.044
[8,  898] loss:0.049
Accuracy on test set is 98 %
[9,    1] loss:0.000
[9,  300] loss:0.042
[9,  599] loss:0.045
[9,  898] loss:0.041
Accuracy on 