# 利用pytorch实现ResNet  

50行代码实现，这谁顶得住，开干开干

## 实现说明：

1. 模型中存在大量的重复网络结构
2. 对于模型中重复的部分，实现为子module或用函数生成对应的module
3. 尽量使用nn.Seqential
4. nn.Module和nn.Function结合使用

In [1]:
from torch import nn
import torch as t
from torch.nn import functional as F

In [11]:
class ResidualBlock(nn.Module):
    '''
    实现一个子模块
    
    '''
    def __init__(self ,inchannel, outchannel, stride = 1, shortcut = None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias = False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace= True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias = False),
            nn.BatchNorm2d(outchannel)
        )
        self.right = shortcut
        
    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        
        out += residual
        
        return F.relu(out)
    
    
    

In [12]:
class ResNet(nn.Module):
    '''
        实现主module:ResNet34
        ResNet34 包含多个layer, 每个layer又包含多个residual block
        用子module实现residual block，用make_layer函数实现layer
    
    '''
    def __init__(self, num_class = 1000):
        super(ResNet, self).__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )
        self.layer1 = self._make_layer(64, 128, 3)
        self.layer2 = self._make_layer(128, 256, 4, stride = 2)
        self.layer3 = self._make_layer(256, 512, 6, stride = 2 )
        self.layer4 = self._make_layer(512, 512, 3, stride = 2)

        self.fc = nn.Linear(512, num_class)

    def _make_layer(self, inchannel, outchannel, block_num, stride =1 ):
        '''
        构建layer包含多个residual block
        '''
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias = False),
            nn.BatchNorm2d(outchannel)
        )
        
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))
        
        for i in range(1,block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        
        x = self.pre(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        
        return self.fc(x)
    
    
        
    

In [14]:
# block = ResidualBlock()
model = ResNet()


In [15]:
input = t.autograd.Variable(t.randn(1, 3, 224, 224))
out = model(input)

In [17]:
print(model)

ResNet(
  (pre): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (right): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      

)


In [18]:
out

tensor([[ 6.0279e-01,  3.2270e-01, -2.1346e-01, -6.4585e-02, -1.2277e-01,
          3.1704e-01, -2.6880e-01,  5.4051e-02,  6.3036e-01,  1.1401e-01,
         -1.5322e-01,  2.5815e-02,  1.5689e-01, -6.7118e-02,  7.4015e-02,
         -4.3403e-02,  9.7923e-02,  2.9348e-01, -1.9360e-01, -4.0967e-01,
         -1.3117e-01, -9.6169e-02, -1.6637e-01,  4.5972e-01, -2.5299e-01,
         -4.1093e-01,  2.6573e-01,  3.3544e-02,  2.2931e-01,  2.7117e-01,
         -2.1993e-01, -2.2382e-01, -2.0749e-01,  1.0306e+00, -1.9673e-01,
         -8.2401e-02, -3.1311e-01, -9.4969e-02,  5.4811e-01, -3.7894e-01,
          2.1383e-01, -6.1764e-01, -3.4215e-01,  4.0260e-01, -6.2899e-01,
         -9.2600e-02, -5.0489e-01, -7.1144e-01,  2.0432e-01, -2.5314e-01,
         -2.0530e-01, -4.8832e-01,  5.9186e-01,  2.6577e-01, -8.3820e-02,
         -9.7709e-02, -2.1208e-02, -2.0003e-01, -6.1839e-04, -5.2728e-01,
          1.0948e-01, -3.1185e-01, -4.3275e-02,  1.6530e-02, -1.7877e-01,
         -2.6744e-01,  9.8365e-02, -3.