In [3]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

ModuleNotFoundError: No module named 'import_ipynb'

In [39]:
class DoubleConvolution(nn.Module):
    """
    Auxiliary class to define a convolutional layer.
    Each convolution block: 3x3 convolution, batch normalization, ReLU activation.

    Args:
        nn.Module : receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int) -> None:
        """
        Args:
            in_channels (int): amount of input channels (16 or 32 or 64 or 128)
            out_channels (int): amount of output channels (16 or 32 or 64 or 128)
        """ 
        super(DoubleConvolution, self).__init__()
        
        self.doubleConv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding=1),
            nn.BatchNorm3d(out_channels), 
            nn.ReLU(inplace=True),
            
            nn.Conv3d(out_channels, out_channels, kernel_size = 3, padding=1),
            nn.BatchNorm3d(out_channels), 
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: output tensor
        """
        return self.doubleConv(x)

In [51]:
class SE(nn.Module):
    """
    Auxiliary class to define a squeeze and excitation layer.

    Args:
        nn.Module: receive the nn.Module properties
    """
    def __init__(self, in_channels : int) -> None:
        """
        Args:
            in_channels (int): amount of input channels (16 or 32 or 64 or 128)
        """        
        super(SE, self).__init__()
        
        self.squeeze = nn.AdaptiveAvgPool3d(1) # Global Average Pooling
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, in_channels // 8), # Reduction ratio = 8
            nn.ReLU(inplace=True), # ReLU activation
            nn.Linear(in_channels // 8, in_channels), # Increase ratio = 8
            nn.Sigmoid() # Sigmoid activation
        )
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: output tensor
        """        
        batch_size, channels, _, _, _ = x.size()
        y = self.squeeze(x).view(batch_size, channels)
        y = self.excitation(y).view(batch_size, channels, 1, 1, 1)
        return x * y.expand_as(x)

In [41]:
class DownSampling(nn.Module):
    """
    Auxiliary class to define a downsampling layer.
    Each downsampling block: 2x2 max pooling, double convolution and squeeze and excitation.
    input X output: [1, 16, 128, 128, 128] ->  [1, 32, 64, 64, 64] 
                    [1, 32, 64, 64, 64]    ->  [1, 64, 32, 32, 32]
                    [1, 64, 32, 32, 32]    ->  [1, 128, 16, 16, 16]

    Args:
        nn.Module: receive the nn.Module properties
    """
    def __init__(self, in_channels : int, out_channels : int) -> None:
        """
        Args:
            in_channels (int): amount of input channels (16 or 32 or 64)
            out_channels (int): amount of output channels (32 or 64 or 128)
        """        
        super(DownSampling, self).__init__()
        
        self.maxpool = nn.MaxPool3d(2)
        self.conv = DoubleConvolution(in_channels, out_channels)
        self.attention = SE(out_channels)
        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): _description_

        Returns:
            torch.Tensor: _description_
        """
        out = self.maxpool(x) # 2x2 max pooling -> 1/2 the size but same amount of channels
        out = self.conv(out) # double convolution -> same size but double the amount of channels
        out = self.attention(out) # squeeze and excitation
        return out

In [42]:
class UpSampling(nn.Module):
    """
    Auxiliary class to define a upsampling layer.
    Each upsampling block: 2x2 upsampling, concatenation with skip connection, double convolution.
    input X output: [1, 128, 16, 16, 16] ->  [1, 64, 32, 32, 32]
                    [1, 64, 32, 32, 32]    ->  [1, 32, 64, 64, 64]
                    [1, 32, 64, 64, 64]    ->  [1, 16, 128, 128, 128]
                    
    Args:
        nn.Module: receive the nn.Module properties
    """
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = False) -> None:
        """
        Args:
            in_channels (int): amount of input channels (128 or 64 or 32)
            out_channels (int): amount of output channels (64 or 32 or 16)
        """
        super(UpSampling, self).__init__()
        
        self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2)
        self.conv = DoubleConvolution(int(in_channels + out_channels), out_channels)
        
    def forward(self, x : torch.Tensor, skip_connection : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): the input tensor
            skip_connection (torch.Tensor): the skip connection from the downsampling path

        Returns:
            torch.Tensor: the output tensor
        """
        x = self.up(x) # 2x2 upsampling -> double the size but same amount of channels
        x = torch.cat([skip_connection, x], dim=1) # concatenation with skip connection
        out = self.conv(x) # double convolution -> same size but half the amount of channels
        return out
        

In [43]:
class Unet(nn.Module):
    def __init__(self) -> None:
        """
        nn.Module: receive the nn.Module properties
        The input must have shape -> [1, 1, 128, 128, 128]
        The output shape will be -> [1, 1, 128, 128, 128]
        """
        super(Unet, self).__init__()
        in_channels = 1
        out_channels = 1
        
        self.input = nn.Sequential(DoubleConvolution(in_channels, 16), SE(16)) # tranform the input to 16 channels and apply squeeze and excitation
        # encoding path
        self.down1 = DownSampling(16, 32) 
        self.down2 = DownSampling(32, 64) 
        self.down3 = DownSampling(64, 128)
        # decoding path
        self.up1 = UpSampling(128, 64)
        self.up2 = UpSampling(64, 32)
        self.up3 = UpSampling(32, 16)
        self.output = nn.Sequential(nn.Conv3d(16, out_channels, kernel_size=1), nn.Sigmoid()) # transform the output to 1 channel and apply sigmoid activation
    
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): a tensor with shape [1, 1, 128, 128, 128]

        Returns:
            torch.Tensor: a tensor with shape [1, 1, 128, 128, 128]
        """
        input = self.input(x) # [1, 1, 128, 128, 128] -> [1, 16, 128, 128, 128]
        down1_output = self.down1(input)# [1, 16, 128, 128, 128] ->[1, 32, 64, 64, 64]
        down2_output = self.down2(down1_output) # [1, 32, 64, 64, 64] -> [1, 64, 32, 32, 32]
        down3_output = self.down3(down2_output) # [1, 64, 32, 32, 32] -> [1, 128, 16, 16, 16]
        out = self.up1(down3_output, down2_output) # [1, 128, 16, 16, 16] -> [1, 64, 32, 32, 32]
        out = self.up2(out, down1_output) # [1, 64, 32, 32, 32] -> [1, 32, 64, 64, 64]
        out = self.up3(out, input) # [1, 32, 64, 64, 64] -> [1, 16, 128, 128, 128]
        out = self.output(out) # [1, 16, 128, 128, 128] -> [1, 1, 128, 128, 128]
        return out
        
    

In [48]:
model = Unet()

In [49]:
x = torch.ones((1, 1, 128, 128, 128))

In [50]:
sum(p.numel() for p in model.parameters()) - 1634865

4094