In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
%matplotlib inline
import matplotlib.pyplot as plt
import torchvision
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [11]:
class Basic_conv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(Basic_conv2d,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,bias=False,**kwargs)
        self.BN = nn.BatchNorm2d(out_channels)
    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x,inplace=True)

In [12]:
class InceptionBlock(nn.Module):
    def __init__(self,in_channels,pool_features):
        super(InceptionBlock,self).__init__()
        self.b1x1 = Basic_conv2d(in_channels,64,kernel_size = 1)
              
        self.b3x3_1 = Basic_conv2d(in_channels,64,kernel_size = 1)
        self.b3x3_2 = Basic_conv2d(64,96,kernel_size = 3,padding=1)      
        
        self.b5x5_1 = Basic_conv2d(in_channels,48,kernel_size = 1)
        self.b5x5_2 = Basic_conv2d(48,64,kernel_size=2) #保持图像大小不变，使用padding填充
       
        self.b_pool = Basic_conv2d(in_channels,pool_features,kernel_size=1)
    def forward(self,x):
        b_1x1_out = self.b1x1(x)
        b_3x3 = self.b3x3_1(x)
        b_3x3_out = self.b3x3_2(b_3x3) 
        b_5x5 = self.b5x5_1(x)
        b_5x5_out = self.b5x5_2(b_5x5)
        b_pool_out = F.max_pool2d(x,kernel_size=3,stride=1,padding=1) #保持大小不变
        b_pool_out = self.b_pool(b_pool_out)
        
        outputs = [b_1x1_out,b_3x3_out,b_5x5_out,b_pool_out]
        return torch.cat(outputs,dim=1)

In [13]:
my_inception = InceptionBlock(32,64)
my_inception

InceptionBlock(
  (b1x1): Basic_conv2d(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (b3x3_1): Basic_conv2d(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (b3x3_2): Basic_conv2d(
    (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (BN): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (b5x5_1): Basic_conv2d(
    (conv): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BN): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (b5x5_2): Basic_conv2d(
    (conv): Conv2d(48, 64, kernel_size=(2, 2), stride=(1, 1), bias=False)
    (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
