Basic idea of squeeze-excitation network, as attention mechanism is to assign a certain value for each extracted feature channel and evaluate which of them is the most important. SE algorithm:

1) Squeeze: We take feature maps (for example 32 of them with size 224x224 of each) and apply Global average pooling (GAP) to put them all into compact form: 32x1x1 to make calculation easier. (We may also apply flatten layer for this, because we need to receive a vector with size 32 as an output).
2)  Excitation: Having vector with 32 values, we should remember, that each of them is the level of importance of each channel. Then we need to reduce amount of channels, also to make calculation easier. We let reduced vector go through relu function, which will turn negative (less important features) values into 0. Then, we put vector into basic channel size and apply sigmoid - because during resizing to basic size, there could appear some negative values, so sigmoid turns them into 0/1 range. 
3) So we've received a vector with values in range from 0 to 1. Then we take this vector, add size dimensions (1x1) to it and multiply it on basic feature map. So if we had feature map of this size - 3, 224, 224, vector with sigmoid-weighted values will have size 3, 1, 1 and will look, for example like [0.9, 0.2, 0.1], then it means, that first channels will be multiplied on 0.9 (turned stronger), second will be multiplied on 0.2 (turned weaker) and third one will be multiplied on 0.1 (also turned weaker) - this how this layer detects priority of each channel.


In [2]:
from symbol import or_test

import torch
from torch import nn

In [58]:
class SEBlock(nn.Module):
    def __init__(self, C, r=16):
        super().__init__()
        
        self.C = C
        self.r = r
    
        self.glob_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.linear1 = nn.Linear(C, C//r)
        self.linear2 = nn.Linear(C//r, C)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.flatten = nn.Flatten()
        
    def forward(self, x):
        # basic size 
        # x.shape = [N, C, H, W]
        
        # global adaptive pooling - to put feature map size as 1x1
        out = self.glob_pool(x)
        
        # turn vector from Cx1x1 size to 1D vector
        out = self.flatten(out)
        # reducing amount of channels (size of vector)
        out = self.linear1(out)
        # pass al values through the relu() to turn negative (less important) ones into zeros
        out = self.relu(out)
        # bringing original vector size back
        out = self.linear2(out)
        # passing original-size vector through sigmoid to detect more and less important values.
        out = self.sigmoid(out)
        # add 1x1 size to 1D vector with more and less important feature values.
        out = out[:, :, None, None]
        # multiply received 3D vector of importances with original feature map
        out = out * x
        
        return out
    
        
        

In [59]:
tensor = torch.rand(1, 128, 224, 224)
block = SEBlock(128, 16)

out = block(tensor)
out.shape

torch.Size([1, 128, 224, 224])