In [1]:
import torch.nn as nn
import torch
from torchsummary import summary

In [4]:
# 构建残差基本模块
class ResBlock(nn.Module):
    def __init__(self, downSample, in_channels, out_channels):  # downSample判断是否需要下采样
        super().__init__()
        if downSample:
            # 需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
            #
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, 2, 0),
                nn.BatchNorm2d(out_channels)
            )

        else:
            # 不需要下采样
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
            # 对输入不做处理
            self.shortcut = nn.Sequential()

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU()

        self.relu3 = nn.ReLU()

    def forward(self, x):
        # 对输入进行处理
        shortcut = self.shortcut(x)

        #conv1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        #conv2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        # 残差链接
        x = x + shortcut
        x = self.relu3(x)

        return x





In [5]:
resblock = ResBlock(True,64,128).to('cuda:0')

In [6]:
summary(resblock,(64,56,56))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 28, 28]           8,320
       BatchNorm2d-2          [-1, 128, 28, 28]             256
            Conv2d-3          [-1, 128, 28, 28]          73,856
       BatchNorm2d-4          [-1, 128, 28, 28]             256
              ReLU-5          [-1, 128, 28, 28]               0
            Conv2d-6          [-1, 128, 28, 28]         147,584
       BatchNorm2d-7          [-1, 128, 28, 28]             256
              ReLU-8          [-1, 128, 28, 28]               0
              ReLU-9          [-1, 128, 28, 28]               0
Total params: 230,528
Trainable params: 230,528
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.77
Forward/backward pass size (MB): 6.89
Params size (MB): 0.88
Estimated Total Size (MB): 8.54
-------------------------------------------

In [12]:
class ResNet18(nn.Module):
    def __init__(self,num_class):  # 类别数量
        super().__init__()

        #layer 0
        self.layer_0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 通道大小列表
        channel_list = [64,128,256,512]

        #layer 1 都不用下采样
        self.layer_1 = nn.Sequential(
            # 两个残差块不需要下采样
            ResBlock(False,channel_list[0],channel_list[0]),
            ResBlock(False,channel_list[0],channel_list[0])
        )

        #layer 2
        self.layer_2 = nn.Sequential(
            # 第一个残差块需要下采样
            ResBlock(True,channel_list[0],channel_list[1]),
            ResBlock(False,channel_list[1],channel_list[1])
        )

        #layer 3
        self.layer_3 = nn.Sequential(
            # 第一个残差块需要下采样
            ResBlock(True,channel_list[1],channel_list[2]),
            ResBlock(False,channel_list[2],channel_list[2])
        )

        #layer 4
        self.layer_4 = nn.Sequential(
            # 第一个残差块需要下采样
            ResBlock(True,channel_list[2],channel_list[3]),
            ResBlock(False,channel_list[3],channel_list[3])
        )

        # layer 5
        self.AdaptiveAvgPool5 = nn.AdaptiveAvgPool2d(1)  # output size = 1
        self.flatten5 = nn.Flatten(start_dim=1)
        self.fc5 = nn.Linear(channel_list[3],num_class)


    def forward(self, x):
        x = self.layer_0(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.AdaptiveAvgPool5(x)
        x = self.flatten5(x)
        x = self.fc5(x)

        return x

In [13]:
# 实例化网络
resnet = ResNet18(10).to('cuda:0')  # num_class 分类个数

In [14]:
summary(resnet,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,928
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,928
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
             ReLU-11           [-1, 64, 56, 56]               0
         ResBlock-12           [-1, 64, 56, 56]               0
           Conv2d-13           [-1, 64, 56, 56]          36,928
      BatchNorm2d-14           [-1, 64,