In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), dropout=0.0):
        super(ConvBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=1)
        self.dropout = nn.Dropout3d(dropout)
        self.relu = nn.ReLU(inplace=True)
        self.batch_norm = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.batch_norm(x)
        return x

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()

        self.encoder1 = ConvBlock3D(in_channels, 16, dropout=0.1)
        self.pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.encoder2 = ConvBlock3D(16, 32, dropout=0.1)
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.encoder3 = ConvBlock3D(32, 64, dropout=0.2)
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.encoder4 = ConvBlock3D(64, 128, dropout=0.2)
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.bottleneck = ConvBlock3D(128, 256, dropout=0.3)

        self.upconv4 = nn.ConvTranspose3d(256, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.decoder4 = ConvBlock3D(256, 128, dropout=0.2)

        self.upconv3 = nn.ConvTranspose3d(128, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.decoder3 = ConvBlock3D(128, 64, dropout=0.2)

        self.upconv2 = nn.ConvTranspose3d(64, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.decoder2 = ConvBlock3D(64, 32, dropout=0.1)

        self.upconv1 = nn.ConvTranspose3d(32, 16, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.decoder1 = ConvBlock3D(32, 16, dropout=0.1)

        self.final_conv = nn.Conv3d(16, out_channels, kernel_size=(1, 1, 1))
        self.softmax = nn.Softmax(dim=1)  # Use dim=1 for channels

    def forward(self, x):
        # Permute the input from (batch_size, height, width, depth, channels) to (batch_size, channels, depth, height, width)
        x = x.permute(0, 4, 3, 1, 2)
        
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        output = self.final_conv(dec1)
        
        # Permute the output back from (batch_size, channels, depth, height, width) to (batch_size, height, width, depth, no_classes)
        output = output.permute(0, 3, 4, 2, 1)
        return self.softmax(output)

# # Test if everything is working 
# model = UNet3D(4, 4)  

# # Test input and output shape
# input_tensor = torch.randn(1, 128, 128, 128, 4)  # (batch_size, height, width, depth, channels)
# output_tensor = model(input_tensor)
# print(input_tensor.shape)
# print(output_tensor.shape)


UNet3D(
  (encoder1): ConvBlock3D(
    (conv1): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (conv2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (dropout): Dropout3d(p=0.1, inplace=False)
    (relu): ReLU(inplace=True)
    (batch_norm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (encoder2): ConvBlock3D(
    (conv1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (dropout): Dropout3d(p=0.1, inplace=False)
    (relu): ReLU(inplace=True)
    (batch_norm): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (encoder3): ConvBlock