<a href="https://colab.research.google.com/github/Hanhpt23/3D-Unet-Pytorch/blob/main/Unet3D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Volumetric segmentation using Unet 3D
- Volumetric segmentation allows for a detailed analysis of anatomical structures, facilitating more personalized and precise treatment plans. This is especially critical in fields like oncology where understanding the exact location and size of tumors is essential.

- Similar to Unet 2D, we can develop Unet 3D by using a set of 3D networks like Conv3d, Maxpooling 3D, BatchNorm3d.

In this repository, Unet 3D is implemented using Pytorch and Unet 2D.
- [Unet paper](https://arxiv.org/pdf/1505.04597.pdf)
- [Unet 2D](https://github.com/Hanhpt23/Unet-Segmentation.git)

In [1]:
import torch
import torch.nn as nn
import math, time
from torchsummary import summary

In [4]:

class UNet3D(nn.Module):

      def __init__(self, in_channels=3, out_channels=1, init_features=32):
            super(UNet3D, self).__init__()

            features = init_features
            self.encoder1 = UNet3D._block(in_channels, features)
            self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
            self.encoder2 = UNet3D._block(features, features * 2)
            self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
            self.encoder3 = UNet3D._block(features * 2, features * 4)
            self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
            self.encoder4 = UNet3D._block(features * 4, features * 8)
            self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

            self.bottleneck = UNet3D._block(features * 8, features * 16)

            self.upconv4 = nn.ConvTranspose3d(
            features * 16, features * 8, kernel_size=2, stride=2
            )
            self.decoder4 = UNet3D._block((features * 8) * 2, features * 8)
            self.upconv3 = nn.ConvTranspose3d(
            features * 8, features * 4, kernel_size=2, stride=2
            )
            self.decoder3 = UNet3D._block((features * 4) * 2, features * 4)
            self.upconv2 = nn.ConvTranspose3d(
            features * 4, features * 2, kernel_size=2, stride=2
            )
            self.decoder2 = UNet3D._block((features * 2) * 2, features * 2)
            self.upconv1 = nn.ConvTranspose3d(
            features * 2, features, kernel_size=2, stride=2
            )
            self.decoder1 = UNet3D._block(features * 2, features)
            self.conv = nn.Conv3d(
            in_channels=features, out_channels=out_channels, kernel_size=1
            )

      def _block(in_channels, features):
            return nn.Sequential(
                        nn.Conv3d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
                        nn.BatchNorm3d(num_features=features),
                        nn.ReLU(inplace=True),
                        nn.Conv3d( in_channels=features,out_channels=features,kernel_size=3,padding=1, bias=False),
                        nn.BatchNorm3d(num_features=features),
                        nn.ReLU(inplace=True)
                        )

      def forward(self, x):
            enc1 = self.encoder1(x)
            enc2 = self.encoder2(self.pool1(enc1))
            enc3 = self.encoder3(self.pool2(enc2))
            enc4 = self.encoder4(self.pool3(enc3))

            bottleneck = self.bottleneck(self.pool4(enc4))

            dec4 = self.upconv4(bottleneck)
            dec4 = torch.cat((dec4, enc4), dim=1)
            dec4 = self.decoder4(dec4)
            dec3 = self.upconv3(dec4)
            dec3 = torch.cat((dec3, enc3), dim=1)
            dec3 = self.decoder3(dec3)
            dec2 = self.upconv2(dec3)
            dec2 = torch.cat((dec2, enc2), dim=1)
            dec2 = self.decoder2(dec2)
            dec1 = self.upconv1(dec2)
            dec1 = torch.cat((dec1, enc1), dim=1)
            dec1 = self.decoder1(dec1)

            return self.conv(dec1)

aa = torch.rand((1, 3, 64, 128, 128))
unet = UNet3D()
print('Output Unet:', unet(aa).shape)

start_time = time.time()
summary(model=unet, input_size=(3, 16, 128, 128))
print(f"--- {round((time.time() - start_time), 4)} seconds ---")

Output Unet: torch.Size([1, 1, 64, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 32, 16, 128, 128]           2,592
       BatchNorm3d-2     [-1, 32, 16, 128, 128]              64
              ReLU-3     [-1, 32, 16, 128, 128]               0
            Conv3d-4     [-1, 32, 16, 128, 128]          27,648
       BatchNorm3d-5     [-1, 32, 16, 128, 128]              64
              ReLU-6     [-1, 32, 16, 128, 128]               0
         MaxPool3d-7        [-1, 32, 8, 64, 64]               0
            Conv3d-8        [-1, 64, 8, 64, 64]          55,296
       BatchNorm3d-9        [-1, 64, 8, 64, 64]             128
             ReLU-10        [-1, 64, 8, 64, 64]               0
           Conv3d-11        [-1, 64, 8, 64, 64]         110,592
      BatchNorm3d-12        [-1, 64, 8, 64, 64]             128
             ReLU-13        [-1, 64, 8, 64, 64]          