In [31]:
import torch
from torch import nn
from torch.nn import functional as F

In [35]:
class BasicConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(BasicConv2d,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,bias=False,**kwargs)
        self.bn=nn.BatchNorm2d(out_channels,eps=0.001)
    def forward(self,x):
        x=self.conv(x)
        x=self.bn(x)
        return F.relu(x,inplace=True)
    
    


In [13]:
print("1*1")

1*1


In [53]:
class Inception(nn.Module):#*与X是有区别的
    def __init__(self,in_channels,pool_features):
        super(Inception,self).__init__()
        self.branch1X1=BasicConv2d(in_channels,64,kernel_size=1)
        
        self.branch5X5_1=BasicConv2d(in_channels,16,kernel_size=1)
        self.branch5X5_2=BasicConv2d(16,32,kernel_size=5,padding=2)
        
        self.branch3X3_1=BasicConv2d(in_channels,96,kernel_size=1)
        self.branch3X3_2=BasicConv2d(96,128,kernel_size=3,padding=1)
        self.branch3X3_3=BasicConv2d(128,128,kernel_size=3,padding=1)
        
        self.branch_pool=BasicConv2d(in_channels,pool_features,kernel_size=1)
        
    def forward(self,x):
        branch1X1=self.branch1X1(x)
        branch5X5=self.branch5X5_1(x)
        branch5X5=self.branch5X5_2(branch5X5)
        branch3X3=self.branch3X3_1(x)
        branch3X3=self.branch3X3_2(branch3X3)
        branch3X3=self.branch3X3_3(branch3X3)
        branch_pool=F.avg_pool2d(x,3,1,1)
        branch_pool=self.branch_pool(branch_pool)
        
        out=[branch1X1,branch5X5,branch3X3,branch_pool]
        return torch.cat(outputs,1)

    
net=Inception(192,32)
print(net)

Inception(
  (branch1X1): BasicConv2d(
    (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch5X5_1): BasicConv2d(
    (conv): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch5X5_2): BasicConv2d(
    (conv): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch3X3_1): BasicConv2d(
    (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch3X3_2): BasicConv2d(
    (conv): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=