<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/SENet(SqueezeExcitationNet).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Squeeze and Excitation Network

A typical convolution network has kernels running through image channels and combining
the feature maps generated per channel. For each channel, we'll have separate kernel which
learns the weights through backpropagation.

The idea is to understand the interdependencies between channels of the images by explicitly
modeling on it and hence to make the network sensitive to informative features which is further
exploited in the next set of transformation.

* Squeeze(Global Information Embedding) operation converts feature maps into single value per channel.
* Excitation(Adaptive Recalibration) operation converts this single value into per-channel weight.

Squeeze turns (C x H x W) into (C x 1 x 1) using Global Average Pooling.
Excitation turns (C x 1 x 1) into (C x H x W) channel weights using 2 FC layer with activation function
inbetween, then which is expanded as same size as input.

Rescale the output from excitation operation into feature maps as earlier.

Based on the depth of the network, the role played by SE operation is differs. At early layers,
it excites shared low level representation irrespective of the classes. But in later stage, SE
network responds differently based input class.

SE Block is simple and is added with existing CNN architecture to enhance the performance like
ResNet or Inception V1 etc.

I used `SEResNet` in this example

> https://github.com/kuangliu/pytorch-cifar/blob/master/models/senet.py

> https://github.com/YeonwooSung/PyTorch_CNN_Architectures/blob/master/models/senet.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

> Note: if you are using `F.adaptive_avg_pool2d(1)` you dont mention the size of width and height of incoming input

> if you are using `F.avg_pool2d(out,out.size(2))` you have to mention the size of width or height

In [None]:
class BasicBlock(nn.Module):
    def __init__(self,ch_in,ch_out,stride = 1):
        super(BasicBlock,self).__init__()
        self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3 ,stride = stride,padding = 1,bias = False)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1 , padding = 1,bias = False)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.shortcut = nn.Sequential()
        if stride != 1 or ch_in != ch_out:
            self.shortcut = nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size = 1,stride = stride,bias = False),
                nn.BatchNorm2d(ch_out)
            )
        #SE
        self.fc1 = nn.Conv2d(ch_out,ch_out// 16,kernel_size = 1)
        self.fc2 = nn.Conv2d(ch_out // 16,ch_out,kernel_size = 1)
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #squeeze
        w = F.avg_pool2d(out,out.size(2))
        #w = F.adaptive_avg_pool2d(out,1)

        w = F.relu(self.fc1(w))
        w = F.sigmoid(self.fc2(w))
        #excitation
        out = out * w

        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
from torchsummary import summary

In [None]:
blk0 = BasicBlock(3,64,stride = 2)
summary(blk0,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           1,728
       BatchNorm2d-2         [-1, 64, 112, 112]             128
            Conv2d-3         [-1, 64, 112, 112]          36,864
       BatchNorm2d-4         [-1, 64, 112, 112]             128
            Conv2d-5              [-1, 4, 1, 1]             260
            Conv2d-6             [-1, 64, 1, 1]             320
            Conv2d-7         [-1, 64, 112, 112]             192
       BatchNorm2d-8         [-1, 64, 112, 112]             128
Total params: 39,748
Trainable params: 39,748
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 36.75
Params size (MB): 0.15
Estimated Total Size (MB): 37.48
----------------------------------------------------------------


In [None]:
class PreActBlock(nn.Module):
    def __init__(self,ch_in,ch_out,stride = 1):
        super(PreActBlock,self).__init__()
        self.bn1 = nn.BatchNorm2d(ch_in)
        self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size = 3,stride = stride,padding = 1,bias = False)
        self.bn2 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1,padding = 1,bias = False)


        if stride != 1 or ch_in != ch_out:
            self.shortcut = nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size = 1,stride = stride ,bias = False)
            )
        #se
        self.fc1 = nn.Conv2d(ch_out,ch_out//16,kernel_size = 1)
        self.fc2 = nn.Conv2d(ch_out//16,ch_out,kernel_size = 1)
    def forward(self,x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(x) if hasattr(self,'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))


        #squeeze
        #w = F.adaptive_avg_pool2d(out,1)
        w = F.avg_pool2d(out,out.size(2))
        w = F.relu(self.fc1(w))
        w = F.sigmoid(self.fc2(w))
        #exitation
        out = out * w
        out += shortcut

        return out
        #excitation

In [None]:
blk2 = PreActBlock(3,32,stride = 1)
summary(blk2,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 3, 224, 224]               6
            Conv2d-2         [-1, 32, 224, 224]              96
            Conv2d-3         [-1, 32, 224, 224]             864
       BatchNorm2d-4         [-1, 32, 224, 224]              64
            Conv2d-5         [-1, 32, 224, 224]           9,216
            Conv2d-6              [-1, 2, 1, 1]              66
            Conv2d-7             [-1, 32, 1, 1]              96
Total params: 10,408
Trainable params: 10,408
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 50.15
Params size (MB): 0.04
Estimated Total Size (MB): 50.76
----------------------------------------------------------------


In [None]:

class SENet(nn.Module):
    def __init__(self,block,num_blocks,n_classes):
        super(SENet,self).__init__()
        self.ch_in = 64
        self.conv1 = nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1,bias = False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layers(block,64,num_blocks[0],stride = 1)
        self.layer2 = self._make_layers(block,128,num_blocks[1],stride = 2)
        self.layer3 = self._make_layers(block,256,num_blocks[2],stride = 2)
        self.layer4 = self._make_layers(block,512,num_blocks[3],stride = 2)
        self.linear = nn.Linear(512,n_classes)
    def _make_layers(self,block,ch_out,num_residuals,stride):
        strides = [stride] + [1]*(num_residuals - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.ch_in,ch_out,stride))
            self.ch_in = ch_out
        return nn.Sequential(*layers)
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.adaptive_avg_pool2d(out,1)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out

In [None]:
SENet(PreActBlock, [2,2,2,2],10)

SENet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (fc1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
      (fc2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine

In [None]:
def SENet18():
    return SENet(PreActBlock, [2,2,2,2],10)

def seresnet34():
    return SENet(PreActBlock, [3, 4, 6, 3],10)
# for renet50 ,resnet101,reenset152 use a bottleneck
net = SENet18()
# y = net(torch.randn(1,3,32,32))
#     print(y.size())
summary(net,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,728
       BatchNorm2d-2         [-1, 64, 224, 224]             128
       BatchNorm2d-3         [-1, 64, 224, 224]             128
            Conv2d-4         [-1, 64, 224, 224]          36,864
       BatchNorm2d-5         [-1, 64, 224, 224]             128
            Conv2d-6         [-1, 64, 224, 224]          36,864
            Conv2d-7              [-1, 4, 1, 1]             260
            Conv2d-8             [-1, 64, 1, 1]             320
       PreActBlock-9         [-1, 64, 224, 224]               0
      BatchNorm2d-10         [-1, 64, 224, 224]             128
           Conv2d-11         [-1, 64, 224, 224]          36,864
      BatchNorm2d-12         [-1, 64, 224, 224]             128
           Conv2d-13         [-1, 64, 224, 224]          36,864
           Conv2d-14              [-1, 

In [None]:
net(torch.randn(2,3,224,224)).shape

torch.Size([2, 10])