In [3]:
import warnings
import numpy as np
import nibabel as nib

In [1]:
from tensorflow import keras

In [2]:
from keras.layers import Conv3D, MaxPooling3D, Dense, BatchNormalization, Activation
from keras.models import Sequential


In [3]:
class DoubleConv(Sequential):
    def __init__(self, in_channels, out_channels, num_groups=8):
        super(DoubleConv, self).__init__()
        self.double_conv = Sequential([
            # Convolution set one 
            Conv3D(out_channels, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same', input_shape=in_channels),
            BatchNormalization(),
            Activation('relu'),

            # Convolution set two
            Conv3D(out_channels, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same'),
            BatchNormalization(),
            Activation('relu')
        ])

    def call(self, x):
        return self.double_conv(x)

In [4]:
from keras.models import Sequential, Model
from keras.layers import MaxPooling3D

class Down(Model):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.encoder = Sequential([
            MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2)),
            DoubleConv(in_channels, out_channels)
        ])

    def call(self, x):
        # max pooling 3d + doubleConv
        return self.encoder(x)

In [5]:
from keras.models import Sequential
from keras.layers import UpSampling3D, Conv3DTranspose, Concatenate, ZeroPadding3D

class Up(Model):
    def __init__(self, in_channels, out_channels, trilinear=True):
        super(Up, self).__init__()

        if trilinear:
            self.up = UpSampling3D(size=(2, 2, 2), interpolation='trilinear')
        else:
            self.up = Conv3DTranspose(filters=in_channels // 2, kernel_size=(2, 2, 2), strides=(2, 2, 2))
            
        self.conv = DoubleConv(in_channels, out_channels)

    def call(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.shape[1] - x1.shape[1]
        diffY = x2.shape[2] - x1.shape[2]
        diffX = x2.shape[3] - x1.shape[3]
        x1 = ZeroPadding3D(((diffZ // 2, diffZ - diffZ // 2), (diffY // 2, diffY - diffY // 2), (diffX // 2, diffX - diffX // 2)))(x1)

        x = Concatenate(axis=1)([x2, x1])
        return self.conv(x)


In [6]:
from keras.layers import Conv3D

class Out(Model):
    def __init__(self, in_channels, out_channels):
        super(Out, self).__init__()
        self.conv = Conv3D(filters=out_channels, kernel_size=(1, 1, 1))

    def call(self, x):
        return self.conv(x)


In [None]:
from keras.models import Model
from keras.layers import Input
import tensorflow as tf

class UNet3d(tf.keras.Model):
    def __init__(self, in_channels, n_classes, n_channels):
        super(UNet3d, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        # extracting the features by incrementally multiplying the number of channels 
        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def call(self, x):
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        
        return mask

# Create an instance of the UNet3d model
in_channels = ...  # Specify the number of input channels
n_classes = ...  # Specify the number of output classes
n_channels = ...  # Specify the initial number of channels
input_shape = (0, 0, 0, in_channels)  # Specify the input shape
inputs = Input(shape=input_shape)
unet_model = UNet3d(in_channels, n_classes, n_channels)
outputs = unet_model(inputs)
model = Model(inputs=inputs, outputs=outputs)
