## SENet on CIFAR 10

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class SEModule(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEModule, self).__init__()
        self.globalAvgPool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = torch.squeeze(self.globalAvgPool(x))
        out = self.fc(out).view(x.size()[0], x.size()[1], 1, 1)

        # both methods works
        # return x * out
        # return x * out.expand_as(x)


class SEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, reduction=16):
        super(SEBlock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.se = SEModule(out_channels, reduction=reduction)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.se(out)

        residual = x
        if self.downsample is not None:
            residual = self.downsample(residual)

        out += residual

        return F.relu(out)


class SENet(nn.Module):
    def __init__(self, num_blocks, num_classes=10):
        super(SENet, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        self.layer2 = [SEBlock(16, 16) for _ in range(num_blocks[0])]
        self.layer2 = nn.ModuleList(self.layer2)

        self.layer3 = [SEBlock(16, 32, 2)]
        self.layer3 += [SEBlock(32, 32) for _ in range(num_blocks[1] - 1)]
        self.layer3 = nn.ModuleList(self.layer3)

        self.layer4 = [SEBlock(32, 64, 2)]
        self.layer4 += [SEBlock(64, 64) for _ in range(num_blocks[2] - 1)]
        self.layer4 = nn.ModuleList(self.layer4)

        self.globalAvgPool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        out = self.layer1(x)

        for layer in self.layer2:
            out = layer(out)

        for layer in self.layer3:
            out = layer(out)

        for layer in self.layer4:
            out = layer(out)

        out = self.globalAvgPool(out)
        out = torch.squeeze(out)
        out = self.fc(out)

        return out