In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [3]:
class DoubleConvolution(nn.Module):
    """
    Auxiliary class to define a convolutional layer.
    Each convolution block: 3x3 convolution, batch normalization, ReLU activation.

    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels : int, out_channels : int) -> None:
        """
        Args:
            in_channels (int): _description_
            out_channels (int): _description_
        """ 
        super(DoubleConvolution, self).__init__()
        
        self.convBlock = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm3d(out_channels), 
            nn.ReLU(inplace=True)
        )
        
        self.doubleConv = nn.Sequential(
            self.convBlock,
            self.convBlock
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        return self.doubleConv(x)

In [4]:
class SE(nn.Module):
    """
    Auxiliary class to define a squeeze and excitation layer.

    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels : int) -> None:
        """
        Args:
            in_channels (int): _description_
            reduction_ratio (int, optional): _description_. Defaults to 2.
        """        
        super(SE, self).__init__()
        
        self.squeeze = nn.AdaptiveAvgPool3d(1)
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // 2, in_channels),
            nn.Sigmoid()
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """        
        batch_size, channels, _, _, _ = x.size()
        y = self.squeeze(x).view(batch_size, channels)
        y = self.excitation(y).view(batch_size, channels, 1, 1, 1)
        return x * y.expand_as(x)

In [5]:
class DownSampling(nn.Module):
    """
    Auxiliary class to define a downsampling layer.
    Each downsampling block: 2x2 max pooling, double convolution.

    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels : int, out_channels : int) -> None:
        """
        Args:
            in_channels (int): _description_
            out_channels (int): _description_
        """        
        super(DownSampling, self).__init__()
        
        self.maxpool = nn.MaxPool3d(2)
        self.conv = DoubleConvolution(in_channels, out_channels)
        self.attention = SE(in_channels)
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        out = self.maxpool(out)
        out = self.conv(x)
        out = self.attention(out)
        return out

In [6]:
class UpSampling(nn.Module):
    """
    Auxiliary class to define a upsampling layer.

    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False) -> None:
        """
        Args:
            in_channels (int): _description_
            out_channels (int): _description_
            bilinear (bool, optional): wheter to use bilinear upsampling. Defaults to False.
        """
        super(UpSampling, self).__init__()
        
        self.conv = DoubleConvolution(in_channels, out_channels)
        self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        
    def forward(self, x : torch.Tensor, skip_connection : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x1 (torch.Tensor): _description_
            x2 (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        x = self.up(x)
        x = torch.cat([skip_connection, x], dim=1)
        out = self.conv(x)
        return out
        

In [7]:
class Unet(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 1) -> None:
        """
        Args:
            in_channels (int): _description_
            out_channels (int): _description_
        """
        super(Unet, self).__init__()
        
        self.input = nn.Sequential(DoubleConvolution(in_channels, 16), SE(16))
        self.down1 = DownSampling(16, 32)
        self.down2 = DownSampling(32, 64)
        self.down3 = DownSampling(64, 128)
        self.up1 = UpSampling(128, 64)
        self.up2 = UpSampling(64, 32)
        self.up3 = UpSampling(32, 16)
        self.output = nn.Sequential(nn.Conv3d(16, out_channels, kernel_size=1), nn.Sigmoid())
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        input = self.input(x)
        down1_output = self.down1(input)
        down2_output = self.down2(down1_output)
        down3_output = self.down3(down2_output)
        out = self.up1(down3_output, down2_output)
        out = self.up2(out, down1_output)
        out = self.up3(out, input)
        out = self.output(out)
        return out
        
    

In [10]:
model = Unet()