In [123]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [124]:
class SingleConvolution(nn.Module):
    """
    Auxiliary class to define a convolutional layer.
    Each convolution block: convolution, batch normalization, ReLU activation.

    Args:
        nn.Module : receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int, kernel_size: int, stride: int, padding: int, bias: bool) -> None:
        """
        Args:
            in_channels (int): amount of input channels
            out_channels (int): amount of output channels
        """ 
        super(SingleConvolution, self).__init__()
        
        self.singleConv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias),
            nn.BatchNorm3d(out_channels), 
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: output tensor
        """
        return self.singleConv(x)

In [125]:
class ChannelGate(nn.Module): # C x 1 x 1
    def __init__(self, in_channels: int, reduction_ratio: int = 16):
        
        super(ChannelGate, self).__init__()
        
        self.in_channels = in_channels
        self.reduction_ratio = reduction_ratio
        
        self.squeezeAvg = nn.AdaptiveAvgPool3d(1) # Global Average Pooling
        self.squeezeMax = nn.AdaptiveMaxPool3d(1) # Global Max Pooling
        self.excitation = nn.Sequential(
            nn.Linear(self.in_channels, self.in_channels // self.reduction_ratio),
            nn.ReLU(inplace=True), # ReLU activation
            nn.Linear(self.in_channels // self.reduction_ratio, self.in_channels),
        )
        self.sigActivation = nn.Sigmoid()
        
    def forward(self, x):
        
        batch_size, channels, _, _, _ = x.size()
        
        yAvg = self.squeezeAvg(x).view(batch_size, channels)
        yAvg = self.excitation(yAvg).view(batch_size, channels, 1, 1, 1)
        
        yMax = self.squeezeMax(x).view(batch_size, channels)
        yMax = self.excitation(yMax).view(batch_size, channels, 1, 1, 1)
        
        sum = yAvg + yMax
        
        return self.sigActivation(sum) * x

In [129]:
class SpatialGate(nn.Module): # 1 x H x W
    def __init__(self, in_channels: int, kernel_size: int = 7, padding: int = 3, bias: bool = False):
        
        super(SpatialGate, self).__init__()
        
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.padding = padding
        self.bias = bias
        
        self.squeezeAvg = nn.AdaptiveAvgPool3d(1)
        self.squeezeMax = nn.AdaptiveMaxPool3d(1)
        self.spatial = SingleConvolution(2 * self.in_channels, self.in_channels, self.kernel_size, 1, kernel_size // 2 , self.bias)
        self.sigActivation = nn.Sigmoid()
        
    def forward(self, x):
        
        yAvg = self.squeezeAvg(x)
        yMax = self.squeezeMax(x)
        y = torch.cat([yAvg, yMax], dim=1)
        y = self.spatial(y)
        
        return self.sigActivation(sum) * x

In [130]:
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio = 16, kernel_size: int = 7):
        
        super(CBAM, self).__init__()
        
        self.ChannelGate = ChannelGate(in_channels, reduction_ratio)
        self.SpatialGate = SpatialGate(in_channels, kernel_size, kernel_size // 2, False)
        
    def forward(self, x):
        
        x_out = self.ChannelGate(x)
        x_out = self.SpatialGate(x_out)
        
        return x_out * x

In [131]:
model = CBAM(16)

TypeError: __init__() missing 1 required positional argument: 'bias'

In [117]:
x = torch.ones((1, 1, 128, 128, 128))
conv = nn.Conv3d(1, 16, kernel_size=3, padding=1)

In [118]:
teste = conv(x)

In [119]:
teste.shape

torch.Size([1, 16, 128, 128, 128])

In [120]:
with torch.no_grad():
    model.eval()
    s = model(teste).shape
print(s)

torch.Size([1, 16, 1, 1, 1])
torch.Size([1, 16, 1, 1, 1])
torch.Size([1, 32, 1, 1, 1])
torch.Size([1, 16, 1, 1, 1])
torch.Size([1, 16, 1, 1, 1])
torch.Size([1, 16, 128, 128, 128])
