In [92]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [93]:
class SingleConvolution(nn.Module):
    """
    Auxiliary class to define a convolutional layer.
    Each convolution block: convolution, batch normalization, ReLU activation.

    Args:
        nn.Module : receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int, kernel_size: int = 3, padding: int = 1, stride : int = 1, bias: bool = False) -> None:
        """
        Args:
            in_channels (int): amount of input channels
            out_channels (int): amount of output channels
        """ 
        super(SingleConvolution, self).__init__()
        
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.bias = bias
        
        self.singleConv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride, bias = self.bias),
            nn.BatchNorm3d(out_channels), 
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: output tensor
        """
        return self.singleConv(x)

In [94]:
class DoubleConvolution(nn.Module):
    """
    Auxiliary class to define a convolutional layer.
    Each convolution block: 3x3 convolution, batch normalization, ReLU activation.

    Args:
        nn.Module : receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int, kernel_size: int = 3, padding: int = 1, stride: int = 1, bias: bool = False) -> None:
        """
        Args:
            in_channels (int): amount of input channels (16 or 32 or 64 or 128)
            out_channels (int): amount of output channels (16 or 32 or 64 or 128)
        """ 
        super(DoubleConvolution, self).__init__()
        
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.bias = bias
        
        self.singleConvolution = SingleConvolution(in_channels, out_channels, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride, bias = self.bias)

        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: output tensor
        """
        y  = self.singleConvolution(x)
        y = self.singleConvolution(y)        
        return y

In [95]:
class ChannelGate(nn.Module): # C x 1 x 1
    def __init__(self, in_channels: int, reduction_ratio: int = 16):
        
        super(ChannelGate, self).__init__()
        
        self.in_channels = in_channels
        self.reduction_ratio = reduction_ratio
        
        self.squeezeAvg = nn.AdaptiveAvgPool3d(1) # Global Average Pooling
        self.squeezeMax = nn.AdaptiveMaxPool3d(1) # Global Max Pooling
        self.excitation = nn.Sequential(
            nn.Linear(self.in_channels, self.in_channels // self.reduction_ratio),
            nn.ReLU(inplace=True), # ReLU activation
            nn.Linear(self.in_channels // self.reduction_ratio, self.in_channels),
        )
        self.sigActivation = nn.Sigmoid()
        
    def forward(self, x):
        
        batch_size, channels, _, _, _ = x.size()
        
        yAvg = self.squeezeAvg(x).view(batch_size, channels)
        yAvg = self.excitation(yAvg).view(batch_size, channels, 1, 1, 1)
        
        yMax = self.squeezeMax(x).view(batch_size, channels)
        yMax = self.excitation(yMax).view(batch_size, channels, 1, 1, 1)
        
        sum = yAvg + yMax
        
        return self.sigActivation(sum) * x

In [96]:
class SpatialGate(nn.Module): # 1 x H x W
    def __init__(self, in_channels: int, kernel_size: int = 7, padding: int = 3, bias: bool = False):
        
        super(SpatialGate, self).__init__()
        
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.padding = padding
        self.bias = bias
        
        self.squeezeAvg = nn.AdaptiveAvgPool3d(1)
        self.squeezeMax = nn.AdaptiveMaxPool3d(1)
        self.spatial = SingleConvolution(2 * self.in_channels, self.in_channels, kernel_size = self.kernel_size, padding = self.padding)
        self.sigActivation = nn.Sigmoid()
        
    def forward(self, x):
        
        yAvg = self.squeezeAvg(x)
        yMax = self.squeezeMax(x)
        y = torch.cat([yAvg, yMax], dim=1)
        y = self.spatial(y)
        
        return self.sigActivation(y) * x

In [107]:
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio = 16, kernel_size: int = 7):
        
        super(CBAM, self).__init__()
        
        self.ChannelGate = ChannelGate(in_channels, reduction_ratio)
        self.SpatialGate = SpatialGate(in_channels, kernel_size = kernel_size, padding = kernel_size // 2)
        
    def forward(self, x):
        
        x_out = self.ChannelGate(x)
        x_out = self.SpatialGate(x_out)
        
        return x_out * x

In [98]:
class DownSampling(nn.Module):
    """
    Auxiliary class to define a downsampling layer.
    Each downsampling block: 2x2 max pooling, double convolution and squeeze and excitation.
    input X output: [1, 16, 128, 128, 128] ->  [1, 32, 64, 64, 64] 
                    [1, 32, 64, 64, 64]    ->  [1, 64, 32, 32, 32]
                    [1, 64, 32, 32, 32]    ->  [1, 128, 16, 16, 16]

    Args:
        nn.Module: receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int, attention) -> None:
        """
        Args:
            in_channels (int): amount of input channels (16 or 32 or 64)
            out_channels (int): amount of output channels (32 or 64 or 128)
        """        
        super(DownSampling, self).__init__()
        
        self.attentionFunction = attention
        
        self.maxpool = nn.MaxPool3d(2)
        self.conv = DoubleConvolution(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.attention = self.attentionFunction(out_channels)
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        out = self.maxpool(x) # 2x2 max pooling -> 1/2 the size but same amount of channels
        out = self.conv(out) # double convolution -> same size but double the amount of channels
        out = self.attention(out) # squeeze and excitation
        return out

In [99]:
class UpSampling(nn.Module):
    """
    Auxiliary class to define a upsampling layer.
    Each upsampling block: 2x2 upsampling, concatenation with skip connection, double convolution.
    input X output: [1, 128, 16, 16, 16] ->  [1, 64, 32, 32, 32]
                    [1, 64, 32, 32, 32]    ->  [1, 32, 64, 64, 64]
                    [1, 32, 64, 64, 64]    ->  [1, 16, 128, 128, 128]
                    
    Args:
        nn.Module: receive the nn.Module properties
    """
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False) -> None:
        """
        Args:
            in_channels (int): amount of input channels (128 or 64 or 32)
            out_channels (int): amount of output channels (64 or 32 or 16)
        """
        super(UpSampling, self).__init__()
        
        self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2)
        self.conv = DoubleConvolution(int(in_channels + out_channels), out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        
    def forward(self, x : torch.Tensor, skip_connection : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): the input tensor
            skip_connection (torch.Tensor): the skip connection from the downsampling path

        Returns:
            torch.Tensor: the output tensor
        """
        x = self.up(x) # 2x2 upsampling -> double the size but same amount of channels
        x = torch.cat([skip_connection, x], dim=1) # concatenation with skip connection
        out = self.conv(x) # double convolution -> same size but half the amount of channels
        return out
        

In [100]:
class Attention_Unet(nn.Module):
    def __init__(self, attentionFunction) -> None:
        """
        nn.Module: receive the nn.Module properties
        The input must have shape -> [1, 1, 128, 128, 128]
        The output shape will be -> [1, 1, 128, 128, 128]
        """
        super(Attention_Unet, self).__init__()
        in_channels = 1
        out_channels = 1
        
        self.attentionFunction = attentionFunction
        
        self.input = nn.Sequential(DoubleConvolution(in_channels, 16, kernel_size=3, padding = 1, stride=1, bias=False), self.attentionFunction(16)) # tranform the input to 16 channels and apply squeeze and excitation
        # encoding path
        self.down1 = DownSampling(16, 32, self.attentionFunction) 
        self.down2 = DownSampling(32, 64, self.attentionFunction) 
        self.down3 = DownSampling(64, 128, self.attentionFunction)
        # decoding path
        self.up1 = UpSampling(128, 64)
        self.up2 = UpSampling(64, 32)
        self.up3 = UpSampling(32, 16)
        self.output = nn.Sequential(nn.Conv3d(16, out_channels, kernel_size=1), nn.Sigmoid()) # transform the output to 1 channel and apply sigmoid activation
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): a tensor with shape [1, 1, 128, 128, 128]

        Returns:
            torch.Tensor: a tensor with shape [1, 1, 128, 128, 128]
        """
        input = self.input(x) # [1, 1, 128, 128, 128] -> [1, 16, 128, 128, 128]
        down1_output = self.down1(input)# [1, 16, 128, 128, 128] ->[1, 32, 64, 64, 64]
        down2_output = self.down2(down1_output) # [1, 32, 64, 64, 64] -> [1, 64, 32, 32, 32]
        down3_output = self.down3(down2_output) # [1, 64, 32, 32, 32] -> [1, 128, 16, 16, 16]
        out = self.up1(down3_output, down2_output) # [1, 128, 16, 16, 16] -> [1, 64, 32, 32, 32]
        out = self.up2(out, down1_output) # [1, 64, 32, 32, 32] -> [1, 32, 64, 64, 64]
        out = self.up3(out, input) # [1, 32, 64, 64, 64] -> [1, 16, 128, 128, 128]
        out = self.output(out) # [1, 16, 128, 128, 128] -> [1, 1, 128, 128, 128]
        return out
        
    

In [111]:
model = Attention_Unet(CBAM)

In [112]:
x = torch.ones((1, 1, 128, 128, 128))

In [113]:
with torch.no_grad():
    model.eval()
    s = model(x).shape
print(s)

RuntimeError: Given groups=1, weight of size [16, 1, 3, 3, 3], expected input[1, 16, 128, 128, 128] to have 1 channels, but got 16 channels instead

In [21]:
sum(p.numel() for p in model.parameters()) - 1465515

15098549