In [6]:
import torch
import torch.nn as nn

class Conv3DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout3d(0.1)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.dropout(x)
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class UNet3D(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        # Encoder (Contracting Path)
        self.encoder1 = Conv3DBlock(in_channels, 16)
        self.pool1 = nn.MaxPool3d(2)
        self.encoder2 = Conv3DBlock(16, 32)
        self.pool2 = nn.MaxPool3d(2)
        self.encoder3 = Conv3DBlock(32, 64)
        self.pool3 = nn.MaxPool3d(2)
        self.encoder4 = Conv3DBlock(64, 128)
        self.pool4 = nn.MaxPool3d(2)

        # Bridge
        self.bridge = Conv3DBlock(128, 256)

        # Decoder (Expanding Path)
        self.upconv4 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder4 = Conv3DBlock(256, 128)
        self.upconv3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder3 = Conv3DBlock(128, 64)
        self.upconv2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.decoder2 = Conv3DBlock(64, 32)
        self.upconv1 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)
        self.decoder1 = Conv3DBlock(32, 16)

        self.final_conv = nn.Conv3d(16, num_classes, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        # Bridge
        bridge = self.bridge(self.pool4(enc4))

        # Decoder
        dec4 = self.upconv4(bridge)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.softmax(self.final_conv(dec1))



In [7]:
# # Test the model
# if __name__ == "__main__":
#     model = UNet3D(in_channels=3, num_classes=4)
#     x = torch.randn(1, 4, 128, 128, 128)
#     print(f"Input shape: {x.shape}")
#     y = model(x)
#     print(f"Output shape: {y.shape}")

Input shape: torch.Size([1, 3, 128, 128, 128])
Output shape: torch.Size([1, 4, 128, 128, 128])
