In [37]:
import torch
import torch.nn as nn

In [38]:
# SENet block:

# convolution 1х1
# batchnorm
# ReLU

# convolution 3х3
# batchnorm
# ReLU

# convolution 1х1
# batchnorm
# ReLU

# SENet branch

#skip connection
# ReLU

class se_block(nn.Module):
  def __init__(self, in_channels, out_channels, identity_downsample):
    super(se_block,self).__init__()
    self.conv1 = nn.Conv2d(in_channels,out_channels//4,kernel_size=1,stride=1,padding=0)
    self.bn1 = nn.BatchNorm2d(out_channels//4)
    self.conv2 = nn.Conv2d(out_channels//4,out_channels//4,kernel_size=3,stride=2 if identity_downsample else 1,padding=1)
    self.conv3 = nn.Conv2d(out_channels//4,out_channels,kernel_size=1,stride=1,padding=0)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()
    self.identity_downsample = identity_downsample
    self.shortcut = nn.Sequential()

    # SENet part
    self.gl_avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
    self.conv4 = nn.Conv2d(out_channels,out_channels//16,kernel_size=1,stride=1)
    self.conv5 = nn.Conv2d(out_channels//16,out_channels,kernel_size=1,stride=1)
    self.sigmoid = nn.Sigmoid()

    # if skip connnection has additional block
    if self.identity_downsample or in_channels != out_channels:
        self.shortcut = nn.Sequential(
              nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2 if self.identity_downsample else 1),
              nn.BatchNorm2d(out_channels)
          )
  
  # x out_channels -> x//4 out_channels, x//4 out_channels -> x out_channels
  def forward(self,x):
    shortcut = self.shortcut(x)

    x = self.relu(self.bn1(self.conv1(x)))
    x = self.relu(self.bn1(self.conv2(x)))
    x = self.relu(self.bn2(self.conv3(x)))

    # SENet part begins
    y = self.sigmoid(self.conv5(self.relu(self.conv4(self.gl_avg_pool(x)))))
    # Multiply 2 branches
    x *= y

    x += shortcut
    x = self.relu(x)

    return x


In [39]:
# SENet
class SENet(nn.Module):
  def __init__(self,block,layers, image_channels, num_classes):
    super(SENet,self).__init__()
    # First convolution out_channels equals to 64
    self.in_channels = 64
    # First block is different since
    # it has 7х7 convolutional
    self.conv1 = nn.Conv2d(image_channels,self.in_channels,kernel_size=7,stride=2,padding=3)
    self.bn1 = nn.BatchNorm2d(self.in_channels)
    self.relu = nn.ReLU()
    self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

    #se_block layers
    self.layer1 = self.make_layer(block, layers[0], out_channels=256, identity_downsample=False)
    self.layer2 = self.make_layer(block, layers[1], out_channels=512, identity_downsample=True)
    self.layer3 = self.make_layer(block, layers[2], out_channels=1024, identity_downsample=True)
    self.layer4 = self.make_layer(block, layers[3], out_channels=2048, identity_downsample=True)

    # Before fully connected layer we use AveregePooling
    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(512*4,num_classes)

  # Layer creation function
  def make_layer(self, block, num_residual_blocks, out_channels, identity_downsample):
    # We increase the number of channels in the 1st layer of each block
    layers = []
    layers.append(block(self.in_channels, out_channels, identity_downsample))
    self.in_channels = out_channels

    for _ in range(num_residual_blocks - 1):
      layers.append(block(self.in_channels, out_channels, False))

    return nn.Sequential(*layers)

  def forward(self,x):
    x = self.maxpool(self.relu(self.bn1(self.conv1(x))))

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.reshape(x.shape[0],-1)
    x = self.fc(x)
    return x

In [40]:
# SE-ResNet with 50 layers
def SE_ResNet50(img_channels=3,num_classes=1000):
    return SENet(se_block,[3,4,6,3],img_channels,num_classes)

In [41]:
# Test
def test():
  net = SE_ResNet50()
  x = torch.rand(2,3,224,224)
  y = net(x)
  print(y.shape)

In [42]:
test()

torch.Size([2, 1000])
