In [None]:
# idea: to combine skip-connection and SE-block in a single one - SE_Residual_block

In [2]:
import torch 
from torch import nn
import torch.nn.functional as F

In [5]:
class SEBlock(nn.Module):
    def __init__(self, C, r=16):
        super().__init__()
        
        self.aap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        self.linear1 = nn.Linear(C, C//r)
        self.linear2 = nn.Linear(C//r, C)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out = self.aap(x)
        out = self.flatten(out)
        
        out = self.relu(self.linear1(out))
        out = self.sigmoid(self.linear2(out))
        
        out = out[:, :, None, None]
        
        res = x * out
        
        return res
        

In [6]:
tensor = torch.rand(1, 32, 256, 256)
block = SEBlock(32)
output = block(tensor)
output.shape

torch.Size([1, 32, 256, 256])

In [31]:
class SE_ResBlock(nn.Module):
    def __init__(self, inputs, outputs, kernel, stride):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(inputs, outputs, kernel_size=kernel, stride=stride, padding=1),
            nn.BatchNorm2d(outputs)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(outputs, outputs, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(outputs)
        )
        
        if inputs != outputs:
            self.add_conv = nn.Sequential(
                nn.Conv2d(inputs, outputs, kernel_size=kernel, stride=stride, padding=1),
                nn.BatchNorm2d(outputs)
            )
            
            
        self.se_block = SEBlock(outputs)
        
    def forward(self, x):
        out = self.conv1(x)
        add_out = self.add_conv(x)
        
        out = F.relu(out)
        out = self.conv2(out)
        
        # apply attention mechanism with output of second convolutional layer 
        out = self.se_block(out)
        out += add_out
        
        out = F.relu(out)
        
        return out
        
        

In [33]:
tensor = torch.rand(1, 3, 256, 256)
block = SE_ResBlock(3, 32, kernel=4, stride=2)

output = block(tensor)
output.shape

torch.Size([1, 32, 128, 128])