# Inception模块

## 介绍

Inception模块的核心思想就是将不同的卷积层通过并联的方式结合在一起，经过不同卷积层处理的结果矩阵在深度这个维度拼接起来，形成一个更深的矩阵。  
Inception模块可以反复叠堆形成更大的网络，它可以对网络的深度和宽度进行高效的扩充，在提升深度学习网络准确率的同时防止过拟合现象的发生。  
Inception模块的优点是可以对尺寸较大的矩阵先进行降维处理的同时，在不同尺寸上对视觉信息进行聚合，方便从不同尺度对特征进行提取。

![Inception.png](./imgs/Inception.png)

*图中的Inception模块包含4个分支，每个分支使用不同大小的卷积核进行卷积操作，然后将结果矩阵在深度这个维度拼接起来，形成更深的矩阵。*

**主要有以下改进：**
1. 一层block就包含1x1卷积、3x3卷积、5x5卷积、3x3最大池化，网络中每一层都能学习到“不同尺度”`“稀疏”（3x3、5x5）或“不稀疏”（1x1）`的特征。既增加了网络的宽度，也增加了网络对尺度的适应性。
2. 通过concat在每个block后合成特征，获得非线性属性。

为了降低算力成本，在3x3和5x5的卷积层之前添加额外的1x1卷积层，来限制输入信道（channel）的数量，对输入的特征矩阵进行降维，从而降低后续卷积的计算量。  
尽管添加了额外的1x1卷积操作是反直觉的，但是1x1卷积要比5x5卷积要廉价的多，而且输入信道数量的减少有利于降低算力成本。  
注意，1x1卷积在最大池化的后边

![Inception_DR.png](./imgs/Inception_DR.png)

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision

## Inception module with dimensionality reduction实现

In [2]:
# 定义基础卷积模型：卷积+BN+ReLU激活
class BasicConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, **kwargs): # kwargs可以传入kernel_size, stride, padding等参数
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              bias=False,                   # BN层会消除bias，故bias=False，减少计算
                              **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [3]:
# 定义Inception模块
class InceptionBlock(nn.Module):

    def __init__(self, in_channels, pool_features):     # pool_features是池化层输出的特征矩阵的深度
        super().__init__()
        self.branch_1x1 = BasicConv2d(in_channels=in_channels,
                                     out_channels=64,
                                     kernel_size=1)
        
        self.branch_3x3_1 = BasicConv2d(in_channels=in_channels,
                                       out_channels=64,
                                       kernel_size=1)
        self.branch_3x3_2 = BasicConv2d(in_channels=64, 
                                       out_channels=96,
                                       kernel_size=3,
                                       padding=1)
        # pytorch默认的填充方式为'VALID'， 卷积核大小为 k , 每次卷积要下降  k - 1 个像素，所以padding = k//2

        self.branch_5x5_1 = BasicConv2d(in_channels=in_channels,
                                       out_channels=48,
                                       kernel_size=1)
        self.branch_5x5_2 = BasicConv2d(in_channels=48,
                                       out_channels=64,
                                       kernel_size=5,
                                       padding=2)

        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)

    def forward(self, x):
        branch_1x1_out = self.branch_1x1(x)

        branch_3x3_out_1 = self.branch_3x3_1(x)
        branch_3x3_out_2 = self.branch_3x3_2(branch_3x3_out_1)

        branch_5x5_out_1 = self.branch_5x5_1(x)
        branch_5x5_out_2 = self.branch_5x5_2(branch_5x5_out_1)

        branch_pool_out_1 = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool_out_2 = self.branch_pool(branch_pool_out_1)

        outputs = [branch_1x1_out, branch_3x3_out_2, branch_5x5_out_2, branch_pool_out_2]
        return torch.cat(outputs, 1) # dim=1表示在通道（channel）这个维度拼接   [batch, channel, height, width]

In [4]:
inception_block = InceptionBlock(in_channels=32, pool_features=64)

In [5]:
inception_block

InceptionBlock(
  (branch_1x1): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch_3x3_1): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch_3x3_2): BasicConv2d(
    (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch_5x5_1): BasicConv2d(
    (conv): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch_5x5_2): BasicConv2d(
    (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, af

In [None]:
model = torchvision.models.inception_v3(pretrained=True)

In [None]:
model